Enable `cache_params` to work with `generate()` from `GenerationMixin`
#3
by FremyCompany - opened
- modeling_nemotron_h.py +20 -15
modeling_nemotron_h.py
CHANGED
|
@@ -31,6 +31,9 @@ from transformers.modeling_attn_mask_utils import (
|
|
| 31 |
AttentionMaskConverter,
|
| 32 |
)
|
| 33 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
|
|
|
|
|
|
| 34 |
from transformers.utils import (
|
| 35 |
ModelOutput,
|
| 36 |
add_code_sample_docstrings,
|
|
@@ -168,12 +171,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
| 168 |
|
| 169 |
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
| 170 |
super().__init__()
|
|
|
|
| 171 |
self.dtype = dtype
|
| 172 |
self.hybrid_override_pattern = config.hybrid_override_pattern
|
| 173 |
self.has_previous_state = False # only used by mamba
|
| 174 |
-
intermediate_size = config.expand * config.hidden_size
|
| 175 |
-
ssm_state_size = config.ssm_state_size
|
| 176 |
-
conv_kernel_size = config.conv_kernel
|
|
|
|
| 177 |
self.conv_states = []
|
| 178 |
self.ssm_states = []
|
| 179 |
self.transformer_layers = []
|
|
@@ -181,10 +186,10 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
| 181 |
if self.hybrid_override_pattern[i] == "M":
|
| 182 |
# Mamba layer
|
| 183 |
self.conv_states += [
|
| 184 |
-
torch.zeros(batch_size,
|
| 185 |
]
|
| 186 |
self.ssm_states += [
|
| 187 |
-
torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
|
| 188 |
]
|
| 189 |
else:
|
| 190 |
# Attention or MLP layer
|
|
@@ -245,14 +250,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
| 245 |
self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
|
| 246 |
) -> torch.Tensor:
|
| 247 |
if cache_init:
|
| 248 |
-
self.conv_states[layer_idx] = new_conv_state.to(self.
|
| 249 |
else:
|
| 250 |
self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
|
| 251 |
-
self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
|
| 252 |
return self.conv_states[layer_idx]
|
| 253 |
|
| 254 |
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
|
| 255 |
-
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
|
| 256 |
return self.ssm_states[layer_idx]
|
| 257 |
|
| 258 |
def reset(self):
|
|
@@ -413,7 +418,7 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
| 413 |
dt_softplus=True,
|
| 414 |
)
|
| 415 |
hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
|
| 416 |
-
breakpoint()
|
| 417 |
hidden_states = self.norm(hidden_states, gate)
|
| 418 |
|
| 419 |
# 4. Final linear projection
|
|
@@ -560,7 +565,7 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
| 560 |
A = -torch.exp(self.A_log.float()) # [num_heads]
|
| 561 |
if cache_params is not None and cache_position is not None and cache_position[0] > 0:
|
| 562 |
# We need to guarantee that anything regarding the cache is on the same device
|
| 563 |
-
cache_device = cache_params.ssm_states.device
|
| 564 |
|
| 565 |
# Note: there is no need to pad parameter matrices here, as there is just one new token
|
| 566 |
# for batched generation
|
|
@@ -1185,7 +1190,7 @@ class NemotronHOutput(ModelOutput):
|
|
| 1185 |
|
| 1186 |
@dataclass
|
| 1187 |
# Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
|
| 1188 |
-
class NemotronHCausalLMOutput(
|
| 1189 |
"""
|
| 1190 |
Base class for causal language model (or autoregressive) outputs.
|
| 1191 |
|
|
@@ -1208,7 +1213,7 @@ class NemotronHCausalLMOutput(ModelOutput):
|
|
| 1208 |
|
| 1209 |
loss: Optional[torch.FloatTensor] = None
|
| 1210 |
logits: Optional[torch.FloatTensor] = None
|
| 1211 |
-
|
| 1212 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 1213 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 1214 |
|
|
@@ -1568,7 +1573,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
| 1568 |
input_ids: Optional[torch.LongTensor] = None,
|
| 1569 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1570 |
position_ids: Optional[torch.LongTensor] = None,
|
| 1571 |
-
|
| 1572 |
labels: Optional[torch.LongTensor] = None,
|
| 1573 |
output_attentions: Optional[bool] = None,
|
| 1574 |
output_hidden_states: Optional[bool] = None,
|
|
@@ -1593,7 +1598,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
| 1593 |
|
| 1594 |
nemotron_h_outputs = self.backbone(
|
| 1595 |
input_ids,
|
| 1596 |
-
cache_params=
|
| 1597 |
inputs_embeds=inputs_embeds,
|
| 1598 |
output_attentions=output_attentions,
|
| 1599 |
output_hidden_states=output_hidden_states,
|
|
@@ -1626,7 +1631,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
| 1626 |
return NemotronHCausalLMOutput(
|
| 1627 |
loss=loss,
|
| 1628 |
logits=logits,
|
| 1629 |
-
|
| 1630 |
hidden_states=nemotron_h_outputs.hidden_states,
|
| 1631 |
attentions=nemotron_h_outputs.attentions,
|
| 1632 |
)
|
|
|
|
| 31 |
AttentionMaskConverter,
|
| 32 |
)
|
| 33 |
from transformers.modeling_utils import PreTrainedModel
|
| 34 |
+
from transformers.modeling_outputs import (
|
| 35 |
+
MoeCausalLMOutputWithPast,
|
| 36 |
+
)
|
| 37 |
from transformers.utils import (
|
| 38 |
ModelOutput,
|
| 39 |
add_code_sample_docstrings,
|
|
|
|
| 171 |
|
| 172 |
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
| 173 |
super().__init__()
|
| 174 |
+
self.device=device
|
| 175 |
self.dtype = dtype
|
| 176 |
self.hybrid_override_pattern = config.hybrid_override_pattern
|
| 177 |
self.has_previous_state = False # only used by mamba
|
| 178 |
+
self.intermediate_size = config.expand * config.hidden_size
|
| 179 |
+
self.ssm_state_size = config.ssm_state_size
|
| 180 |
+
self.conv_kernel_size = config.conv_kernel
|
| 181 |
+
self.conv_dim = self.intermediate_size + 2 * config.n_groups * config.ssm_state_size
|
| 182 |
self.conv_states = []
|
| 183 |
self.ssm_states = []
|
| 184 |
self.transformer_layers = []
|
|
|
|
| 186 |
if self.hybrid_override_pattern[i] == "M":
|
| 187 |
# Mamba layer
|
| 188 |
self.conv_states += [
|
| 189 |
+
torch.zeros(batch_size, self.conv_dim, self.conv_kernel_size, device=device, dtype=dtype)
|
| 190 |
]
|
| 191 |
self.ssm_states += [
|
| 192 |
+
torch.zeros(batch_size, self.intermediate_size, self.ssm_state_size, device=device, dtype=dtype)
|
| 193 |
]
|
| 194 |
else:
|
| 195 |
# Attention or MLP layer
|
|
|
|
| 250 |
self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
|
| 251 |
) -> torch.Tensor:
|
| 252 |
if cache_init:
|
| 253 |
+
self.conv_states[layer_idx] = new_conv_state.to(self.device)
|
| 254 |
else:
|
| 255 |
self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
|
| 256 |
+
self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[layer_idx].device)
|
| 257 |
return self.conv_states[layer_idx]
|
| 258 |
|
| 259 |
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
|
| 260 |
+
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
|
| 261 |
return self.ssm_states[layer_idx]
|
| 262 |
|
| 263 |
def reset(self):
|
|
|
|
| 418 |
dt_softplus=True,
|
| 419 |
)
|
| 420 |
hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
|
| 421 |
+
# TODO: why was there a breakpoint() call here?
|
| 422 |
hidden_states = self.norm(hidden_states, gate)
|
| 423 |
|
| 424 |
# 4. Final linear projection
|
|
|
|
| 565 |
A = -torch.exp(self.A_log.float()) # [num_heads]
|
| 566 |
if cache_params is not None and cache_position is not None and cache_position[0] > 0:
|
| 567 |
# We need to guarantee that anything regarding the cache is on the same device
|
| 568 |
+
cache_device = cache_params.ssm_states[0].device if len(cache_params.ssm_states) > 0 else cache_params.device
|
| 569 |
|
| 570 |
# Note: there is no need to pad parameter matrices here, as there is just one new token
|
| 571 |
# for batched generation
|
|
|
|
| 1190 |
|
| 1191 |
@dataclass
|
| 1192 |
# Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
|
| 1193 |
+
class NemotronHCausalLMOutput(MoeCausalLMOutputWithPast):
|
| 1194 |
"""
|
| 1195 |
Base class for causal language model (or autoregressive) outputs.
|
| 1196 |
|
|
|
|
| 1213 |
|
| 1214 |
loss: Optional[torch.FloatTensor] = None
|
| 1215 |
logits: Optional[torch.FloatTensor] = None
|
| 1216 |
+
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None
|
| 1217 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 1218 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 1219 |
|
|
|
|
| 1573 |
input_ids: Optional[torch.LongTensor] = None,
|
| 1574 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1575 |
position_ids: Optional[torch.LongTensor] = None,
|
| 1576 |
+
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
|
| 1577 |
labels: Optional[torch.LongTensor] = None,
|
| 1578 |
output_attentions: Optional[bool] = None,
|
| 1579 |
output_hidden_states: Optional[bool] = None,
|
|
|
|
| 1598 |
|
| 1599 |
nemotron_h_outputs = self.backbone(
|
| 1600 |
input_ids,
|
| 1601 |
+
cache_params=past_key_values,
|
| 1602 |
inputs_embeds=inputs_embeds,
|
| 1603 |
output_attentions=output_attentions,
|
| 1604 |
output_hidden_states=output_hidden_states,
|
|
|
|
| 1631 |
return NemotronHCausalLMOutput(
|
| 1632 |
loss=loss,
|
| 1633 |
logits=logits,
|
| 1634 |
+
past_key_values=nemotron_h_outputs.cache_params,
|
| 1635 |
hidden_states=nemotron_h_outputs.hidden_states,
|
| 1636 |
attentions=nemotron_h_outputs.attentions,
|
| 1637 |
)
|