LexiMind / tests /test_models /test_attention.py
OliverPerrin's picture
feat: Add FLAN-T5 compatibility with relative position bias
b43ba56
"""
Tests for attention mechanisms.
Run with: pytest tests/test_models/test_attention.py -v
"""
import pytest
import torch
from src.models.attention import MultiHeadAttention, ScaledDotProductAttention
class TestScaledDotProductAttention:
"""Test suite for ScaledDotProductAttention.
Note: ScaledDotProductAttention expects 4D inputs: (batch, num_heads, seq, d_k)
"""
def test_output_shape(self):
"""Test that output shapes are correct."""
attention = ScaledDotProductAttention()
batch_size, num_heads, seq_len, d_k = 2, 8, 10, 64
Q = torch.randn(batch_size, num_heads, seq_len, d_k)
K = torch.randn(batch_size, num_heads, seq_len, d_k)
V = torch.randn(batch_size, num_heads, seq_len, d_k)
output, weights = attention(Q, K, V, return_attn_weights=True)
assert output.shape == (batch_size, num_heads, seq_len, d_k)
assert weights.shape == (batch_size, num_heads, seq_len, seq_len)
def test_attention_weights_sum_to_one(self):
"""Test that attention weights are a valid probability distribution."""
attention = ScaledDotProductAttention()
batch_size, num_heads, seq_len, d_k = 2, 4, 10, 64
Q = K = V = torch.randn(batch_size, num_heads, seq_len, d_k)
_, weights = attention(Q, K, V, return_attn_weights=True)
# Each row should sum to 1 (probability distribution over keys)
row_sums = weights.sum(dim=-1)
assert torch.allclose(row_sums, torch.ones(batch_size, num_heads, seq_len), atol=1e-6)
def test_masking(self):
"""Test that masking properly zeros out attention to masked positions."""
attention = ScaledDotProductAttention()
batch_size, num_heads, seq_len, d_k = 1, 4, 5, 64
Q = K = V = torch.randn(batch_size, num_heads, seq_len, d_k)
# Create mask: only attend to first 3 positions (4D mask)
mask = torch.zeros(batch_size, 1, seq_len, seq_len, dtype=torch.bool)
mask[:, :, :, :3] = True # Attend to first 3 key positions
_, weights = attention(Q, K, V, mask, return_attn_weights=True)
# Key positions 3 and 4 should have zero attention weight
assert torch.allclose(
weights[:, :, :, 3:], torch.zeros(batch_size, num_heads, seq_len, 2), atol=1e-6
)
# TODO: Add more tests as you understand the mechanism better
class TestMultiHeadAttention:
"""Test suite for MultiHeadAttention."""
def test_output_shape(self):
"""Test that output shapes are correct."""
d_model, num_heads = 512, 8
batch_size, seq_len = 2, 10
mha = MultiHeadAttention(d_model, num_heads)
Q = K = V = torch.randn(batch_size, seq_len, d_model)
output, attn_weights = mha(Q, K, V, return_attn_weights=True)
assert output.shape == (batch_size, seq_len, d_model)
assert attn_weights.shape == (batch_size, num_heads, seq_len, seq_len)
def test_different_qkv(self):
"""Test with different Q, K, V (cross-attention scenario)."""
d_model, num_heads = 512, 8
batch_size = 2
seq_len_q, seq_len_kv = 10, 20
mha = MultiHeadAttention(d_model, num_heads)
Q = torch.randn(batch_size, seq_len_q, d_model)
K = torch.randn(batch_size, seq_len_kv, d_model)
V = torch.randn(batch_size, seq_len_kv, d_model)
output, attn_weights = mha(Q, K, V, return_attn_weights=True)
# Output has same length as query
assert output.shape == (batch_size, seq_len_q, d_model)
# Attention is query_len x key_len
assert attn_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_kv)
def test_masking(self):
"""Test that masking works correctly."""
d_model, num_heads = 512, 8
batch_size, seq_len = 2, 5
mha = MultiHeadAttention(d_model, num_heads)
Q = K = V = torch.randn(batch_size, seq_len, d_model)
# Mask out last 2 positions
mask = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool)
mask[:, :, -2:] = False
_, attn_weights = mha(Q, K, V, mask, return_attn_weights=True)
# Last 2 positions should have near-zero attention
assert torch.allclose(
attn_weights[:, :, :, -2:], torch.zeros(batch_size, num_heads, seq_len, 2), atol=1e-6
)
def test_parameters_exist(self):
"""Test that learnable parameters are created."""
mha = MultiHeadAttention(512, 8)
# Should have 4 linear layers worth of parameters
param_names = [name for name, _ in mha.named_parameters()]
assert any("W_Q" in name or "q_linear" in name.lower() for name in param_names)
assert any("W_K" in name or "k_linear" in name.lower() for name in param_names)
assert any("W_V" in name or "v_linear" in name.lower() for name in param_names)
assert any("W_O" in name or "out" in name.lower() for name in param_names)
def test_dropout_changes_output(self):
"""Test that dropout is actually applied during training."""
torch.manual_seed(42)
mha = MultiHeadAttention(512, 8, dropout=0.5)
mha.train() # Enable training mode
Q = K = V = torch.randn(2, 10, 512)
# Run twice with same input - should get different outputs due to dropout
output1, _ = mha(Q, K, V)
output2, _ = mha(Q, K, V)
assert not torch.allclose(output1, output2)
# In eval mode, should be deterministic
mha.eval()
output3, _ = mha(Q, K, V)
output4, _ = mha(Q, K, V)
assert torch.allclose(output3, output4)
if __name__ == "__main__":
pytest.main([__file__, "-v"])