""" Visualize training metrics from MLflow runs. Generates plots showing: - Loss curves (training/validation) - Task-specific metrics over time - Learning rate schedule - Training speed analysis Author: Oliver Perrin Date: December 2025 """ from __future__ import annotations import json import sys from pathlib import Path import matplotlib.pyplot as plt import mlflow import mlflow.tracking import seaborn as sns PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from src.utils.logging import configure_logging, get_logger configure_logging() logger = get_logger(__name__) # Configure plotting style sns.set_style("whitegrid") plt.rcParams["figure.figsize"] = (12, 8) plt.rcParams["figure.dpi"] = 100 OUTPUTS_DIR = PROJECT_ROOT / "outputs" MLRUNS_DIR = PROJECT_ROOT / "mlruns" def load_training_history() -> dict[str, object] | None: """Load training history from JSON if available.""" history_path = OUTPUTS_DIR / "training_history.json" if history_path.exists(): with open(history_path) as f: data: dict[str, object] = json.load(f) return data return None def get_latest_run(): """Get the most recent MLflow run.""" mlflow.set_tracking_uri(f"file://{MLRUNS_DIR}") client = mlflow.tracking.MlflowClient() # Get the experiment (LexiMind) experiment = client.get_experiment_by_name("LexiMind") if not experiment: logger.error("No 'LexiMind' experiment found") return None # Get all runs, sorted by start time runs = client.search_runs( experiment_ids=[experiment.experiment_id], order_by=["start_time DESC"], max_results=1, ) if not runs: logger.error("No runs found in experiment") return None return runs[0] def plot_loss_curves(run): """Plot training and validation loss over time.""" client = mlflow.tracking.MlflowClient() # Get metrics train_loss = client.get_metric_history(run.info.run_id, "train_total_loss") val_loss = client.get_metric_history(run.info.run_id, "val_total_loss") fig, ax = plt.subplots(figsize=(12, 6)) if not train_loss: # Create placeholder plot ax.text( 0.5, 0.5, "No training data yet\n\nWaiting for first epoch to complete...", ha="center", va="center", fontsize=14, color="gray", ) ax.set_xlim(0, 1) ax.set_ylim(0, 1) else: # Extract steps and values train_steps = [m.step for m in train_loss] train_values = [m.value for m in train_loss] ax.plot(train_steps, train_values, label="Training Loss", linewidth=2, alpha=0.8) if val_loss: val_steps = [m.step for m in val_loss] val_values = [m.value for m in val_loss] ax.plot(val_steps, val_values, label="Validation Loss", linewidth=2, alpha=0.8) ax.legend(fontsize=11) ax.set_xlabel("Epoch", fontsize=12) ax.set_ylabel("Loss", fontsize=12) ax.set_title("Training Progress: Total Loss", fontsize=14, fontweight="bold") ax.grid(True, alpha=0.3) plt.tight_layout() output_path = OUTPUTS_DIR / "training_loss_curve.png" plt.savefig(output_path, dpi=150, bbox_inches="tight") logger.info(f"✓ Saved loss curve to {output_path}") plt.close() def plot_task_metrics(run): """Plot metrics for each task.""" client = mlflow.tracking.MlflowClient() fig, axes = plt.subplots(2, 2, figsize=(14, 10)) fig.suptitle("Task-Specific Training Metrics", fontsize=16, fontweight="bold") # Summarization ax = axes[0, 0] train_sum = client.get_metric_history(run.info.run_id, "train_summarization_loss") val_sum = client.get_metric_history(run.info.run_id, "val_summarization_loss") if train_sum: ax.plot( [m.step for m in train_sum], [m.value for m in train_sum], label="Train", linewidth=2 ) if val_sum: ax.plot([m.step for m in val_sum], [m.value for m in val_sum], label="Val", linewidth=2) ax.set_title("Summarization Loss", fontweight="bold") ax.set_xlabel("Epoch") ax.set_ylabel("Loss") ax.legend() ax.grid(True, alpha=0.3) # Emotion ax = axes[0, 1] train_emo = client.get_metric_history(run.info.run_id, "train_emotion_loss") val_emo = client.get_metric_history(run.info.run_id, "val_emotion_loss") train_f1 = client.get_metric_history(run.info.run_id, "train_emotion_f1") val_f1 = client.get_metric_history(run.info.run_id, "val_emotion_f1") if train_emo: ax.plot( [m.step for m in train_emo], [m.value for m in train_emo], label="Train Loss", linewidth=2, ) if val_emo: ax.plot( [m.step for m in val_emo], [m.value for m in val_emo], label="Val Loss", linewidth=2 ) ax2 = ax.twinx() if train_f1: ax2.plot( [m.step for m in train_f1], [m.value for m in train_f1], label="Train F1", linewidth=2, linestyle="--", alpha=0.7, ) if val_f1: ax2.plot( [m.step for m in val_f1], [m.value for m in val_f1], label="Val F1", linewidth=2, linestyle="--", alpha=0.7, ) ax.set_title("Emotion Detection", fontweight="bold") ax.set_xlabel("Epoch") ax.set_ylabel("Loss") ax2.set_ylabel("F1 Score") ax.legend(loc="upper left") ax2.legend(loc="upper right") ax.grid(True, alpha=0.3) # Topic ax = axes[1, 0] train_topic = client.get_metric_history(run.info.run_id, "train_topic_loss") val_topic = client.get_metric_history(run.info.run_id, "val_topic_loss") train_acc = client.get_metric_history(run.info.run_id, "train_topic_accuracy") val_acc = client.get_metric_history(run.info.run_id, "val_topic_accuracy") if train_topic: ax.plot( [m.step for m in train_topic], [m.value for m in train_topic], label="Train Loss", linewidth=2, ) if val_topic: ax.plot( [m.step for m in val_topic], [m.value for m in val_topic], label="Val Loss", linewidth=2 ) ax2 = ax.twinx() if train_acc: ax2.plot( [m.step for m in train_acc], [m.value for m in train_acc], label="Train Acc", linewidth=2, linestyle="--", alpha=0.7, ) if val_acc: ax2.plot( [m.step for m in val_acc], [m.value for m in val_acc], label="Val Acc", linewidth=2, linestyle="--", alpha=0.7, ) ax.set_title("Topic Classification", fontweight="bold") ax.set_xlabel("Epoch") ax.set_ylabel("Loss") ax2.set_ylabel("Accuracy") ax.legend(loc="upper left") ax2.legend(loc="upper right") ax.grid(True, alpha=0.3) # Summary statistics ax = axes[1, 1] ax.axis("off") # Get final metrics summary_text = "Final Metrics (Last Epoch)\n" + "=" * 35 + "\n\n" if val_topic and val_acc: summary_text += f"Topic Accuracy: {val_acc[-1].value:.1%}\n" if val_emo and val_f1: summary_text += f"Emotion F1: {val_f1[-1].value:.1%}\n" if val_sum: summary_text += f"Summarization Loss: {val_sum[-1].value:.3f}\n" ax.text(0.1, 0.5, summary_text, fontsize=12, family="monospace", verticalalignment="center") plt.tight_layout() output_path = OUTPUTS_DIR / "task_metrics.png" plt.savefig(output_path, dpi=150, bbox_inches="tight") logger.info(f"✓ Saved task metrics to {output_path}") plt.close() def plot_learning_rate(run): """Plot learning rate schedule if available.""" client = mlflow.tracking.MlflowClient() lr_metrics = client.get_metric_history(run.info.run_id, "learning_rate") fig, ax = plt.subplots(figsize=(12, 5)) if not lr_metrics: # Create placeholder ax.text( 0.5, 0.5, "No learning rate data yet\n\n(Will be logged in future training runs)", ha="center", va="center", fontsize=14, color="gray", ) ax.set_xlim(0, 1) ax.set_ylim(0, 1) else: steps = [m.step for m in lr_metrics] values = [m.value for m in lr_metrics] ax.plot(steps, values, linewidth=2, color="darkblue") # Mark warmup region warmup_steps = 1000 # From config if warmup_steps < max(steps): ax.axvline(warmup_steps, color="red", linestyle="--", alpha=0.5, label="Warmup End") ax.legend() ax.set_xlabel("Step", fontsize=12) ax.set_ylabel("Learning Rate", fontsize=12) ax.set_title("Learning Rate Schedule (Cosine with Warmup)", fontsize=14, fontweight="bold") ax.grid(True, alpha=0.3) plt.tight_layout() output_path = OUTPUTS_DIR / "learning_rate_schedule.png" plt.savefig(output_path, dpi=150, bbox_inches="tight") logger.info(f"✓ Saved LR schedule to {output_path}") plt.close() def main(): """Generate all training visualizations.""" logger.info("Loading MLflow data...") run = get_latest_run() if not run: logger.error("No training run found. Make sure training has started.") return logger.info(f"Analyzing run: {run.info.run_id}") OUTPUTS_DIR.mkdir(parents=True, exist_ok=True) logger.info("Generating visualizations...") plot_loss_curves(run) plot_task_metrics(run) plot_learning_rate(run) logger.info("\n" + "=" * 60) logger.info("✓ All visualizations saved to outputs/") logger.info("=" * 60) logger.info(" - training_loss_curve.png") logger.info(" - task_metrics.png") logger.info(" - learning_rate_schedule.png") logger.info("=" * 60) if __name__ == "__main__": main()