LexiMind / tests /test_models /test_encoder_layer.py
OliverPerrin's picture
Style: Fix linting errors and organize imports (ruff & mypy)
a18e93d
import pytest
import torch
from src.models.encoder import TransformerEncoderLayer
def _take_tensor(output):
"""Return the tensor component regardless of (tensor, attn) tuple output."""
if isinstance(output, tuple): # modern layers return (output, attention)
return output[0]
return output
def test_output_shape_and_grad():
"""
The encoder layer should preserve the input shape (batch, seq_len, d_model)
and gradients should flow to parameters.
"""
d_model, num_heads, d_ff = 64, 8, 256
batch_size, seq_len = 2, 10
layer = TransformerEncoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0)
x = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
out = _take_tensor(layer(x)) # should accept mask=None by default
assert out.shape == (batch_size, seq_len, d_model)
# simple backward to ensure gradients propagate
loss = out.sum()
loss.backward()
grads = [p.grad for p in layer.parameters() if p.requires_grad]
assert any(g is not None for g in grads), "No gradients found on any parameter"
def test_dropout_behavior_train_vs_eval():
"""
With dropout > 0, the outputs should differ between two forward calls in train mode
and be identical in eval mode.
"""
torch.manual_seed(0)
d_model, num_heads, d_ff = 64, 8, 256
batch_size, seq_len = 2, 10
layer = TransformerEncoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.5)
x = torch.randn(batch_size, seq_len, d_model)
layer.train()
out1 = _take_tensor(layer(x))
out2 = _take_tensor(layer(x))
# Training mode with dropout: outputs usually differ
assert not torch.allclose(out1, out2), "Outputs identical in train mode despite dropout"
layer.eval()
out3 = _take_tensor(layer(x))
out4 = _take_tensor(layer(x))
# Eval mode deterministic: outputs should be identical
assert torch.allclose(out3, out4), "Outputs differ in eval mode"
def test_mask_broadcasting_accepts_3d_and_4d_mask():
"""
The encoder layer should accept a 3D mask (batch, seq_q, seq_k) and a 4D mask
(batch, 1, seq_q, seq_k) and handle broadcasting across heads without error.
"""
d_model, num_heads, d_ff = 64, 8, 256
batch_size, seq_len = 2, 7
layer = TransformerEncoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0)
x = torch.randn(batch_size, seq_len, d_model)
# 3D mask: (batch, seq, seq)
mask3 = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool)
mask3[:, :, -2:] = False # mask out last two key positions
out3 = _take_tensor(layer(x, mask=mask3)) # should not raise
assert out3.shape == (batch_size, seq_len, d_model)
# 4D mask: (batch, 1, seq, seq) already including head dim for broadcasting
mask4 = mask3.unsqueeze(1)
out4 = _take_tensor(layer(x, mask=mask4))
assert out4.shape == (batch_size, seq_len, d_model)
if __name__ == "__main__":
# Run tests interactively if needed
pytest.main([__file__, "-q"])