Biorrith commited on
Commit
191fc47
·
1 Parent(s): 3e92db8

Changes to handle device properly

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. src/chatterbox/mtl_tts.py +4 -4
app.py CHANGED
@@ -178,7 +178,7 @@ with gr.Blocks() as demo:
178
  gr.Markdown(get_supported_languages_display())
179
  with gr.Row():
180
  with gr.Column():
181
- initial_lang = "fr"
182
  text = gr.Textbox(
183
  value=default_text_for_ui(initial_lang),
184
  label="Text to synthesize (max chars 300)",
 
178
  gr.Markdown(get_supported_languages_display())
179
  with gr.Row():
180
  with gr.Column():
181
+ initial_lang = "da"
182
  text = gr.Textbox(
183
  value=default_text_for_ui(initial_lang),
184
  label="Text to synthesize (max chars 300)",
src/chatterbox/mtl_tts.py CHANGED
@@ -142,12 +142,12 @@ class ChatterboxMultilingualTTS:
142
 
143
  ve = VoiceEncoder()
144
  ve.load_state_dict(
145
- torch.load(ckpt_dir / "ve.pt", weights_only=True)
146
  )
147
  ve.to(device).eval()
148
 
149
  t3 = T3(T3Config.multilingual())
150
- t3_state = load_safetensors(ckpt_dir / "t3_23lang.safetensors")
151
  if "model" in t3_state.keys():
152
  t3_state = t3_state["model"][0]
153
  t3.load_state_dict(t3_state)
@@ -155,7 +155,7 @@ class ChatterboxMultilingualTTS:
155
 
156
  s3gen = S3Gen()
157
  s3gen.load_state_dict(
158
- torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
159
  )
160
  s3gen.to(device).eval()
161
 
@@ -165,7 +165,7 @@ class ChatterboxMultilingualTTS:
165
 
166
  conds = None
167
  if (builtin_voice := ckpt_dir / "conds.pt").exists():
168
- conds = Conditionals.load(builtin_voice).to(device)
169
 
170
  return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
171
 
 
142
 
143
  ve = VoiceEncoder()
144
  ve.load_state_dict(
145
+ torch.load(ckpt_dir / "ve.pt", weights_only=True, map_location=device)
146
  )
147
  ve.to(device).eval()
148
 
149
  t3 = T3(T3Config.multilingual())
150
+ t3_state = load_safetensors(ckpt_dir / "t3_23lang.safetensors", device=str(device))
151
  if "model" in t3_state.keys():
152
  t3_state = t3_state["model"][0]
153
  t3.load_state_dict(t3_state)
 
155
 
156
  s3gen = S3Gen()
157
  s3gen.load_state_dict(
158
+ torch.load(ckpt_dir / "s3gen.pt", weights_only=True, map_location=device)
159
  )
160
  s3gen.to(device).eval()
161
 
 
165
 
166
  conds = None
167
  if (builtin_voice := ckpt_dir / "conds.pt").exists():
168
+ conds = Conditionals.load(builtin_voice, map_location=device).to(device)
169
 
170
  return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
171