import os import matplotlib import torch matplotlib.use("Agg") # use non-interactive backend import matplotlib.pyplot as plt import seaborn as sns from src.models.attention import MultiHeadAttention, ScaledDotProductAttention from src.models.positional_encoding import PositionalEncoding OUTPUTS_DIR = "outputs" def ensure_outputs_dir(): os.makedirs(OUTPUTS_DIR, exist_ok=True) def test_attention_visualization(): """Visual test to understand attention patterns.""" ensure_outputs_dir() attention = ScaledDotProductAttention() # Create a simple case: 5 tokens, each token attends most to itself batch_size = 1 seq_len = 5 d_k = 64 # Create Q, K, V torch.manual_seed(42) Q = torch.randn(batch_size, seq_len, d_k) K = torch.randn(batch_size, seq_len, d_k) V = torch.eye(seq_len, d_k).unsqueeze(0) # Identity-like # Compute attention _output, weights = attention(Q, K, V, return_attn_weights=True) # Plot attention weights plt.figure(figsize=(8, 6)) sns.heatmap( weights[0].detach().numpy(), annot=True, fmt=".2f", cmap="viridis", xticklabels=[f"Key {i}" for i in range(seq_len)], yticklabels=[f"Query {i}" for i in range(seq_len)], ) plt.title("Attention Weights Heatmap") plt.xlabel("Keys (What we attend TO)") plt.ylabel("Queries (What is attending)") plt.tight_layout() save_path = os.path.join(OUTPUTS_DIR, "attention_visualization.png") plt.savefig(save_path) print(f"✅ Saved visualization to {save_path}") plt.close() def test_visualize_multihead_attention(): """ Visual test to see what different attention heads learn. Creates a heatmap showing attention patterns for each head. """ ensure_outputs_dir() # Setup torch.manual_seed(42) d_model, num_heads = 512, 8 batch_size, seq_len = 1, 10 mha = MultiHeadAttention(d_model, num_heads, dropout=0.0) mha.eval() # No dropout for visualization # Create input with some structure # Let's make tokens attend to nearby tokens X = torch.randn(batch_size, seq_len, d_model) # Add positional bias (tokens are more similar to nearby tokens) for i in range(seq_len): for j in range(seq_len): distance = abs(i - j) X[0, i] += 0.5 * X[0, j] / (distance + 1) # Forward pass output, attn_weights = mha(X, X, X, return_attn_weights=True) # attn_weights shape: (1, 8, 10, 10) = batch, heads, query_pos, key_pos attn_weights = attn_weights[0].detach().numpy() # Remove batch dim: (8, 10, 10) # Create visualization fig, axes = plt.subplots(2, 4, figsize=(16, 8)) fig.suptitle("Multi-Head Attention: What Each Head Learns", fontsize=16, y=1.02) for head_idx in range(num_heads): row = head_idx // 4 col = head_idx % 4 ax = axes[row, col] # Plot attention heatmap for this head sns.heatmap( attn_weights[head_idx], annot=True, fmt=".2f", cmap="viridis", cbar=True, square=True, ax=ax, vmin=0, vmax=attn_weights[head_idx].max(), xticklabels=[f"K{i}" for i in range(seq_len)], yticklabels=[f"Q{i}" for i in range(seq_len)], ) ax.set_title(f"Head {head_idx}", fontweight="bold") ax.set_xlabel("Keys (attend TO)") ax.set_ylabel("Queries (attending FROM)") plt.tight_layout() save_path = os.path.join(OUTPUTS_DIR, "multihead_attention_visualization.png") plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f"✅ Saved visualization to {save_path}") plt.close() def test_compare_single_vs_multihead(): """ Compare single-head vs multi-head attention capacity. """ ensure_outputs_dir() torch.manual_seed(42) seq_len, d_model = 8, 512 X = torch.randn(1, seq_len, d_model) # Test with 1 head vs 8 heads mha_1head = MultiHeadAttention(d_model, num_heads=1, dropout=0.0) mha_8heads = MultiHeadAttention(d_model, num_heads=8, dropout=0.0) mha_1head.eval() mha_8heads.eval() _, attn_1head = mha_1head(X, X, X, return_attn_weights=True) _, attn_8heads = mha_8heads(X, X, X, return_attn_weights=True) # Plot comparison fig, axes = plt.subplots(1, 2, figsize=(12, 5)) # Single head sns.heatmap( attn_1head[0, 0].detach().numpy(), annot=True, fmt=".2f", cmap="viridis", cbar=True, square=True, ax=axes[0], ) axes[0].set_title("Single-Head Attention\n(Limited expressiveness)", fontweight="bold") axes[0].set_xlabel("Keys") axes[0].set_ylabel("Queries") # Multi-head average avg_attn = attn_8heads[0].mean(dim=0).detach().numpy() sns.heatmap(avg_attn, annot=True, fmt=".2f", cmap="viridis", cbar=True, square=True, ax=axes[1]) axes[1].set_title("8-Head Attention (Average)\n(Richer patterns)", fontweight="bold") axes[1].set_xlabel("Keys") axes[1].set_ylabel("Queries") plt.tight_layout() save_path = os.path.join(OUTPUTS_DIR, "single_vs_multihead.png") plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f"✅ Saved comparison to {save_path}") plt.close() def test_visualize_positional_encoding(): """ Visualize the positional encoding pattern. Creates heatmap showing encoding values. """ ensure_outputs_dir() pos_enc = PositionalEncoding(d_model=128, max_len=100, dropout=0.0) # Get encoding matrix pe = pos_enc.pe.squeeze(0).numpy() # (max_len, d_model) # Plot first 50 positions and 64 dimensions plt.figure(figsize=(12, 8)) sns.heatmap( pe[:50, :64].T, cmap="RdBu_r", center=0, xticklabels=5, yticklabels=8, cbar_kws={"label": "Encoding Value"}, ) plt.xlabel("Position in Sequence") plt.ylabel("Embedding Dimension") plt.title("Positional Encoding Pattern\n(Notice the wave patterns with different frequencies)") plt.tight_layout() save_path = os.path.join(OUTPUTS_DIR, "positional_encoding_heatmap.png") plt.savefig(save_path, dpi=150) print(f"✅ Saved to {save_path}") plt.close() if __name__ == "__main__": test_attention_visualization() test_visualize_multihead_attention() test_compare_single_vs_multihead() test_visualize_positional_encoding()