LexiMind / scripts /train.py
OliverPerrin
Update LexiMind: improved training, model architecture, and evaluation
1ec7405
"""
Training script for LexiMind.
Orchestrates dataset loading, model construction, torch.compile optimization,
and multi-task training with checkpoint management.
Author: Oliver Perrin
Date: December 2025
"""
from __future__ import annotations
import json
import logging
import os
import re
import sys
import time
import warnings
from pathlib import Path
from typing import Dict, Sequence, cast
# Suppress torch inductor warnings that mess up progress bars
os.environ.setdefault("TORCH_LOGS", "-all")
warnings.filterwarnings("ignore", category=UserWarning, module="torch._inductor")
warnings.filterwarnings("ignore", category=FutureWarning, module="mlflow")
logging.getLogger("torch._inductor").setLevel(logging.ERROR)
logging.getLogger("torch._dynamo").setLevel(logging.ERROR)
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.data.dataloader import (
build_emotion_dataloader,
build_summarization_dataloader,
build_topic_dataloader,
)
from src.data.dataset import (
EmotionDataset,
SummarizationDataset,
TopicDataset,
load_emotion_jsonl,
load_summarization_jsonl,
load_topic_jsonl,
)
from src.data.tokenization import Tokenizer, TokenizerConfig
from src.models.factory import ModelConfig, build_multitask_model
from src.training.trainer import Trainer, TrainerConfig
from src.training.utils import set_seed
from src.utils.io import load_state, save_state
from src.utils.labels import LabelMetadata, save_label_metadata
# --------------- Data Loading ---------------
SPLIT_ALIASES: Dict[str, Sequence[str]] = {
"train": ("train",),
"val": ("val", "validation"),
"test": ("test",),
}
def load_splits(data_dir: Path, loader) -> Dict[str, list]:
"""Load train/val/test splits from data directory."""
splits = {}
for name, aliases in SPLIT_ALIASES.items():
for alias in aliases:
for ext in ("jsonl", "json"):
path = data_dir / f"{alias}.{ext}"
if path.exists():
splits[name] = loader(str(path))
break
if name in splits:
break
if name not in splits:
raise FileNotFoundError(f"Missing {name} split in {data_dir}")
return splits
def limit_samples(splits: Dict[str, list], cfg: DictConfig) -> None:
"""Apply sample limits for dev/debug runs."""
for split, key in [("train", "max_train_samples"), ("val", "max_val_samples")]:
limit = cfg.get(key)
if limit and split in splits and len(splits[split]) > limit:
splits[split] = splits[split][: int(limit)]
print(f" {split}: limited to {limit} samples")
# --------------- Model Compilation ---------------
def compile_model(model: torch.nn.Module) -> torch.nn.Module:
"""Compile model with inductor backend (optimized for speed)."""
print(f" -> Enabling torch.compile for {model.__class__.__name__}...")
from src.training.safe_compile import apply_safe_config, compile_model_safe
# Apply safe configuration first
apply_safe_config()
# Compile with default mode (inductor) - most stable
return compile_model_safe(model, mode="default")
# --------------- Main ---------------
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(cfg: DictConfig) -> None:
start_time = time.perf_counter()
print(OmegaConf.to_yaml(cfg))
set_seed(cfg.seed)
# Benchmark mode: skip saving checkpoints (for speed testing)
benchmark_mode = cfg.get("benchmark", False)
if benchmark_mode:
print("⚡ BENCHMARK MODE: Checkpoints will NOT be saved")
# Enable TF32 for Ampere+ GPUs (RTX 30xx/40xx) - ~2x matmul speedup
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
print("✓ TF32 enabled for Ampere GPU")
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True # Auto-tune convolutions
torch.backends.cuda.enable_flash_sdp(True) # Flash attention if available
torch.backends.cuda.enable_mem_efficient_sdp(True) # Memory-efficient attention
# Disable debug APIs for max speed
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)
# --------------- Load Data ---------------
data_cfg = cfg.data
trainer_cfg = cfg.training.get("trainer", {})
print("\nLoading datasets...")
summ_splits = load_splits(Path(data_cfg.processed.summarization), load_summarization_jsonl)
emot_splits = load_splits(Path(data_cfg.processed.emotion), load_emotion_jsonl)
topic_splits = load_splits(Path(data_cfg.processed.topic), load_topic_jsonl)
# Apply dev/debug sample limits
for splits in [summ_splits, emot_splits, topic_splits]:
limit_samples(splits, trainer_cfg)
# --------------- Tokenizer & Datasets ---------------
tok_cfg = data_cfg.get("tokenizer", {})
# Allow training overrides for max_length to run shorter dev sweeps
override_max_len = cfg.training.get("tokenizer_max_length")
tokenizer = Tokenizer(
TokenizerConfig(
pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
max_length=int(override_max_len or tok_cfg.get("max_length", 512)),
lower=bool(tok_cfg.get("lower", False)),
)
)
summ_train = SummarizationDataset(summ_splits["train"])
summ_val = SummarizationDataset(summ_splits["val"])
emot_train = EmotionDataset(emot_splits["train"])
emot_val = EmotionDataset(emot_splits["val"], binarizer=emot_train.binarizer)
topic_train = TopicDataset(topic_splits["train"])
topic_val = TopicDataset(topic_splits["val"], encoder=topic_train.encoder)
# --------------- DataLoaders ---------------
dl_cfg = cfg.training.get("dataloader", {})
batch_size = int(dl_cfg.get("batch_size", 8))
num_workers = int(dl_cfg.get("num_workers", 4))
pin_memory = bool(dl_cfg.get("pin_memory", True))
max_len = tokenizer.config.max_length
train_loaders = {
"summarization": build_summarization_dataloader(
summ_train,
tokenizer,
shuffle=True,
max_source_length=max_len,
max_target_length=max_len,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
),
"emotion": build_emotion_dataloader(
emot_train,
tokenizer,
shuffle=True,
max_length=max_len,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
),
"topic": build_topic_dataloader(
topic_train,
tokenizer,
shuffle=True,
max_length=max_len,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
),
}
val_loaders = {
"summarization": build_summarization_dataloader(
summ_val,
tokenizer,
shuffle=False,
max_source_length=max_len,
max_target_length=max_len,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
),
"emotion": build_emotion_dataloader(
emot_val,
tokenizer,
shuffle=False,
max_length=max_len,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
),
"topic": build_topic_dataloader(
topic_val,
tokenizer,
shuffle=False,
max_length=max_len,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
),
}
# --------------- Model ---------------
print("\nBuilding model...")
device = torch.device(cfg.device)
model_cfg = ModelConfig(
d_model=cfg.model.d_model,
vocab_size=getattr(cfg.model, "vocab_size", None), # Override tokenizer vocab if specified
num_encoder_layers=cfg.model.num_encoder_layers,
num_decoder_layers=cfg.model.num_decoder_layers,
num_attention_heads=cfg.model.num_attention_heads,
ffn_dim=cfg.model.ffn_dim,
dropout=cfg.model.dropout,
use_pretrained=cfg.model.use_pretrained,
pretrained_model_name=cfg.model.pretrained_model_name,
activation=getattr(cfg.model, "activation", "gelu"),
use_relative_position_bias=getattr(cfg.model, "use_relative_position_bias", False),
)
model = build_multitask_model(
tokenizer,
num_emotions=len(emot_train.emotion_classes),
num_topics=len(topic_train.topic_classes),
config=model_cfg,
).to(device)
# If Training Crashes: Resume from checkpoint if provided (load before compile to avoid key mismatches)
start_epoch = 1
resume_path = cfg.get("resume_from")
if resume_path:
ckpt_path = Path(resume_path)
if ckpt_path.exists():
print(f"\n↩Resuming from checkpoint: {ckpt_path}")
load_state(model, str(ckpt_path))
# Parse epoch number robustly from filename (e.g., epoch_5.pt)
epoch_num = None
try:
# Prefer stem (no suffix); fallback to any digit sequence in name
digits = re.findall(r"\d+", ckpt_path.stem)
if digits:
epoch_num = int(digits[-1])
except Exception:
epoch_num = None
if epoch_num is not None:
start_epoch = epoch_num + 1
print(f" -> Starting from epoch {start_epoch}")
else:
print(" -> Could not parse epoch number; starting from epoch 1")
start_epoch = 1
else:
print(f"⚠ Resume checkpoint not found: {ckpt_path}. Starting from scratch.")
# Compile encoder/decoder for faster training (skip heads - small overhead)
compile_encoder = bool(cfg.training.get("compile_encoder", True))
compile_decoder = bool(cfg.training.get("compile_decoder", True))
if compile_encoder and model.encoder is not None:
from src.models.encoder import TransformerEncoder
model.encoder = cast(TransformerEncoder, compile_model(model.encoder))
if compile_decoder and model.decoder is not None:
from src.models.decoder import TransformerDecoder
model.decoder = cast(TransformerDecoder, compile_model(model.decoder))
# --------------- Optimizer & Trainer ---------------
opt_cfg = cfg.training.get("optimizer", {})
sched_cfg = cfg.training.get("scheduler", {})
optimizer = torch.optim.AdamW(
model.parameters(),
lr=float(opt_cfg.get("lr", 3e-5)),
weight_decay=float(opt_cfg.get("weight_decay", 0.01)),
)
# Clamp start_epoch to max_epochs to avoid empty loop
max_epochs = int(trainer_cfg.get("max_epochs", 1))
if start_epoch > max_epochs:
print(f"⚠ resume_from points past max_epochs ({max_epochs}); nothing to train. Setting start_epoch to {max_epochs}")
start_epoch = max_epochs
trainer = Trainer(
model=model,
optimizer=optimizer,
config=TrainerConfig(
max_epochs=max_epochs,
gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)),
task_weights=trainer_cfg.get("task_weights"),
label_smoothing=float(trainer_cfg.get("label_smoothing", 0.0)),
gradient_accumulation_steps=int(trainer_cfg.get("gradient_accumulation_steps", 1)),
scheduler_type=str(sched_cfg.get("name", "constant")),
warmup_steps=int(sched_cfg.get("warmup_steps", 0)),
),
device=device,
tokenizer=tokenizer,
)
# --------------- Train ---------------
def save_checkpoint(epoch: int, model: torch.nn.Module, history: Dict) -> None:
if benchmark_mode:
return # Skip saving in benchmark mode
path = Path(cfg.checkpoint_out).parent / f"epoch_{epoch}.pt"
path.parent.mkdir(parents=True, exist_ok=True)
save_state(model, str(path))
print("\nStarting training...")
history = trainer.fit(
train_loaders,
val_loaders,
checkpoint_callback=save_checkpoint,
start_epoch=start_epoch,
)
# --------------- Save Outputs ---------------
if benchmark_mode:
total_time = time.perf_counter() - start_time
print(f"\n{'=' * 50}")
print(f"⚡ Benchmark complete in {total_time:.1f}s")
print(" (No files saved in benchmark mode)")
print(f"{'=' * 50}")
return
# Best checkpoint
ckpt_path = Path(cfg.checkpoint_out)
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
save_state(model, str(ckpt_path))
# Labels
labels_path = Path(cfg.labels_out)
save_label_metadata(
LabelMetadata(emotion=emot_train.emotion_classes, topic=topic_train.topic_classes),
labels_path,
)
# History
history_path = Path(cfg.history_out)
history_path.parent.mkdir(parents=True, exist_ok=True)
with history_path.open("w") as f:
json.dump(history, f, indent=2)
total_time = time.perf_counter() - start_time
print(f"\n{'=' * 50}")
print(f"Training complete in {total_time:.1f}s")
print(f" Checkpoint: {ckpt_path}")
print(f" Labels: {labels_path}")
print(f" History: {history_path}")
print(f"{'=' * 50}")
if __name__ == "__main__":
main()