File size: 3,105 Bytes
590a604
 
 
 
 
 
 
 
 
ee1a8a3
1fbc47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee1a8a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fbc47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b43ba56
1fbc47b
 
 
 
 
 
 
 
 
 
 
 
b43ba56
1fbc47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Model export script for LexiMind.

Rebuilds the multitask model from configuration and exports trained weights
for deployment or distribution.

Author: Oliver Perrin
Date: December 2025
"""

from __future__ import annotations

import argparse
from pathlib import Path

import torch

from src.data.tokenization import Tokenizer, TokenizerConfig
from src.models.factory import build_multitask_model, load_model_config
from src.utils.config import load_yaml
from src.utils.labels import load_label_metadata


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Export LexiMind model weights")
    parser.add_argument(
        "--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint."
    )
    parser.add_argument(
        "--output", default="outputs/model.pt", help="Output path for the exported state dict."
    )
    parser.add_argument(
        "--labels",
        default="artifacts/labels.json",
        help="Label metadata JSON produced after training.",
    )
    parser.add_argument(
        "--model-config",
        default="configs/model/base.yaml",
        help="Model architecture configuration.",
    )
    parser.add_argument(
        "--data-config",
        default="configs/data/datasets.yaml",
        help="Data configuration (for tokenizer settings).",
    )
    return parser.parse_args()


def main() -> None:
    """Export multitask model weights from a training checkpoint to a standalone state dict."""
    args = parse_args()

    checkpoint = Path(args.checkpoint)
    if not checkpoint.exists():
        raise FileNotFoundError(checkpoint)

    labels = load_label_metadata(args.labels)
    data_cfg = load_yaml(args.data_config).data
    tokenizer_section = data_cfg.get("tokenizer", {})
    tokenizer_config = TokenizerConfig(
        pretrained_model_name=tokenizer_section.get("pretrained_model_name", "google/flan-t5-base"),
        max_length=int(tokenizer_section.get("max_length", 512)),
        lower=bool(tokenizer_section.get("lower", False)),
    )
    tokenizer = Tokenizer(tokenizer_config)

    model = build_multitask_model(
        tokenizer,
        num_emotions=labels.emotion_size,
        num_topics=labels.topic_size,
        config=load_model_config(args.model_config),
    )

    raw_state = torch.load(checkpoint, map_location="cuda")
    if isinstance(raw_state, dict):
        if "model_state_dict" in raw_state and isinstance(raw_state["model_state_dict"], dict):
            state_dict = raw_state["model_state_dict"]
        elif "state_dict" in raw_state and isinstance(raw_state["state_dict"], dict):
            state_dict = raw_state["state_dict"]
        else:
            state_dict = raw_state
    else:
        raise TypeError(f"Unsupported checkpoint format: expected dict, got {type(raw_state)!r}")
    model.load_state_dict(state_dict)

    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), output_path)
    print(f"Model exported to {output_path}")


if __name__ == "__main__":
    main()