Spaces:
Sleeping
Sleeping
Changes to handle device properly
Browse files- app.py +1 -1
- src/chatterbox/mtl_tts.py +4 -4
app.py
CHANGED
|
@@ -178,7 +178,7 @@ with gr.Blocks() as demo:
|
|
| 178 |
gr.Markdown(get_supported_languages_display())
|
| 179 |
with gr.Row():
|
| 180 |
with gr.Column():
|
| 181 |
-
initial_lang = "
|
| 182 |
text = gr.Textbox(
|
| 183 |
value=default_text_for_ui(initial_lang),
|
| 184 |
label="Text to synthesize (max chars 300)",
|
|
|
|
| 178 |
gr.Markdown(get_supported_languages_display())
|
| 179 |
with gr.Row():
|
| 180 |
with gr.Column():
|
| 181 |
+
initial_lang = "da"
|
| 182 |
text = gr.Textbox(
|
| 183 |
value=default_text_for_ui(initial_lang),
|
| 184 |
label="Text to synthesize (max chars 300)",
|
src/chatterbox/mtl_tts.py
CHANGED
|
@@ -142,12 +142,12 @@ class ChatterboxMultilingualTTS:
|
|
| 142 |
|
| 143 |
ve = VoiceEncoder()
|
| 144 |
ve.load_state_dict(
|
| 145 |
-
torch.load(ckpt_dir / "ve.pt", weights_only=True)
|
| 146 |
)
|
| 147 |
ve.to(device).eval()
|
| 148 |
|
| 149 |
t3 = T3(T3Config.multilingual())
|
| 150 |
-
t3_state = load_safetensors(ckpt_dir / "t3_23lang.safetensors")
|
| 151 |
if "model" in t3_state.keys():
|
| 152 |
t3_state = t3_state["model"][0]
|
| 153 |
t3.load_state_dict(t3_state)
|
|
@@ -155,7 +155,7 @@ class ChatterboxMultilingualTTS:
|
|
| 155 |
|
| 156 |
s3gen = S3Gen()
|
| 157 |
s3gen.load_state_dict(
|
| 158 |
-
torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
|
| 159 |
)
|
| 160 |
s3gen.to(device).eval()
|
| 161 |
|
|
@@ -165,7 +165,7 @@ class ChatterboxMultilingualTTS:
|
|
| 165 |
|
| 166 |
conds = None
|
| 167 |
if (builtin_voice := ckpt_dir / "conds.pt").exists():
|
| 168 |
-
conds = Conditionals.load(builtin_voice).to(device)
|
| 169 |
|
| 170 |
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
|
| 171 |
|
|
|
|
| 142 |
|
| 143 |
ve = VoiceEncoder()
|
| 144 |
ve.load_state_dict(
|
| 145 |
+
torch.load(ckpt_dir / "ve.pt", weights_only=True, map_location=device)
|
| 146 |
)
|
| 147 |
ve.to(device).eval()
|
| 148 |
|
| 149 |
t3 = T3(T3Config.multilingual())
|
| 150 |
+
t3_state = load_safetensors(ckpt_dir / "t3_23lang.safetensors", device=str(device))
|
| 151 |
if "model" in t3_state.keys():
|
| 152 |
t3_state = t3_state["model"][0]
|
| 153 |
t3.load_state_dict(t3_state)
|
|
|
|
| 155 |
|
| 156 |
s3gen = S3Gen()
|
| 157 |
s3gen.load_state_dict(
|
| 158 |
+
torch.load(ckpt_dir / "s3gen.pt", weights_only=True, map_location=device)
|
| 159 |
)
|
| 160 |
s3gen.to(device).eval()
|
| 161 |
|
|
|
|
| 165 |
|
| 166 |
conds = None
|
| 167 |
if (builtin_voice := ckpt_dir / "conds.pt").exists():
|
| 168 |
+
conds = Conditionals.load(builtin_voice, map_location=device).to(device)
|
| 169 |
|
| 170 |
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
|
| 171 |
|