Spaces:
Runtime error
Runtime error
File size: 8,031 Bytes
26e0cd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
from .configuration_vibevoice_streaming import VibeVoiceStreamingConfig
logger = logging.get_logger(__name__)
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
class BinaryClassifier(nn.Module):
def __init__(self, hidden_size):
super(BinaryClassifier, self).__init__()
self.fc1 = nn.Linear(hidden_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
class SpeechConnector(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
self.fc2 = nn.Linear(output_dim, output_dim)
def forward(self, features, **kwargs):
x = self.fc1(features)
x = self.norm(x)
x = self.fc2(x)
return x
# @auto_docstring
class VibeVoiceStreamingPreTrainedModel(PreTrainedModel):
config_class = VibeVoiceStreamingConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if isinstance(module, VibeVoiceDiffusionHead):
module.initialize_weights()
return
# Use the language model's initializer_range if available
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
std = self.config.decoder_config.initializer_range
else:
std = 0.02 # Default value
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
# @auto_docstring
class VibeVoiceStreamingModel(VibeVoiceStreamingPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
# Initialize Qwen2 model for language modeling.
# The lower Transformer layers are only used for encoding text, while the upper Transformer layers are used for encoding text and generating speech.
# To keep the code clean, we constructs two language models.
# The final norm layer of the first language_model is set to identity and will not be used in inference.
lm_config = copy.deepcopy(config.decoder_config)
lm_backbone_num_hidden_layers = getattr(lm_config, 'num_hidden_layers', 24) - config.tts_backbone_num_hidden_layers
lm_config.num_hidden_layers = lm_backbone_num_hidden_layers
self.language_model = AutoModel.from_config(lm_config)
self.language_model.norm = nn.Identity()
# We only need the Transformer layers here. Note that embed_tokens in tts_language_model is unused
tts_lm_config = copy.deepcopy(lm_config)
tts_lm_config.num_hidden_layers = config.tts_backbone_num_hidden_layers
self.tts_language_model = AutoModel.from_config(tts_lm_config)
# Marks the text that needs to be spoken by the TTS model.
self.tts_input_types = nn.Embedding(num_embeddings=2, embedding_dim=config.decoder_config.hidden_size)
# Initialize speech components if needed
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
# Initialize prediction head for speech generation
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
# Initialize noise scheduler
self.noise_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
prediction_type=config.diffusion_head_config.prediction_type
)
def get_input_embeddings(self):
if hasattr(self.language_model, 'embed_tokens'):
# If the language model has an embed_tokens attribute, return it
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
if attr.orig_name == 'embed_tokens.weight':
return getattr(self.language_model, name)
assert False, 'should not arrive here'
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None):
"""Set the speech tokenizers used for encoding and decoding speech."""
self.acoustic_tokenizer = acoustic_tokenizer
# Reset the encoder to evaluation mode
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.eval()
def forward(self, *args, **kwargs):
"""
Intentionally not implemented.
This streaming model is split into two explicit submodules:
- `language_model` for plain text processing (lower layers).
- `tts_language_model` for TTS-related upper layers.
We deliberately avoid a unified `forward` to prevent accidental calls
that mix responsibilities.
To use the model:
- Call `self.language_model(...)` for text embeddings / hidden states.
- Call `self.tts_language_model(...)` for the TTS portion.
- Use the dedicated inference class for combined generation logic.
"""
raise RuntimeError(
"VibeVoiceStreamingModel.forward is intentionally disabled. "
"Use `model.language_model(...)` or `model.tts_language_model(...)` instead."
)
AutoModel.register(VibeVoiceStreamingConfig, VibeVoiceStreamingModel)
__all__ = [
"VibeVoiceStreamingPreTrainedModel",
"VibeVoiceStreamingModel",
] |