from typing import Any, Dict, cast import torch from src.models.decoder import TransformerDecoder def test_step_equivalence_with_greedy_decode(): torch.manual_seed(7) vocab_size = 25 d_model = 32 num_layers = 2 num_heads = 4 d_ff = 64 batch_size = 2 src_len = 6 max_tgt = 6 decoder = TransformerDecoder( vocab_size=vocab_size, d_model=d_model, num_layers=num_layers, num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=max_tgt, pad_token_id=0, ) memory = torch.randn(batch_size, src_len, d_model) # 1) Get greedy sequence from naive greedy_decode greedy = decoder.greedy_decode(memory, max_len=max_tgt, start_token_id=1, end_token_id=None) # 2) Reproduce the same sequence with step() using cache cache: Dict[str, Any] = {"past_length": 0} generated = torch.full((batch_size, 1), 1, dtype=torch.long) for _ in range(max_tgt - 1): last_token = generated[:, -1:].to(memory.device) logits, cache = decoder.step(cast(torch.LongTensor, last_token), memory, cache=cache) next_token = logits.argmax(dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) # Compare shapes & that sequences are identical assert generated.shape == greedy.shape assert torch.equal(generated, greedy) def test_step_cache_growth_and_shapes(): torch.manual_seed(9) vocab_size = 20 d_model = 24 num_layers = 3 num_heads = 4 d_ff = 64 batch_size = 1 src_len = 5 steps = 4 max_tgt = 8 decoder = TransformerDecoder( vocab_size=vocab_size, d_model=d_model, num_layers=num_layers, num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=max_tgt, pad_token_id=0, ) memory = torch.randn(batch_size, src_len, d_model) cache: Dict[str, Any] = {"past_length": 0} last = torch.full((batch_size, 1), 1, dtype=torch.long) for step_idx in range(steps): logits, cache = decoder.step(cast(torch.LongTensor, last), memory, cache=cache) # check updated past_length assert cache["past_length"] == step_idx + 1 # check cached per-layer keys exist and have expected shape (B, H, seq_len, d_k) for i in range(num_layers): k = cache.get(f"self_k_{i}") v = cache.get(f"self_v_{i}") assert k is not None and v is not None # seq_len should equal past_length assert k.shape[2] == cache["past_length"] # shapes match assert k.shape[0] == batch_size assert v.shape[0] == batch_size # advance last token for next loop last = logits.argmax(dim=-1, keepdim=True) # Also ensure memory projections cached for i in range(num_layers): assert f"mem_k_{i}" in cache and f"mem_v_{i}" in cache mem_k = cache[f"mem_k_{i}"] assert mem_k.shape[0] == batch_size assert mem_k.shape[2] == src_len # seq length of memory