Spaces:
Running
Running
| """ | |
| Model export script for LexiMind. | |
| Rebuilds the multitask model from configuration and exports trained weights | |
| for deployment or distribution. | |
| Author: Oliver Perrin | |
| Date: December 2025 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| import torch | |
| from src.data.tokenization import Tokenizer, TokenizerConfig | |
| from src.models.factory import build_multitask_model, load_model_config | |
| from src.utils.config import load_yaml | |
| from src.utils.labels import load_label_metadata | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Export LexiMind model weights") | |
| parser.add_argument( | |
| "--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint." | |
| ) | |
| parser.add_argument( | |
| "--output", default="outputs/model.pt", help="Output path for the exported state dict." | |
| ) | |
| parser.add_argument( | |
| "--labels", | |
| default="artifacts/labels.json", | |
| help="Label metadata JSON produced after training.", | |
| ) | |
| parser.add_argument( | |
| "--model-config", | |
| default="configs/model/base.yaml", | |
| help="Model architecture configuration.", | |
| ) | |
| parser.add_argument( | |
| "--data-config", | |
| default="configs/data/datasets.yaml", | |
| help="Data configuration (for tokenizer settings).", | |
| ) | |
| return parser.parse_args() | |
| def main() -> None: | |
| """Export multitask model weights from a training checkpoint to a standalone state dict.""" | |
| args = parse_args() | |
| checkpoint = Path(args.checkpoint) | |
| if not checkpoint.exists(): | |
| raise FileNotFoundError(checkpoint) | |
| labels = load_label_metadata(args.labels) | |
| data_cfg = load_yaml(args.data_config).data | |
| tokenizer_section = data_cfg.get("tokenizer", {}) | |
| tokenizer_config = TokenizerConfig( | |
| pretrained_model_name=tokenizer_section.get("pretrained_model_name", "google/flan-t5-base"), | |
| max_length=int(tokenizer_section.get("max_length", 512)), | |
| lower=bool(tokenizer_section.get("lower", False)), | |
| ) | |
| tokenizer = Tokenizer(tokenizer_config) | |
| model = build_multitask_model( | |
| tokenizer, | |
| num_emotions=labels.emotion_size, | |
| num_topics=labels.topic_size, | |
| config=load_model_config(args.model_config), | |
| ) | |
| raw_state = torch.load(checkpoint, map_location="cuda") | |
| if isinstance(raw_state, dict): | |
| if "model_state_dict" in raw_state and isinstance(raw_state["model_state_dict"], dict): | |
| state_dict = raw_state["model_state_dict"] | |
| elif "state_dict" in raw_state and isinstance(raw_state["state_dict"], dict): | |
| state_dict = raw_state["state_dict"] | |
| else: | |
| state_dict = raw_state | |
| else: | |
| raise TypeError(f"Unsupported checkpoint format: expected dict, got {type(raw_state)!r}") | |
| model.load_state_dict(state_dict) | |
| output_path = Path(args.output) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save(model.state_dict(), output_path) | |
| print(f"Model exported to {output_path}") | |
| if __name__ == "__main__": | |
| main() | |