Spaces:
Running
Running
| """ | |
| Training script for LexiMind. | |
| Orchestrates dataset loading, model construction, torch.compile optimization, | |
| and multi-task training with checkpoint management. | |
| Author: Oliver Perrin | |
| Date: December 2025 | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| import time | |
| import warnings | |
| from pathlib import Path | |
| from typing import Dict, Sequence, cast | |
| # Suppress torch inductor warnings that mess up progress bars | |
| os.environ.setdefault("TORCH_LOGS", "-all") | |
| warnings.filterwarnings("ignore", category=UserWarning, module="torch._inductor") | |
| warnings.filterwarnings("ignore", category=FutureWarning, module="mlflow") | |
| logging.getLogger("torch._inductor").setLevel(logging.ERROR) | |
| logging.getLogger("torch._dynamo").setLevel(logging.ERROR) | |
| import hydra | |
| import torch | |
| from omegaconf import DictConfig, OmegaConf | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from src.data.dataloader import ( | |
| build_emotion_dataloader, | |
| build_summarization_dataloader, | |
| build_topic_dataloader, | |
| ) | |
| from src.data.dataset import ( | |
| EmotionDataset, | |
| SummarizationDataset, | |
| TopicDataset, | |
| load_emotion_jsonl, | |
| load_summarization_jsonl, | |
| load_topic_jsonl, | |
| ) | |
| from src.data.tokenization import Tokenizer, TokenizerConfig | |
| from src.models.factory import ModelConfig, build_multitask_model | |
| from src.training.trainer import Trainer, TrainerConfig | |
| from src.training.utils import set_seed | |
| from src.utils.io import load_state, save_state | |
| from src.utils.labels import LabelMetadata, save_label_metadata | |
| # --------------- Data Loading --------------- | |
| SPLIT_ALIASES: Dict[str, Sequence[str]] = { | |
| "train": ("train",), | |
| "val": ("val", "validation"), | |
| "test": ("test",), | |
| } | |
| def load_splits(data_dir: Path, loader) -> Dict[str, list]: | |
| """Load train/val/test splits from data directory.""" | |
| splits = {} | |
| for name, aliases in SPLIT_ALIASES.items(): | |
| for alias in aliases: | |
| for ext in ("jsonl", "json"): | |
| path = data_dir / f"{alias}.{ext}" | |
| if path.exists(): | |
| splits[name] = loader(str(path)) | |
| break | |
| if name in splits: | |
| break | |
| if name not in splits: | |
| raise FileNotFoundError(f"Missing {name} split in {data_dir}") | |
| return splits | |
| def limit_samples(splits: Dict[str, list], cfg: DictConfig) -> None: | |
| """Apply sample limits for dev/debug runs.""" | |
| for split, key in [("train", "max_train_samples"), ("val", "max_val_samples")]: | |
| limit = cfg.get(key) | |
| if limit and split in splits and len(splits[split]) > limit: | |
| splits[split] = splits[split][: int(limit)] | |
| print(f" {split}: limited to {limit} samples") | |
| # --------------- Model Compilation --------------- | |
| def compile_model(model: torch.nn.Module) -> torch.nn.Module: | |
| """Compile model with inductor backend (optimized for speed).""" | |
| print(f" -> Enabling torch.compile for {model.__class__.__name__}...") | |
| from src.training.safe_compile import apply_safe_config, compile_model_safe | |
| # Apply safe configuration first | |
| apply_safe_config() | |
| # Compile with default mode (inductor) - most stable | |
| return compile_model_safe(model, mode="default") | |
| # --------------- Main --------------- | |
| def main(cfg: DictConfig) -> None: | |
| start_time = time.perf_counter() | |
| print(OmegaConf.to_yaml(cfg)) | |
| set_seed(cfg.seed) | |
| # Benchmark mode: skip saving checkpoints (for speed testing) | |
| benchmark_mode = cfg.get("benchmark", False) | |
| if benchmark_mode: | |
| print("⚡ BENCHMARK MODE: Checkpoints will NOT be saved") | |
| # Enable TF32 for Ampere+ GPUs (RTX 30xx/40xx) - ~2x matmul speedup | |
| if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: | |
| print("✓ TF32 enabled for Ampere GPU") | |
| torch.set_float32_matmul_precision("high") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cudnn.benchmark = True # Auto-tune convolutions | |
| torch.backends.cuda.enable_flash_sdp(True) # Flash attention if available | |
| torch.backends.cuda.enable_mem_efficient_sdp(True) # Memory-efficient attention | |
| # Disable debug APIs for max speed | |
| torch.autograd.set_detect_anomaly(False) | |
| torch.autograd.profiler.profile(False) | |
| torch.autograd.profiler.emit_nvtx(False) | |
| # --------------- Load Data --------------- | |
| data_cfg = cfg.data | |
| trainer_cfg = cfg.training.get("trainer", {}) | |
| print("\nLoading datasets...") | |
| summ_splits = load_splits(Path(data_cfg.processed.summarization), load_summarization_jsonl) | |
| emot_splits = load_splits(Path(data_cfg.processed.emotion), load_emotion_jsonl) | |
| topic_splits = load_splits(Path(data_cfg.processed.topic), load_topic_jsonl) | |
| # Apply dev/debug sample limits | |
| for splits in [summ_splits, emot_splits, topic_splits]: | |
| limit_samples(splits, trainer_cfg) | |
| # --------------- Tokenizer & Datasets --------------- | |
| tok_cfg = data_cfg.get("tokenizer", {}) | |
| # Allow training overrides for max_length to run shorter dev sweeps | |
| override_max_len = cfg.training.get("tokenizer_max_length") | |
| tokenizer = Tokenizer( | |
| TokenizerConfig( | |
| pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"), | |
| max_length=int(override_max_len or tok_cfg.get("max_length", 512)), | |
| lower=bool(tok_cfg.get("lower", False)), | |
| ) | |
| ) | |
| summ_train = SummarizationDataset(summ_splits["train"]) | |
| summ_val = SummarizationDataset(summ_splits["val"]) | |
| emot_train = EmotionDataset(emot_splits["train"]) | |
| emot_val = EmotionDataset(emot_splits["val"], binarizer=emot_train.binarizer) | |
| topic_train = TopicDataset(topic_splits["train"]) | |
| topic_val = TopicDataset(topic_splits["val"], encoder=topic_train.encoder) | |
| # --------------- DataLoaders --------------- | |
| dl_cfg = cfg.training.get("dataloader", {}) | |
| batch_size = int(dl_cfg.get("batch_size", 8)) | |
| num_workers = int(dl_cfg.get("num_workers", 4)) | |
| pin_memory = bool(dl_cfg.get("pin_memory", True)) | |
| max_len = tokenizer.config.max_length | |
| train_loaders = { | |
| "summarization": build_summarization_dataloader( | |
| summ_train, | |
| tokenizer, | |
| shuffle=True, | |
| max_source_length=max_len, | |
| max_target_length=max_len, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| ), | |
| "emotion": build_emotion_dataloader( | |
| emot_train, | |
| tokenizer, | |
| shuffle=True, | |
| max_length=max_len, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| ), | |
| "topic": build_topic_dataloader( | |
| topic_train, | |
| tokenizer, | |
| shuffle=True, | |
| max_length=max_len, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| ), | |
| } | |
| val_loaders = { | |
| "summarization": build_summarization_dataloader( | |
| summ_val, | |
| tokenizer, | |
| shuffle=False, | |
| max_source_length=max_len, | |
| max_target_length=max_len, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| ), | |
| "emotion": build_emotion_dataloader( | |
| emot_val, | |
| tokenizer, | |
| shuffle=False, | |
| max_length=max_len, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| ), | |
| "topic": build_topic_dataloader( | |
| topic_val, | |
| tokenizer, | |
| shuffle=False, | |
| max_length=max_len, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| ), | |
| } | |
| # --------------- Model --------------- | |
| print("\nBuilding model...") | |
| device = torch.device(cfg.device) | |
| model_cfg = ModelConfig( | |
| d_model=cfg.model.d_model, | |
| vocab_size=getattr(cfg.model, "vocab_size", None), # Override tokenizer vocab if specified | |
| num_encoder_layers=cfg.model.num_encoder_layers, | |
| num_decoder_layers=cfg.model.num_decoder_layers, | |
| num_attention_heads=cfg.model.num_attention_heads, | |
| ffn_dim=cfg.model.ffn_dim, | |
| dropout=cfg.model.dropout, | |
| use_pretrained=cfg.model.use_pretrained, | |
| pretrained_model_name=cfg.model.pretrained_model_name, | |
| activation=getattr(cfg.model, "activation", "gelu"), | |
| use_relative_position_bias=getattr(cfg.model, "use_relative_position_bias", False), | |
| ) | |
| model = build_multitask_model( | |
| tokenizer, | |
| num_emotions=len(emot_train.emotion_classes), | |
| num_topics=len(topic_train.topic_classes), | |
| config=model_cfg, | |
| ).to(device) | |
| # If Training Crashes: Resume from checkpoint if provided (load before compile to avoid key mismatches) | |
| start_epoch = 1 | |
| resume_path = cfg.get("resume_from") | |
| if resume_path: | |
| ckpt_path = Path(resume_path) | |
| if ckpt_path.exists(): | |
| print(f"\n↩Resuming from checkpoint: {ckpt_path}") | |
| load_state(model, str(ckpt_path)) | |
| # Parse epoch number robustly from filename (e.g., epoch_5.pt) | |
| epoch_num = None | |
| try: | |
| # Prefer stem (no suffix); fallback to any digit sequence in name | |
| digits = re.findall(r"\d+", ckpt_path.stem) | |
| if digits: | |
| epoch_num = int(digits[-1]) | |
| except Exception: | |
| epoch_num = None | |
| if epoch_num is not None: | |
| start_epoch = epoch_num + 1 | |
| print(f" -> Starting from epoch {start_epoch}") | |
| else: | |
| print(" -> Could not parse epoch number; starting from epoch 1") | |
| start_epoch = 1 | |
| else: | |
| print(f"⚠ Resume checkpoint not found: {ckpt_path}. Starting from scratch.") | |
| # Compile encoder/decoder for faster training (skip heads - small overhead) | |
| compile_encoder = bool(cfg.training.get("compile_encoder", True)) | |
| compile_decoder = bool(cfg.training.get("compile_decoder", True)) | |
| if compile_encoder and model.encoder is not None: | |
| from src.models.encoder import TransformerEncoder | |
| model.encoder = cast(TransformerEncoder, compile_model(model.encoder)) | |
| if compile_decoder and model.decoder is not None: | |
| from src.models.decoder import TransformerDecoder | |
| model.decoder = cast(TransformerDecoder, compile_model(model.decoder)) | |
| # --------------- Optimizer & Trainer --------------- | |
| opt_cfg = cfg.training.get("optimizer", {}) | |
| sched_cfg = cfg.training.get("scheduler", {}) | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=float(opt_cfg.get("lr", 3e-5)), | |
| weight_decay=float(opt_cfg.get("weight_decay", 0.01)), | |
| ) | |
| # Clamp start_epoch to max_epochs to avoid empty loop | |
| max_epochs = int(trainer_cfg.get("max_epochs", 1)) | |
| if start_epoch > max_epochs: | |
| print(f"⚠ resume_from points past max_epochs ({max_epochs}); nothing to train. Setting start_epoch to {max_epochs}") | |
| start_epoch = max_epochs | |
| trainer = Trainer( | |
| model=model, | |
| optimizer=optimizer, | |
| config=TrainerConfig( | |
| max_epochs=max_epochs, | |
| gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)), | |
| task_weights=trainer_cfg.get("task_weights"), | |
| label_smoothing=float(trainer_cfg.get("label_smoothing", 0.0)), | |
| gradient_accumulation_steps=int(trainer_cfg.get("gradient_accumulation_steps", 1)), | |
| scheduler_type=str(sched_cfg.get("name", "constant")), | |
| warmup_steps=int(sched_cfg.get("warmup_steps", 0)), | |
| ), | |
| device=device, | |
| tokenizer=tokenizer, | |
| ) | |
| # --------------- Train --------------- | |
| def save_checkpoint(epoch: int, model: torch.nn.Module, history: Dict) -> None: | |
| if benchmark_mode: | |
| return # Skip saving in benchmark mode | |
| path = Path(cfg.checkpoint_out).parent / f"epoch_{epoch}.pt" | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| save_state(model, str(path)) | |
| print("\nStarting training...") | |
| history = trainer.fit( | |
| train_loaders, | |
| val_loaders, | |
| checkpoint_callback=save_checkpoint, | |
| start_epoch=start_epoch, | |
| ) | |
| # --------------- Save Outputs --------------- | |
| if benchmark_mode: | |
| total_time = time.perf_counter() - start_time | |
| print(f"\n{'=' * 50}") | |
| print(f"⚡ Benchmark complete in {total_time:.1f}s") | |
| print(" (No files saved in benchmark mode)") | |
| print(f"{'=' * 50}") | |
| return | |
| # Best checkpoint | |
| ckpt_path = Path(cfg.checkpoint_out) | |
| ckpt_path.parent.mkdir(parents=True, exist_ok=True) | |
| save_state(model, str(ckpt_path)) | |
| # Labels | |
| labels_path = Path(cfg.labels_out) | |
| save_label_metadata( | |
| LabelMetadata(emotion=emot_train.emotion_classes, topic=topic_train.topic_classes), | |
| labels_path, | |
| ) | |
| # History | |
| history_path = Path(cfg.history_out) | |
| history_path.parent.mkdir(parents=True, exist_ok=True) | |
| with history_path.open("w") as f: | |
| json.dump(history, f, indent=2) | |
| total_time = time.perf_counter() - start_time | |
| print(f"\n{'=' * 50}") | |
| print(f"Training complete in {total_time:.1f}s") | |
| print(f" Checkpoint: {ckpt_path}") | |
| print(f" Labels: {labels_path}") | |
| print(f" History: {history_path}") | |
| print(f"{'=' * 50}") | |
| if __name__ == "__main__": | |
| main() | |