Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from src.models.heads import ( | |
| ClassificationHead, | |
| LMHead, | |
| ProjectionHead, | |
| TokenClassificationHead, | |
| ) | |
| def test_classification_head_shapes_and_dropout(): | |
| torch.manual_seed(0) | |
| d_model = 64 | |
| num_labels = 5 | |
| batch_size = 3 | |
| seq_len = 10 | |
| head = ClassificationHead(d_model=d_model, num_labels=num_labels, pooler="mean", dropout=0.5) | |
| head.train() | |
| x = torch.randn(batch_size, seq_len, d_model) | |
| out1 = head(x) | |
| out2 = head(x) | |
| # With dropout in train mode, outputs should usually differ | |
| assert out1.shape == (batch_size, num_labels) | |
| assert out2.shape == (batch_size, num_labels) | |
| assert not torch.allclose(out1, out2) | |
| head.eval() | |
| out3 = head(x) | |
| out4 = head(x) | |
| assert torch.allclose(out3, out4), "Eval mode should be deterministic" | |
| def test_token_classification_head_shapes_and_grads(): | |
| torch.manual_seed(1) | |
| d_model = 48 | |
| num_labels = 7 | |
| batch_size = 2 | |
| seq_len = 6 | |
| head = TokenClassificationHead(d_model=d_model, num_labels=num_labels, dropout=0.0) | |
| x = torch.randn(batch_size, seq_len, d_model, requires_grad=True) | |
| out = head(x) | |
| assert out.shape == (batch_size, seq_len, num_labels) | |
| loss = out.sum() | |
| loss.backward() | |
| grads = [p.grad for name, p in head.named_parameters() if p.requires_grad] | |
| assert any(g is not None for g in grads) | |
| def test_lm_head_tie_weights_and_shapes(): | |
| torch.manual_seed(2) | |
| vocab_size = 50 | |
| d_model = 32 | |
| batch_size = 2 | |
| seq_len = 4 | |
| embedding = nn.Embedding(vocab_size, d_model) | |
| lm_tied = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=embedding) | |
| lm_untied = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=None) | |
| hidden = torch.randn(batch_size, seq_len, d_model) | |
| # Shapes | |
| logits_tied = lm_tied(hidden) | |
| logits_untied = lm_untied(hidden) | |
| assert logits_tied.shape == (batch_size, seq_len, vocab_size) | |
| assert logits_untied.shape == (batch_size, seq_len, vocab_size) | |
| # Weight tying: projection weight should be the same object as embedding.weight | |
| assert lm_tied.proj.weight is embedding.weight | |
| # Grad flows through tied weights | |
| loss = logits_tied.sum() | |
| loss.backward() | |
| assert embedding.weight.grad is not None | |
| def test_projection_head_2d_and_3d_behavior_and_grad(): | |
| torch.manual_seed(3) | |
| d_model = 40 | |
| proj_dim = 16 | |
| batch_size = 2 | |
| seq_len = 5 | |
| head = ProjectionHead(d_model=d_model, proj_dim=proj_dim, hidden_dim=64, dropout=0.0) | |
| # 2D input | |
| vec = torch.randn(batch_size, d_model, requires_grad=True) | |
| out2 = head(vec) | |
| assert out2.shape == (batch_size, proj_dim) | |
| # 3D input | |
| seq = torch.randn(batch_size, seq_len, d_model, requires_grad=True) | |
| out3 = head(seq) | |
| assert out3.shape == (batch_size, seq_len, proj_dim) | |
| # Grad flow | |
| loss = out3.sum() | |
| loss.backward() | |
| grads = [p.grad for p in head.parameters() if p.requires_grad] | |
| assert any(g is not None for g in grads) | |