File size: 3,034 Bytes
1bdd1c1
 
a18e93d
1bdd1c1
 
 
 
a18e93d
1bdd1c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a18e93d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn

from src.models.heads import (
    ClassificationHead,
    LMHead,
    ProjectionHead,
    TokenClassificationHead,
)


def test_classification_head_shapes_and_dropout():
    torch.manual_seed(0)
    d_model = 64
    num_labels = 5
    batch_size = 3
    seq_len = 10

    head = ClassificationHead(d_model=d_model, num_labels=num_labels, pooler="mean", dropout=0.5)
    head.train()
    x = torch.randn(batch_size, seq_len, d_model)

    out1 = head(x)
    out2 = head(x)
    # With dropout in train mode, outputs should usually differ
    assert out1.shape == (batch_size, num_labels)
    assert out2.shape == (batch_size, num_labels)
    assert not torch.allclose(out1, out2)

    head.eval()
    out3 = head(x)
    out4 = head(x)
    assert torch.allclose(out3, out4), "Eval mode should be deterministic"


def test_token_classification_head_shapes_and_grads():
    torch.manual_seed(1)
    d_model = 48
    num_labels = 7
    batch_size = 2
    seq_len = 6

    head = TokenClassificationHead(d_model=d_model, num_labels=num_labels, dropout=0.0)
    x = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
    out = head(x)
    assert out.shape == (batch_size, seq_len, num_labels)

    loss = out.sum()
    loss.backward()
    grads = [p.grad for name, p in head.named_parameters() if p.requires_grad]
    assert any(g is not None for g in grads)


def test_lm_head_tie_weights_and_shapes():
    torch.manual_seed(2)
    vocab_size = 50
    d_model = 32
    batch_size = 2
    seq_len = 4

    embedding = nn.Embedding(vocab_size, d_model)
    lm_tied = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=embedding)
    lm_untied = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=None)

    hidden = torch.randn(batch_size, seq_len, d_model)

    # Shapes
    logits_tied = lm_tied(hidden)
    logits_untied = lm_untied(hidden)
    assert logits_tied.shape == (batch_size, seq_len, vocab_size)
    assert logits_untied.shape == (batch_size, seq_len, vocab_size)

    # Weight tying: projection weight should be the same object as embedding.weight
    assert lm_tied.proj.weight is embedding.weight

    # Grad flows through tied weights
    loss = logits_tied.sum()
    loss.backward()
    assert embedding.weight.grad is not None


def test_projection_head_2d_and_3d_behavior_and_grad():
    torch.manual_seed(3)
    d_model = 40
    proj_dim = 16
    batch_size = 2
    seq_len = 5

    head = ProjectionHead(d_model=d_model, proj_dim=proj_dim, hidden_dim=64, dropout=0.0)
    # 2D input
    vec = torch.randn(batch_size, d_model, requires_grad=True)
    out2 = head(vec)
    assert out2.shape == (batch_size, proj_dim)

    # 3D input
    seq = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
    out3 = head(seq)
    assert out3.shape == (batch_size, seq_len, proj_dim)

    # Grad flow
    loss = out3.sum()
    loss.backward()
    grads = [p.grad for p in head.parameters() if p.requires_grad]
    assert any(g is not None for g in grads)