change rotary base (#31)
Browse files- feat: rotary base as a property (c1200891411b6198ca6448cfebf5123d15bf2c31)
- Merge branch 'main' into pr/31 (c2ead96805f8278295d48fda36eba1d96ed3bffb)
Co-authored-by: Jack Min Ong <[email protected]>
- configuration_xlm_roberta.py +2 -0
- modeling_lora.py +8 -0
- modeling_xlm_roberta.py +14 -3
- rotary.py +12 -1
configuration_xlm_roberta.py
CHANGED
|
@@ -20,6 +20,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
| 20 |
bos_token_id=0,
|
| 21 |
eos_token_id=2,
|
| 22 |
position_embedding_type="absolute",
|
|
|
|
| 23 |
use_cache=True,
|
| 24 |
classifier_dropout=None,
|
| 25 |
lora_adaptations=None,
|
|
@@ -52,6 +53,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
| 52 |
self.initializer_range = initializer_range
|
| 53 |
self.layer_norm_eps = layer_norm_eps
|
| 54 |
self.position_embedding_type = position_embedding_type
|
|
|
|
| 55 |
self.use_cache = use_cache
|
| 56 |
self.classifier_dropout = classifier_dropout
|
| 57 |
self.load_trained_adapters = load_trained_adapters
|
|
|
|
| 20 |
bos_token_id=0,
|
| 21 |
eos_token_id=2,
|
| 22 |
position_embedding_type="absolute",
|
| 23 |
+
rotary_emb_base=10000.0,
|
| 24 |
use_cache=True,
|
| 25 |
classifier_dropout=None,
|
| 26 |
lora_adaptations=None,
|
|
|
|
| 53 |
self.initializer_range = initializer_range
|
| 54 |
self.layer_norm_eps = layer_norm_eps
|
| 55 |
self.position_embedding_type = position_embedding_type
|
| 56 |
+
self.rotary_emb_base = rotary_emb_base
|
| 57 |
self.use_cache = use_cache
|
| 58 |
self.classifier_dropout = classifier_dropout
|
| 59 |
self.load_trained_adapters = load_trained_adapters
|
modeling_lora.py
CHANGED
|
@@ -262,6 +262,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 262 |
self.main_params_trainable = config.lora_main_params_trainable
|
| 263 |
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
@property
|
| 266 |
def main_params_trainable(self):
|
| 267 |
return self._main_params_trainable
|
|
|
|
| 262 |
self.main_params_trainable = config.lora_main_params_trainable
|
| 263 |
|
| 264 |
|
| 265 |
+
@property
|
| 266 |
+
def rotary_emb_base(self):
|
| 267 |
+
return self.roberta.rotary_emb_base
|
| 268 |
+
|
| 269 |
+
@rotary_emb_base.setter
|
| 270 |
+
def rotary_emb_base(self, base):
|
| 271 |
+
self.roberta.rotary_emb_base = base
|
| 272 |
+
|
| 273 |
@property
|
| 274 |
def main_params_trainable(self):
|
| 275 |
return self._main_params_trainable
|
modeling_xlm_roberta.py
CHANGED
|
@@ -93,7 +93,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
| 93 |
rotary_kwargs["rotary_emb_dim"] = getattr(
|
| 94 |
config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
|
| 95 |
)
|
| 96 |
-
rotary_kwargs["rotary_emb_base"] =
|
| 97 |
rotary_kwargs["rotary_emb_scale_base"] = getattr(
|
| 98 |
config, "rotary_emb_scale_base", None
|
| 99 |
)
|
|
@@ -450,6 +450,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 450 |
|
| 451 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 452 |
self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
|
|
|
|
| 453 |
|
| 454 |
@torch.inference_mode()
|
| 455 |
def encode(
|
|
@@ -599,7 +600,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 599 |
self.train(is_training)
|
| 600 |
return all_embeddings
|
| 601 |
|
| 602 |
-
|
| 603 |
def truncate_embeddings(self, embeddings, truncate_dim):
|
| 604 |
if not self.config.matryoshka_dimensions:
|
| 605 |
logger.warning(
|
|
@@ -622,12 +622,23 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 622 |
input_mask_expanded.sum(1), min=1e-9
|
| 623 |
)
|
| 624 |
|
| 625 |
-
|
| 626 |
def cls_pooling(
|
| 627 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
| 628 |
):
|
| 629 |
return token_embeddings[:,0]
|
| 630 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
|
| 632 |
def forward(
|
| 633 |
self,
|
|
|
|
| 93 |
rotary_kwargs["rotary_emb_dim"] = getattr(
|
| 94 |
config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
|
| 95 |
)
|
| 96 |
+
rotary_kwargs["rotary_emb_base"] = config.rotary_emb_base
|
| 97 |
rotary_kwargs["rotary_emb_scale_base"] = getattr(
|
| 98 |
config, "rotary_emb_scale_base", None
|
| 99 |
)
|
|
|
|
| 450 |
|
| 451 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 452 |
self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
|
| 453 |
+
self._rotary_emb_base = config.rotary_emb_base
|
| 454 |
|
| 455 |
@torch.inference_mode()
|
| 456 |
def encode(
|
|
|
|
| 600 |
self.train(is_training)
|
| 601 |
return all_embeddings
|
| 602 |
|
|
|
|
| 603 |
def truncate_embeddings(self, embeddings, truncate_dim):
|
| 604 |
if not self.config.matryoshka_dimensions:
|
| 605 |
logger.warning(
|
|
|
|
| 622 |
input_mask_expanded.sum(1), min=1e-9
|
| 623 |
)
|
| 624 |
|
|
|
|
| 625 |
def cls_pooling(
|
| 626 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
| 627 |
):
|
| 628 |
return token_embeddings[:,0]
|
| 629 |
|
| 630 |
+
@property
|
| 631 |
+
def rotary_emb_base(self):
|
| 632 |
+
return self._rotary_emb_base
|
| 633 |
+
|
| 634 |
+
@rotary_emb_base.setter
|
| 635 |
+
def rotary_emb_base(self, base):
|
| 636 |
+
if not isinstance(base, (int, float)):
|
| 637 |
+
raise TypeError("Base must be an integer or float")
|
| 638 |
+
logger.info(f'Changing RoPE base value to {base}')
|
| 639 |
+
for layer in self.encoder.layers:
|
| 640 |
+
layer.mixer.rotary_emb.base = base
|
| 641 |
+
self._rotary_emb_base = base
|
| 642 |
|
| 643 |
def forward(
|
| 644 |
self,
|
rotary.py
CHANGED
|
@@ -443,7 +443,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 443 |
"""
|
| 444 |
super().__init__()
|
| 445 |
self.dim = dim
|
| 446 |
-
self.
|
| 447 |
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
| 448 |
# Generate and save the inverse frequency buffer (non trainable)
|
| 449 |
inv_freq = self._compute_inv_freq(device)
|
|
@@ -463,6 +463,17 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 463 |
self._cos_k_cached = None
|
| 464 |
self._sin_k_cached = None
|
| 465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
def _compute_inv_freq(self, device=None):
|
| 467 |
return 1.0 / (
|
| 468 |
self.base
|
|
|
|
| 443 |
"""
|
| 444 |
super().__init__()
|
| 445 |
self.dim = dim
|
| 446 |
+
self._base = float(base)
|
| 447 |
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
| 448 |
# Generate and save the inverse frequency buffer (non trainable)
|
| 449 |
inv_freq = self._compute_inv_freq(device)
|
|
|
|
| 463 |
self._cos_k_cached = None
|
| 464 |
self._sin_k_cached = None
|
| 465 |
|
| 466 |
+
@property
|
| 467 |
+
def base(self):
|
| 468 |
+
return self._base
|
| 469 |
+
|
| 470 |
+
@base.setter
|
| 471 |
+
def base(self, new_base):
|
| 472 |
+
if new_base > 0:
|
| 473 |
+
self._base = float(new_base)
|
| 474 |
+
else:
|
| 475 |
+
raise ValueError("Rotary base value must be positive")
|
| 476 |
+
|
| 477 |
def _compute_inv_freq(self, device=None):
|
| 478 |
return 1.0 / (
|
| 479 |
self.base
|