Spaces:
Running
Running
File size: 5,722 Bytes
ba4cb76 d18b34d ba4cb76 b43ba56 d18b34d ba4cb76 b43ba56 d18b34d b43ba56 d18b34d b43ba56 d18b34d ba4cb76 b43ba56 d18b34d b43ba56 d18b34d ba4cb76 b43ba56 d18b34d ba4cb76 b43ba56 d18b34d b43ba56 d18b34d b43ba56 d18b34d b43ba56 d18b34d ba4cb76 d18b34d ba4cb76 d18b34d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""
Tests for attention mechanisms.
Run with: pytest tests/test_models/test_attention.py -v
"""
import pytest
import torch
from src.models.attention import MultiHeadAttention, ScaledDotProductAttention
class TestScaledDotProductAttention:
"""Test suite for ScaledDotProductAttention.
Note: ScaledDotProductAttention expects 4D inputs: (batch, num_heads, seq, d_k)
"""
def test_output_shape(self):
"""Test that output shapes are correct."""
attention = ScaledDotProductAttention()
batch_size, num_heads, seq_len, d_k = 2, 8, 10, 64
Q = torch.randn(batch_size, num_heads, seq_len, d_k)
K = torch.randn(batch_size, num_heads, seq_len, d_k)
V = torch.randn(batch_size, num_heads, seq_len, d_k)
output, weights = attention(Q, K, V, return_attn_weights=True)
assert output.shape == (batch_size, num_heads, seq_len, d_k)
assert weights.shape == (batch_size, num_heads, seq_len, seq_len)
def test_attention_weights_sum_to_one(self):
"""Test that attention weights are a valid probability distribution."""
attention = ScaledDotProductAttention()
batch_size, num_heads, seq_len, d_k = 2, 4, 10, 64
Q = K = V = torch.randn(batch_size, num_heads, seq_len, d_k)
_, weights = attention(Q, K, V, return_attn_weights=True)
# Each row should sum to 1 (probability distribution over keys)
row_sums = weights.sum(dim=-1)
assert torch.allclose(row_sums, torch.ones(batch_size, num_heads, seq_len), atol=1e-6)
def test_masking(self):
"""Test that masking properly zeros out attention to masked positions."""
attention = ScaledDotProductAttention()
batch_size, num_heads, seq_len, d_k = 1, 4, 5, 64
Q = K = V = torch.randn(batch_size, num_heads, seq_len, d_k)
# Create mask: only attend to first 3 positions (4D mask)
mask = torch.zeros(batch_size, 1, seq_len, seq_len, dtype=torch.bool)
mask[:, :, :, :3] = True # Attend to first 3 key positions
_, weights = attention(Q, K, V, mask, return_attn_weights=True)
# Key positions 3 and 4 should have zero attention weight
assert torch.allclose(
weights[:, :, :, 3:], torch.zeros(batch_size, num_heads, seq_len, 2), atol=1e-6
)
# TODO: Add more tests as you understand the mechanism better
class TestMultiHeadAttention:
"""Test suite for MultiHeadAttention."""
def test_output_shape(self):
"""Test that output shapes are correct."""
d_model, num_heads = 512, 8
batch_size, seq_len = 2, 10
mha = MultiHeadAttention(d_model, num_heads)
Q = K = V = torch.randn(batch_size, seq_len, d_model)
output, attn_weights = mha(Q, K, V, return_attn_weights=True)
assert output.shape == (batch_size, seq_len, d_model)
assert attn_weights.shape == (batch_size, num_heads, seq_len, seq_len)
def test_different_qkv(self):
"""Test with different Q, K, V (cross-attention scenario)."""
d_model, num_heads = 512, 8
batch_size = 2
seq_len_q, seq_len_kv = 10, 20
mha = MultiHeadAttention(d_model, num_heads)
Q = torch.randn(batch_size, seq_len_q, d_model)
K = torch.randn(batch_size, seq_len_kv, d_model)
V = torch.randn(batch_size, seq_len_kv, d_model)
output, attn_weights = mha(Q, K, V, return_attn_weights=True)
# Output has same length as query
assert output.shape == (batch_size, seq_len_q, d_model)
# Attention is query_len x key_len
assert attn_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_kv)
def test_masking(self):
"""Test that masking works correctly."""
d_model, num_heads = 512, 8
batch_size, seq_len = 2, 5
mha = MultiHeadAttention(d_model, num_heads)
Q = K = V = torch.randn(batch_size, seq_len, d_model)
# Mask out last 2 positions
mask = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool)
mask[:, :, -2:] = False
_, attn_weights = mha(Q, K, V, mask, return_attn_weights=True)
# Last 2 positions should have near-zero attention
assert torch.allclose(
attn_weights[:, :, :, -2:], torch.zeros(batch_size, num_heads, seq_len, 2), atol=1e-6
)
def test_parameters_exist(self):
"""Test that learnable parameters are created."""
mha = MultiHeadAttention(512, 8)
# Should have 4 linear layers worth of parameters
param_names = [name for name, _ in mha.named_parameters()]
assert any("W_Q" in name or "q_linear" in name.lower() for name in param_names)
assert any("W_K" in name or "k_linear" in name.lower() for name in param_names)
assert any("W_V" in name or "v_linear" in name.lower() for name in param_names)
assert any("W_O" in name or "out" in name.lower() for name in param_names)
def test_dropout_changes_output(self):
"""Test that dropout is actually applied during training."""
torch.manual_seed(42)
mha = MultiHeadAttention(512, 8, dropout=0.5)
mha.train() # Enable training mode
Q = K = V = torch.randn(2, 10, 512)
# Run twice with same input - should get different outputs due to dropout
output1, _ = mha(Q, K, V)
output2, _ = mha(Q, K, V)
assert not torch.allclose(output1, output2)
# In eval mode, should be deterministic
mha.eval()
output3, _ = mha(Q, K, V)
output4, _ = mha(Q, K, V)
assert torch.allclose(output3, output4)
if __name__ == "__main__":
pytest.main([__file__, "-v"])
|