Spaces:
Running
Running
File size: 2,211 Bytes
5a20c96 1fbc47b d18b34d 1fbc47b 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# tests/test_models/test_positional_encoding.py
"""
Tests for positional encoding.
"""
import matplotlib
import torch
matplotlib.use("Agg") # use non-interactive backend for test environments
from src.models.positional_encoding import PositionalEncoding
class TestPositionalEncoding:
"""Test suite for PositionalEncoding."""
def test_output_shape(self):
"""Test that output shape matches input shape."""
d_model, max_len = 512, 5000
batch_size, seq_len = 2, 100
pos_enc = PositionalEncoding(d_model, max_len, dropout=0.0)
x = torch.randn(batch_size, seq_len, d_model)
output = pos_enc(x)
assert output.shape == (batch_size, seq_len, d_model)
def test_different_sequence_lengths(self):
"""Test with various sequence lengths."""
pos_enc = PositionalEncoding(d_model=256, max_len=1000, dropout=0.0)
for seq_len in [10, 50, 100, 500]:
x = torch.randn(1, seq_len, 256)
output = pos_enc(x)
assert output.shape == (1, seq_len, 256)
def test_dropout_changes_output(self):
"""Test that dropout is applied during training."""
torch.manual_seed(42)
pos_enc = PositionalEncoding(d_model=128, dropout=0.5)
pos_enc.train()
x = torch.randn(2, 10, 128)
output1 = pos_enc(x)
output2 = pos_enc(x)
# Should be different due to dropout
assert not torch.allclose(output1, output2)
# In eval mode, should be deterministic
pos_enc.eval()
output3 = pos_enc(x)
output4 = pos_enc(x)
assert torch.allclose(output3, output4)
def test_encoding_properties(self):
"""Test mathematical properties of encoding."""
pos_enc = PositionalEncoding(d_model=128, max_len=100, dropout=0.0)
# Get the raw encoding (without dropout)
pe = pos_enc.pe[0] # Remove batch dimension
# Each row should have values in [-1, 1] (sin/cos range)
assert (pe >= -1).all() and (pe <= 1).all()
# Different positions should have different encodings
assert not torch.allclose(pe[0], pe[1])
assert not torch.allclose(pe[0], pe[50])
|