Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from models.tts.naturalspeech2.diffusion import Diffusion | |
| from models.tts.naturalspeech2.diffusion_flow import DiffusionFlow | |
| from models.tts.naturalspeech2.wavenet import WaveNet | |
| from models.tts.naturalspeech2.prior_encoder import PriorEncoder | |
| from modules.naturalpseech2.transformers import TransformerEncoder | |
| from encodec import EncodecModel | |
| from einops import rearrange, repeat | |
| import os | |
| import json | |
| class NaturalSpeech2(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.latent_dim = cfg.latent_dim | |
| self.query_emb_num = cfg.query_emb.query_token_num | |
| self.prior_encoder = PriorEncoder(cfg.prior_encoder) | |
| if cfg.diffusion.diffusion_type == "diffusion": | |
| self.diffusion = Diffusion(cfg.diffusion) | |
| elif cfg.diffusion.diffusion_type == "flow": | |
| self.diffusion = DiffusionFlow(cfg.diffusion) | |
| self.prompt_encoder = TransformerEncoder(cfg=cfg.prompt_encoder) | |
| if self.latent_dim != cfg.prompt_encoder.encoder_hidden: | |
| self.prompt_lin = nn.Linear( | |
| self.latent_dim, cfg.prompt_encoder.encoder_hidden | |
| ) | |
| self.prompt_lin.weight.data.normal_(0.0, 0.02) | |
| else: | |
| self.prompt_lin = None | |
| self.query_emb = nn.Embedding(self.query_emb_num, cfg.query_emb.hidden_size) | |
| self.query_attn = nn.MultiheadAttention( | |
| cfg.query_emb.hidden_size, cfg.query_emb.head_num, batch_first=True | |
| ) | |
| codec_model = EncodecModel.encodec_model_24khz() | |
| codec_model.set_target_bandwidth(12.0) | |
| codec_model.requires_grad_(False) | |
| self.quantizer = codec_model.quantizer | |
| def code_to_latent(self, code): | |
| latent = self.quantizer.decode(code.transpose(0, 1)) | |
| return latent | |
| def latent_to_code(self, latent, nq=16): | |
| residual = latent | |
| all_indices = [] | |
| all_dist = [] | |
| for i in range(nq): | |
| layer = self.quantizer.vq.layers[i] | |
| x = rearrange(residual, "b d n -> b n d") | |
| x = layer.project_in(x) | |
| shape = x.shape | |
| x = layer._codebook.preprocess(x) | |
| embed = layer._codebook.embed.t() | |
| dist = -( | |
| x.pow(2).sum(1, keepdim=True) | |
| - 2 * x @ embed | |
| + embed.pow(2).sum(0, keepdim=True) | |
| ) | |
| indices = dist.max(dim=-1).indices | |
| indices = layer._codebook.postprocess_emb(indices, shape) | |
| dist = dist.reshape(*shape[:-1], dist.shape[-1]) | |
| quantized = layer.decode(indices) | |
| residual = residual - quantized | |
| all_indices.append(indices) | |
| all_dist.append(dist) | |
| out_indices = torch.stack(all_indices) | |
| out_dist = torch.stack(all_dist) | |
| return out_indices, out_dist # (nq, B, T); (nq, B, T, 1024) | |
| def latent_to_latent(self, latent, nq=16): | |
| codes, _ = self.latent_to_code(latent, nq) | |
| latent = self.quantizer.vq.decode(codes) | |
| return latent | |
| def forward( | |
| self, | |
| code=None, | |
| pitch=None, | |
| duration=None, | |
| phone_id=None, | |
| phone_id_frame=None, | |
| frame_nums=None, | |
| ref_code=None, | |
| ref_frame_nums=None, | |
| phone_mask=None, | |
| mask=None, | |
| ref_mask=None, | |
| ): | |
| ref_latent = self.code_to_latent(ref_code) | |
| latent = self.code_to_latent(code) | |
| if self.latent_dim is not None: | |
| ref_latent = self.prompt_lin(ref_latent.transpose(1, 2)) | |
| ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None) | |
| spk_emb = ref_latent.transpose(1, 2) # (B, d, T') | |
| spk_query_emb = self.query_emb( | |
| torch.arange(self.query_emb_num).to(latent.device) | |
| ).repeat( | |
| latent.shape[0], 1, 1 | |
| ) # (B, query_emb_num, d) | |
| spk_query_emb, _ = self.query_attn( | |
| spk_query_emb, | |
| spk_emb.transpose(1, 2), | |
| spk_emb.transpose(1, 2), | |
| key_padding_mask=~(ref_mask.bool()), | |
| ) # (B, query_emb_num, d) | |
| prior_out = self.prior_encoder( | |
| phone_id=phone_id, | |
| duration=duration, | |
| pitch=pitch, | |
| phone_mask=phone_mask, | |
| mask=mask, | |
| ref_emb=spk_emb, | |
| ref_mask=ref_mask, | |
| is_inference=False, | |
| ) | |
| prior_condition = prior_out["prior_out"] # (B, T, d) | |
| diff_out = self.diffusion(latent, mask, prior_condition, spk_query_emb) | |
| return diff_out, prior_out | |
| def inference( | |
| self, ref_code=None, phone_id=None, ref_mask=None, inference_steps=1000 | |
| ): | |
| ref_latent = self.code_to_latent(ref_code) | |
| if self.latent_dim is not None: | |
| ref_latent = self.prompt_lin(ref_latent.transpose(1, 2)) | |
| ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None) | |
| spk_emb = ref_latent.transpose(1, 2) # (B, d, T') | |
| spk_query_emb = self.query_emb( | |
| torch.arange(self.query_emb_num).to(ref_latent.device) | |
| ).repeat( | |
| ref_latent.shape[0], 1, 1 | |
| ) # (B, query_emb_num, d) | |
| spk_query_emb, _ = self.query_attn( | |
| spk_query_emb, | |
| spk_emb.transpose(1, 2), | |
| spk_emb.transpose(1, 2), | |
| key_padding_mask=~(ref_mask.bool()), | |
| ) # (B, query_emb_num, d) | |
| prior_out = self.prior_encoder( | |
| phone_id=phone_id, | |
| duration=None, | |
| pitch=None, | |
| phone_mask=None, | |
| mask=None, | |
| ref_emb=spk_emb, | |
| ref_mask=ref_mask, | |
| is_inference=True, | |
| ) | |
| prior_condition = prior_out["prior_out"] # (B, T, d) | |
| z = torch.randn( | |
| prior_condition.shape[0], self.latent_dim, prior_condition.shape[1] | |
| ).to(ref_latent.device) / (1.20) | |
| x0 = self.diffusion.reverse_diffusion( | |
| z, None, prior_condition, inference_steps, spk_query_emb | |
| ) | |
| return x0, prior_out | |
| def reverse_diffusion_from_t( | |
| self, | |
| code=None, | |
| pitch=None, | |
| duration=None, | |
| phone_id=None, | |
| ref_code=None, | |
| phone_mask=None, | |
| mask=None, | |
| ref_mask=None, | |
| n_timesteps=None, | |
| t=None, | |
| ): | |
| # o Only for debug | |
| ref_latent = self.code_to_latent(ref_code) | |
| latent = self.code_to_latent(code) | |
| if self.latent_dim is not None: | |
| ref_latent = self.prompt_lin(ref_latent.transpose(1, 2)) | |
| ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None) | |
| spk_emb = ref_latent.transpose(1, 2) # (B, d, T') | |
| spk_query_emb = self.query_emb( | |
| torch.arange(self.query_emb_num).to(latent.device) | |
| ).repeat( | |
| latent.shape[0], 1, 1 | |
| ) # (B, query_emb_num, d) | |
| spk_query_emb, _ = self.query_attn( | |
| spk_query_emb, | |
| spk_emb.transpose(1, 2), | |
| spk_emb.transpose(1, 2), | |
| key_padding_mask=~(ref_mask.bool()), | |
| ) # (B, query_emb_num, d) | |
| prior_out = self.prior_encoder( | |
| phone_id=phone_id, | |
| duration=duration, | |
| pitch=pitch, | |
| phone_mask=phone_mask, | |
| mask=mask, | |
| ref_emb=spk_emb, | |
| ref_mask=ref_mask, | |
| is_inference=False, | |
| ) | |
| prior_condition = prior_out["prior_out"] # (B, T, d) | |
| diffusion_step = ( | |
| torch.ones( | |
| latent.shape[0], | |
| dtype=latent.dtype, | |
| device=latent.device, | |
| requires_grad=False, | |
| ) | |
| * t | |
| ) | |
| diffusion_step = torch.clamp(diffusion_step, 1e-5, 1.0 - 1e-5) | |
| xt, _ = self.diffusion.forward_diffusion( | |
| x0=latent, diffusion_step=diffusion_step | |
| ) | |
| # print(torch.abs(xt-latent).max(), torch.abs(xt-latent).mean(), torch.abs(xt-latent).std()) | |
| x0 = self.diffusion.reverse_diffusion_from_t( | |
| xt, mask, prior_condition, n_timesteps, spk_query_emb, t_start=t | |
| ) | |
| return x0, prior_out, xt | |