Spaces:
Running
Running
File size: 1,001 Bytes
590a604 ee1a8a3 1fbc47b b43ba56 1fbc47b b43ba56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
"""
Checkpoint I/O utilities for LexiMind.
Handles model state serialization with support for torch.compile artifacts.
Author: Oliver Perrin
Date: December 2025
"""
from pathlib import Path
import torch
def save_state(model: torch.nn.Module, path: str) -> None:
destination = Path(path)
destination.parent.mkdir(parents=True, exist_ok=True)
# Handle torch.compile artifacts: strip '_orig_mod.' prefix
state_dict = model.state_dict()
clean_state_dict = {}
for k, v in state_dict.items():
new_k = k.replace("_orig_mod.", "")
clean_state_dict[new_k] = v
torch.save(clean_state_dict, destination)
def load_state(model: torch.nn.Module, path: str) -> None:
state = torch.load(path, map_location="cpu", weights_only=True)
# Handle torch.compile artifacts in loaded checkpoints
clean_state = {}
for k, v in state.items():
new_k = k.replace("_orig_mod.", "")
clean_state[new_k] = v
model.load_state_dict(clean_state)
|