achyutarajaram commited on
Commit
5cd6989
·
verified ·
1 Parent(s): 7eae0f9

Update gpt.py

Browse files
Files changed (1) hide show
  1. gpt.py +16 -0
gpt.py CHANGED
@@ -657,6 +657,22 @@ class GPT(nn.Module):
657
  # report number of parameters
658
  print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
  def get_num_params(self, non_embedding=True):
661
  """
662
  Return the number of parameters in the model.
 
657
  # report number of parameters
658
  print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
659
 
660
+ @torch.no_grad()
661
+ def _initialize_weights(self, module: nn.Module) -> None:
662
+ """
663
+ Compatibility shim for newer `transformers` versions.
664
+
665
+ `transformers.PreTrainedModel.initialize_weights()` will treat any submodule that
666
+ defines `_init_weights` as a nested "sub-model" and will recursively call that
667
+ submodule's `_initialize_weights`. Our core `GPT` module historically only
668
+ implemented `_init_weights`, so we provide this wrapper to match HF's contract.
669
+ """
670
+ if getattr(module, "_is_hf_initialized", False):
671
+ return
672
+ self._init_weights(module)
673
+ module._is_hf_initialized = True
674
+
675
+
676
  def get_num_params(self, non_embedding=True):
677
  """
678
  Return the number of parameters in the model.