Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from modules.naturalpseech2.transformers import ( | |
| TransformerEncoder, | |
| DurationPredictor, | |
| PitchPredictor, | |
| LengthRegulator, | |
| ) | |
| class PriorEncoder(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.enc_emb_tokens = nn.Embedding( | |
| cfg.vocab_size, cfg.encoder.encoder_hidden, padding_idx=0 | |
| ) | |
| self.enc_emb_tokens.weight.data.normal_(mean=0.0, std=1e-5) | |
| self.encoder = TransformerEncoder( | |
| enc_emb_tokens=self.enc_emb_tokens, cfg=cfg.encoder | |
| ) | |
| self.duration_predictor = DurationPredictor(cfg.duration_predictor) | |
| self.pitch_predictor = PitchPredictor(cfg.pitch_predictor) | |
| self.length_regulator = LengthRegulator() | |
| self.pitch_min = cfg.pitch_min | |
| self.pitch_max = cfg.pitch_max | |
| self.pitch_bins_num = cfg.pitch_bins_num | |
| pitch_bins = torch.exp( | |
| torch.linspace( | |
| np.log(self.pitch_min), np.log(self.pitch_max), self.pitch_bins_num - 1 | |
| ) | |
| ) | |
| self.register_buffer("pitch_bins", pitch_bins) | |
| self.pitch_embedding = nn.Embedding( | |
| self.pitch_bins_num, cfg.encoder.encoder_hidden | |
| ) | |
| def forward( | |
| self, | |
| phone_id, | |
| duration=None, | |
| pitch=None, | |
| phone_mask=None, | |
| mask=None, | |
| ref_emb=None, | |
| ref_mask=None, | |
| is_inference=False, | |
| ): | |
| """ | |
| input: | |
| phone_id: (B, N) | |
| duration: (B, N) | |
| pitch: (B, T) | |
| phone_mask: (B, N); mask is 0 | |
| mask: (B, T); mask is 0 | |
| ref_emb: (B, d, T') | |
| ref_mask: (B, T'); mask is 0 | |
| output: | |
| prior_embedding: (B, d, T) | |
| pred_dur: (B, N) | |
| pred_pitch: (B, T) | |
| """ | |
| x = self.encoder(phone_id, phone_mask, ref_emb.transpose(1, 2)) | |
| # print(torch.min(x), torch.max(x)) | |
| dur_pred_out = self.duration_predictor(x, phone_mask, ref_emb, ref_mask) | |
| # dur_pred_out: {dur_pred_log, dur_pred, dur_pred_round} | |
| if is_inference or duration is None: | |
| x, mel_len = self.length_regulator( | |
| x, | |
| dur_pred_out["dur_pred_round"], | |
| max_len=torch.max(torch.sum(dur_pred_out["dur_pred_round"], dim=1)), | |
| ) | |
| else: | |
| x, mel_len = self.length_regulator(x, duration, max_len=pitch.shape[1]) | |
| pitch_pred_log = self.pitch_predictor(x, mask, ref_emb, ref_mask) | |
| if is_inference or pitch is None: | |
| pitch_tokens = torch.bucketize(pitch_pred_log.exp(), self.pitch_bins) | |
| pitch_embedding = self.pitch_embedding(pitch_tokens) | |
| else: | |
| pitch_tokens = torch.bucketize(pitch, self.pitch_bins) | |
| pitch_embedding = self.pitch_embedding(pitch_tokens) | |
| x = x + pitch_embedding | |
| if (not is_inference) and (mask is not None): | |
| x = x * mask.to(x.dtype)[:, :, None] | |
| prior_out = { | |
| "dur_pred_round": dur_pred_out["dur_pred_round"], | |
| "dur_pred_log": dur_pred_out["dur_pred_log"], | |
| "dur_pred": dur_pred_out["dur_pred"], | |
| "pitch_pred_log": pitch_pred_log, | |
| "pitch_token": pitch_tokens, | |
| "mel_len": mel_len, | |
| "prior_out": x, | |
| } | |
| return prior_out | |