Update gpt.py
Browse files
gpt.py
CHANGED
|
@@ -657,6 +657,22 @@ class GPT(nn.Module):
|
|
| 657 |
# report number of parameters
|
| 658 |
print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
|
| 659 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
def get_num_params(self, non_embedding=True):
|
| 661 |
"""
|
| 662 |
Return the number of parameters in the model.
|
|
|
|
| 657 |
# report number of parameters
|
| 658 |
print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
|
| 659 |
|
| 660 |
+
@torch.no_grad()
|
| 661 |
+
def _initialize_weights(self, module: nn.Module) -> None:
|
| 662 |
+
"""
|
| 663 |
+
Compatibility shim for newer `transformers` versions.
|
| 664 |
+
|
| 665 |
+
`transformers.PreTrainedModel.initialize_weights()` will treat any submodule that
|
| 666 |
+
defines `_init_weights` as a nested "sub-model" and will recursively call that
|
| 667 |
+
submodule's `_initialize_weights`. Our core `GPT` module historically only
|
| 668 |
+
implemented `_init_weights`, so we provide this wrapper to match HF's contract.
|
| 669 |
+
"""
|
| 670 |
+
if getattr(module, "_is_hf_initialized", False):
|
| 671 |
+
return
|
| 672 |
+
self._init_weights(module)
|
| 673 |
+
module._is_hf_initialized = True
|
| 674 |
+
|
| 675 |
+
|
| 676 |
def get_num_params(self, non_embedding=True):
|
| 677 |
"""
|
| 678 |
Return the number of parameters in the model.
|