Spaces:
Running
Running
| """ | |
| Tests for attention mechanisms. | |
| Run with: pytest tests/test_models/test_attention.py -v | |
| """ | |
| import pytest | |
| import torch | |
| from src.models.attention import MultiHeadAttention, ScaledDotProductAttention | |
| class TestScaledDotProductAttention: | |
| """Test suite for ScaledDotProductAttention. | |
| Note: ScaledDotProductAttention expects 4D inputs: (batch, num_heads, seq, d_k) | |
| """ | |
| def test_output_shape(self): | |
| """Test that output shapes are correct.""" | |
| attention = ScaledDotProductAttention() | |
| batch_size, num_heads, seq_len, d_k = 2, 8, 10, 64 | |
| Q = torch.randn(batch_size, num_heads, seq_len, d_k) | |
| K = torch.randn(batch_size, num_heads, seq_len, d_k) | |
| V = torch.randn(batch_size, num_heads, seq_len, d_k) | |
| output, weights = attention(Q, K, V, return_attn_weights=True) | |
| assert output.shape == (batch_size, num_heads, seq_len, d_k) | |
| assert weights.shape == (batch_size, num_heads, seq_len, seq_len) | |
| def test_attention_weights_sum_to_one(self): | |
| """Test that attention weights are a valid probability distribution.""" | |
| attention = ScaledDotProductAttention() | |
| batch_size, num_heads, seq_len, d_k = 2, 4, 10, 64 | |
| Q = K = V = torch.randn(batch_size, num_heads, seq_len, d_k) | |
| _, weights = attention(Q, K, V, return_attn_weights=True) | |
| # Each row should sum to 1 (probability distribution over keys) | |
| row_sums = weights.sum(dim=-1) | |
| assert torch.allclose(row_sums, torch.ones(batch_size, num_heads, seq_len), atol=1e-6) | |
| def test_masking(self): | |
| """Test that masking properly zeros out attention to masked positions.""" | |
| attention = ScaledDotProductAttention() | |
| batch_size, num_heads, seq_len, d_k = 1, 4, 5, 64 | |
| Q = K = V = torch.randn(batch_size, num_heads, seq_len, d_k) | |
| # Create mask: only attend to first 3 positions (4D mask) | |
| mask = torch.zeros(batch_size, 1, seq_len, seq_len, dtype=torch.bool) | |
| mask[:, :, :, :3] = True # Attend to first 3 key positions | |
| _, weights = attention(Q, K, V, mask, return_attn_weights=True) | |
| # Key positions 3 and 4 should have zero attention weight | |
| assert torch.allclose( | |
| weights[:, :, :, 3:], torch.zeros(batch_size, num_heads, seq_len, 2), atol=1e-6 | |
| ) | |
| # TODO: Add more tests as you understand the mechanism better | |
| class TestMultiHeadAttention: | |
| """Test suite for MultiHeadAttention.""" | |
| def test_output_shape(self): | |
| """Test that output shapes are correct.""" | |
| d_model, num_heads = 512, 8 | |
| batch_size, seq_len = 2, 10 | |
| mha = MultiHeadAttention(d_model, num_heads) | |
| Q = K = V = torch.randn(batch_size, seq_len, d_model) | |
| output, attn_weights = mha(Q, K, V, return_attn_weights=True) | |
| assert output.shape == (batch_size, seq_len, d_model) | |
| assert attn_weights.shape == (batch_size, num_heads, seq_len, seq_len) | |
| def test_different_qkv(self): | |
| """Test with different Q, K, V (cross-attention scenario).""" | |
| d_model, num_heads = 512, 8 | |
| batch_size = 2 | |
| seq_len_q, seq_len_kv = 10, 20 | |
| mha = MultiHeadAttention(d_model, num_heads) | |
| Q = torch.randn(batch_size, seq_len_q, d_model) | |
| K = torch.randn(batch_size, seq_len_kv, d_model) | |
| V = torch.randn(batch_size, seq_len_kv, d_model) | |
| output, attn_weights = mha(Q, K, V, return_attn_weights=True) | |
| # Output has same length as query | |
| assert output.shape == (batch_size, seq_len_q, d_model) | |
| # Attention is query_len x key_len | |
| assert attn_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_kv) | |
| def test_masking(self): | |
| """Test that masking works correctly.""" | |
| d_model, num_heads = 512, 8 | |
| batch_size, seq_len = 2, 5 | |
| mha = MultiHeadAttention(d_model, num_heads) | |
| Q = K = V = torch.randn(batch_size, seq_len, d_model) | |
| # Mask out last 2 positions | |
| mask = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool) | |
| mask[:, :, -2:] = False | |
| _, attn_weights = mha(Q, K, V, mask, return_attn_weights=True) | |
| # Last 2 positions should have near-zero attention | |
| assert torch.allclose( | |
| attn_weights[:, :, :, -2:], torch.zeros(batch_size, num_heads, seq_len, 2), atol=1e-6 | |
| ) | |
| def test_parameters_exist(self): | |
| """Test that learnable parameters are created.""" | |
| mha = MultiHeadAttention(512, 8) | |
| # Should have 4 linear layers worth of parameters | |
| param_names = [name for name, _ in mha.named_parameters()] | |
| assert any("W_Q" in name or "q_linear" in name.lower() for name in param_names) | |
| assert any("W_K" in name or "k_linear" in name.lower() for name in param_names) | |
| assert any("W_V" in name or "v_linear" in name.lower() for name in param_names) | |
| assert any("W_O" in name or "out" in name.lower() for name in param_names) | |
| def test_dropout_changes_output(self): | |
| """Test that dropout is actually applied during training.""" | |
| torch.manual_seed(42) | |
| mha = MultiHeadAttention(512, 8, dropout=0.5) | |
| mha.train() # Enable training mode | |
| Q = K = V = torch.randn(2, 10, 512) | |
| # Run twice with same input - should get different outputs due to dropout | |
| output1, _ = mha(Q, K, V) | |
| output2, _ = mha(Q, K, V) | |
| assert not torch.allclose(output1, output2) | |
| # In eval mode, should be deterministic | |
| mha.eval() | |
| output3, _ = mha(Q, K, V) | |
| output4, _ = mha(Q, K, V) | |
| assert torch.allclose(output3, output4) | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |