File size: 5,722 Bytes
ba4cb76
 
 
 
 
 
 
 
d18b34d
 
 
ba4cb76
 
b43ba56
 
 
 
d18b34d
ba4cb76
 
 
b43ba56
d18b34d
b43ba56
 
 
d18b34d
 
 
b43ba56
 
d18b34d
ba4cb76
 
 
b43ba56
d18b34d
b43ba56
d18b34d
 
ba4cb76
 
b43ba56
d18b34d
ba4cb76
 
 
b43ba56
d18b34d
b43ba56
d18b34d
b43ba56
 
 
d18b34d
 
 
b43ba56
 
 
 
d18b34d
ba4cb76
d18b34d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba4cb76
 
 
d18b34d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Tests for attention mechanisms.

Run with: pytest tests/test_models/test_attention.py -v
"""

import pytest
import torch

from src.models.attention import MultiHeadAttention, ScaledDotProductAttention


class TestScaledDotProductAttention:
    """Test suite for ScaledDotProductAttention.

    Note: ScaledDotProductAttention expects 4D inputs: (batch, num_heads, seq, d_k)
    """

    def test_output_shape(self):
        """Test that output shapes are correct."""
        attention = ScaledDotProductAttention()
        batch_size, num_heads, seq_len, d_k = 2, 8, 10, 64

        Q = torch.randn(batch_size, num_heads, seq_len, d_k)
        K = torch.randn(batch_size, num_heads, seq_len, d_k)
        V = torch.randn(batch_size, num_heads, seq_len, d_k)

        output, weights = attention(Q, K, V, return_attn_weights=True)

        assert output.shape == (batch_size, num_heads, seq_len, d_k)
        assert weights.shape == (batch_size, num_heads, seq_len, seq_len)

    def test_attention_weights_sum_to_one(self):
        """Test that attention weights are a valid probability distribution."""
        attention = ScaledDotProductAttention()
        batch_size, num_heads, seq_len, d_k = 2, 4, 10, 64

        Q = K = V = torch.randn(batch_size, num_heads, seq_len, d_k)
        _, weights = attention(Q, K, V, return_attn_weights=True)

        # Each row should sum to 1 (probability distribution over keys)
        row_sums = weights.sum(dim=-1)
        assert torch.allclose(row_sums, torch.ones(batch_size, num_heads, seq_len), atol=1e-6)

    def test_masking(self):
        """Test that masking properly zeros out attention to masked positions."""
        attention = ScaledDotProductAttention()
        batch_size, num_heads, seq_len, d_k = 1, 4, 5, 64

        Q = K = V = torch.randn(batch_size, num_heads, seq_len, d_k)

        # Create mask: only attend to first 3 positions (4D mask)
        mask = torch.zeros(batch_size, 1, seq_len, seq_len, dtype=torch.bool)
        mask[:, :, :, :3] = True  # Attend to first 3 key positions

        _, weights = attention(Q, K, V, mask, return_attn_weights=True)

        # Key positions 3 and 4 should have zero attention weight
        assert torch.allclose(
            weights[:, :, :, 3:], torch.zeros(batch_size, num_heads, seq_len, 2), atol=1e-6
        )

    # TODO: Add more tests as you understand the mechanism better


class TestMultiHeadAttention:
    """Test suite for MultiHeadAttention."""

    def test_output_shape(self):
        """Test that output shapes are correct."""
        d_model, num_heads = 512, 8
        batch_size, seq_len = 2, 10

        mha = MultiHeadAttention(d_model, num_heads)

        Q = K = V = torch.randn(batch_size, seq_len, d_model)
        output, attn_weights = mha(Q, K, V, return_attn_weights=True)

        assert output.shape == (batch_size, seq_len, d_model)
        assert attn_weights.shape == (batch_size, num_heads, seq_len, seq_len)

    def test_different_qkv(self):
        """Test with different Q, K, V (cross-attention scenario)."""
        d_model, num_heads = 512, 8
        batch_size = 2
        seq_len_q, seq_len_kv = 10, 20

        mha = MultiHeadAttention(d_model, num_heads)

        Q = torch.randn(batch_size, seq_len_q, d_model)
        K = torch.randn(batch_size, seq_len_kv, d_model)
        V = torch.randn(batch_size, seq_len_kv, d_model)

        output, attn_weights = mha(Q, K, V, return_attn_weights=True)

        # Output has same length as query
        assert output.shape == (batch_size, seq_len_q, d_model)
        # Attention is query_len x key_len
        assert attn_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_kv)

    def test_masking(self):
        """Test that masking works correctly."""
        d_model, num_heads = 512, 8
        batch_size, seq_len = 2, 5

        mha = MultiHeadAttention(d_model, num_heads)
        Q = K = V = torch.randn(batch_size, seq_len, d_model)

        # Mask out last 2 positions
        mask = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool)
        mask[:, :, -2:] = False

        _, attn_weights = mha(Q, K, V, mask, return_attn_weights=True)

        # Last 2 positions should have near-zero attention
        assert torch.allclose(
            attn_weights[:, :, :, -2:], torch.zeros(batch_size, num_heads, seq_len, 2), atol=1e-6
        )

    def test_parameters_exist(self):
        """Test that learnable parameters are created."""
        mha = MultiHeadAttention(512, 8)

        # Should have 4 linear layers worth of parameters
        param_names = [name for name, _ in mha.named_parameters()]

        assert any("W_Q" in name or "q_linear" in name.lower() for name in param_names)
        assert any("W_K" in name or "k_linear" in name.lower() for name in param_names)
        assert any("W_V" in name or "v_linear" in name.lower() for name in param_names)
        assert any("W_O" in name or "out" in name.lower() for name in param_names)

    def test_dropout_changes_output(self):
        """Test that dropout is actually applied during training."""
        torch.manual_seed(42)
        mha = MultiHeadAttention(512, 8, dropout=0.5)
        mha.train()  # Enable training mode

        Q = K = V = torch.randn(2, 10, 512)

        # Run twice with same input - should get different outputs due to dropout
        output1, _ = mha(Q, K, V)
        output2, _ = mha(Q, K, V)

        assert not torch.allclose(output1, output2)

        # In eval mode, should be deterministic
        mha.eval()
        output3, _ = mha(Q, K, V)
        output4, _ = mha(Q, K, V)

        assert torch.allclose(output3, output4)


if __name__ == "__main__":
    pytest.main([__file__, "-v"])