Spaces:
Running
Running
File size: 3,034 Bytes
1bdd1c1 a18e93d 1bdd1c1 a18e93d 1bdd1c1 a18e93d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import torch
import torch.nn as nn
from src.models.heads import (
ClassificationHead,
LMHead,
ProjectionHead,
TokenClassificationHead,
)
def test_classification_head_shapes_and_dropout():
torch.manual_seed(0)
d_model = 64
num_labels = 5
batch_size = 3
seq_len = 10
head = ClassificationHead(d_model=d_model, num_labels=num_labels, pooler="mean", dropout=0.5)
head.train()
x = torch.randn(batch_size, seq_len, d_model)
out1 = head(x)
out2 = head(x)
# With dropout in train mode, outputs should usually differ
assert out1.shape == (batch_size, num_labels)
assert out2.shape == (batch_size, num_labels)
assert not torch.allclose(out1, out2)
head.eval()
out3 = head(x)
out4 = head(x)
assert torch.allclose(out3, out4), "Eval mode should be deterministic"
def test_token_classification_head_shapes_and_grads():
torch.manual_seed(1)
d_model = 48
num_labels = 7
batch_size = 2
seq_len = 6
head = TokenClassificationHead(d_model=d_model, num_labels=num_labels, dropout=0.0)
x = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
out = head(x)
assert out.shape == (batch_size, seq_len, num_labels)
loss = out.sum()
loss.backward()
grads = [p.grad for name, p in head.named_parameters() if p.requires_grad]
assert any(g is not None for g in grads)
def test_lm_head_tie_weights_and_shapes():
torch.manual_seed(2)
vocab_size = 50
d_model = 32
batch_size = 2
seq_len = 4
embedding = nn.Embedding(vocab_size, d_model)
lm_tied = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=embedding)
lm_untied = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=None)
hidden = torch.randn(batch_size, seq_len, d_model)
# Shapes
logits_tied = lm_tied(hidden)
logits_untied = lm_untied(hidden)
assert logits_tied.shape == (batch_size, seq_len, vocab_size)
assert logits_untied.shape == (batch_size, seq_len, vocab_size)
# Weight tying: projection weight should be the same object as embedding.weight
assert lm_tied.proj.weight is embedding.weight
# Grad flows through tied weights
loss = logits_tied.sum()
loss.backward()
assert embedding.weight.grad is not None
def test_projection_head_2d_and_3d_behavior_and_grad():
torch.manual_seed(3)
d_model = 40
proj_dim = 16
batch_size = 2
seq_len = 5
head = ProjectionHead(d_model=d_model, proj_dim=proj_dim, hidden_dim=64, dropout=0.0)
# 2D input
vec = torch.randn(batch_size, d_model, requires_grad=True)
out2 = head(vec)
assert out2.shape == (batch_size, proj_dim)
# 3D input
seq = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
out3 = head(seq)
assert out3.shape == (batch_size, seq_len, proj_dim)
# Grad flow
loss = out3.sum()
loss.backward()
grads = [p.grad for p in head.parameters() if p.requires_grad]
assert any(g is not None for g in grads)
|