File size: 1,001 Bytes
590a604
 
 
 
 
 
 
 
ee1a8a3
1fbc47b
 
 
 
 
 
 
 
b43ba56
 
 
 
 
 
 
 
 
1fbc47b
 
 
 
b43ba56
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
Checkpoint I/O utilities for LexiMind.

Handles model state serialization with support for torch.compile artifacts.

Author: Oliver Perrin
Date: December 2025
"""

from pathlib import Path

import torch


def save_state(model: torch.nn.Module, path: str) -> None:
    destination = Path(path)
    destination.parent.mkdir(parents=True, exist_ok=True)

    # Handle torch.compile artifacts: strip '_orig_mod.' prefix
    state_dict = model.state_dict()
    clean_state_dict = {}
    for k, v in state_dict.items():
        new_k = k.replace("_orig_mod.", "")
        clean_state_dict[new_k] = v

    torch.save(clean_state_dict, destination)


def load_state(model: torch.nn.Module, path: str) -> None:
    state = torch.load(path, map_location="cpu", weights_only=True)

    # Handle torch.compile artifacts in loaded checkpoints
    clean_state = {}
    for k, v in state.items():
        new_k = k.replace("_orig_mod.", "")
        clean_state[new_k] = v

    model.load_state_dict(clean_state)