File size: 2,211 Bytes
5a20c96
 
 
 
 
 
1fbc47b
d18b34d
1fbc47b
 
5a20c96
 
 
 
 
d18b34d
5a20c96
 
 
 
d18b34d
5a20c96
 
d18b34d
5a20c96
 
d18b34d
5a20c96
 
 
d18b34d
5a20c96
 
 
 
d18b34d
5a20c96
 
 
 
 
d18b34d
5a20c96
d18b34d
5a20c96
 
d18b34d
5a20c96
 
d18b34d
5a20c96
 
 
 
 
d18b34d
5a20c96
 
 
d18b34d
5a20c96
 
d18b34d
5a20c96
 
d18b34d
5a20c96
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# tests/test_models/test_positional_encoding.py

"""
Tests for positional encoding.
"""

import matplotlib
import torch

matplotlib.use("Agg")  # use non-interactive backend for test environments
from src.models.positional_encoding import PositionalEncoding


class TestPositionalEncoding:
    """Test suite for PositionalEncoding."""

    def test_output_shape(self):
        """Test that output shape matches input shape."""
        d_model, max_len = 512, 5000
        batch_size, seq_len = 2, 100

        pos_enc = PositionalEncoding(d_model, max_len, dropout=0.0)
        x = torch.randn(batch_size, seq_len, d_model)

        output = pos_enc(x)
        assert output.shape == (batch_size, seq_len, d_model)

    def test_different_sequence_lengths(self):
        """Test with various sequence lengths."""
        pos_enc = PositionalEncoding(d_model=256, max_len=1000, dropout=0.0)

        for seq_len in [10, 50, 100, 500]:
            x = torch.randn(1, seq_len, 256)
            output = pos_enc(x)
            assert output.shape == (1, seq_len, 256)

    def test_dropout_changes_output(self):
        """Test that dropout is applied during training."""
        torch.manual_seed(42)
        pos_enc = PositionalEncoding(d_model=128, dropout=0.5)
        pos_enc.train()

        x = torch.randn(2, 10, 128)

        output1 = pos_enc(x)
        output2 = pos_enc(x)

        # Should be different due to dropout
        assert not torch.allclose(output1, output2)

        # In eval mode, should be deterministic
        pos_enc.eval()
        output3 = pos_enc(x)
        output4 = pos_enc(x)
        assert torch.allclose(output3, output4)

    def test_encoding_properties(self):
        """Test mathematical properties of encoding."""
        pos_enc = PositionalEncoding(d_model=128, max_len=100, dropout=0.0)

        # Get the raw encoding (without dropout)
        pe = pos_enc.pe[0]  # Remove batch dimension

        # Each row should have values in [-1, 1] (sin/cos range)
        assert (pe >= -1).all() and (pe <= 1).all()

        # Different positions should have different encodings
        assert not torch.allclose(pe[0], pe[1])
        assert not torch.allclose(pe[0], pe[50])