"""Multitask model composition utilities. This module provides infrastructure for multi-task learning: - MultiTaskModel: Compose encoder/decoder with multiple task heads - Routing: forward(task_name, ...) dispatches to correct components - Loss computation: Built-in cross-entropy with ignore_index support Author: Oliver Perrin Date: 2025-10-23 """ from typing import Any, Dict, Optional import torch import torch.nn as nn import torch.nn.functional as F from .decoder import TransformerDecoder # Import your components from .encoder import TransformerEncoder from .heads import ClassificationHead, LMHead, TokenClassificationHead class MultiTaskModel(nn.Module): """ Compose encoder/decoder and task heads. Usage patterns: - Encoder-only classification: mt = MultiTaskModel(encoder=enc) mt.add_head("sentiment", ClassificationHead(...)) logits = mt.forward("sentiment", {"input_ids": src_ids}) - Seq2seq LM: mt = MultiTaskModel(encoder=enc, decoder=dec) mt.add_head("summarize", LMHead(...)) logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids}) Args: encoder: optional encoder backbone. decoder: optional decoder backbone. decoder_outputs_logits: set True when ``decoder.forward`` already returns vocabulary logits; set False if the decoder produces hidden states that must be projected by the LM head. """ def __init__( self, encoder: Optional[TransformerEncoder] = None, decoder: Optional[TransformerDecoder] = None, *, decoder_outputs_logits: bool = True, ): super().__init__() self.encoder = encoder self.decoder = decoder self.heads: Dict[str, nn.Module] = {} # When True, decoder.forward(...) is expected to return logits already projected to the vocabulary space. # When False, decoder outputs hidden states that must be passed through the registered LM head. self.decoder_outputs_logits = decoder_outputs_logits def add_head(self, name: str, module: nn.Module) -> None: """Register a head under a task name.""" if name in self.heads: raise ValueError(f"Head '{name}' already exists") self.heads[name] = module self.add_module(f"head_{name}", module) def remove_head(self, name: str) -> None: """Remove a registered head.""" if name not in self.heads: raise KeyError(name) del self._modules[f"head_{name}"] del self.heads[name] def forward( self, task: str, inputs: Dict[str, torch.Tensor], return_loss: bool = False, loss_kwargs: Optional[Dict[str, Any]] = None, ) -> Any: """ Route inputs to appropriate model components and head. Args: task: registered head name inputs: dictionary; common keys: - For encoder tasks: "input_ids" or "embeddings" (B, S) or (B, S, d) - For seq2seq: "src_ids" (B,S) or "src_embeddings", and "tgt_ids" (B,T) or "tgt_embeddings" when computing training loss, pass "labels" (B,T) for LM return_loss: if True and labels provided, returns (loss, logits) loss_kwargs: forwarded to compute_loss (e.g., ignore_index) Returns: logits (or (loss, logits) if return_loss True) """ if task not in self.heads: raise KeyError(f"Unknown task/head '{task}'") head = self.heads[task] # Unwrap for type checking if compiled check_head = head if hasattr(head, "_orig_mod"): check_head = head._orig_mod loss_kwargs = loss_kwargs or {} # Encoder-only heads expect encoder outputs if isinstance(check_head, (ClassificationHead, TokenClassificationHead)): if self.encoder is None: raise RuntimeError("Encoder is required for encoder-side heads") # accept either input_ids or embeddings if "input_ids" in inputs: encoder_mask = None if "attention_mask" in inputs: encoder_mask = self._expand_attention_mask( inputs["attention_mask"], inputs["input_ids"].device ) enc_out = self.encoder(inputs["input_ids"], mask=encoder_mask) elif "embeddings" in inputs: encoder_mask = inputs.get("attention_mask") if encoder_mask is not None: encoder_mask = self._expand_attention_mask( encoder_mask, inputs["embeddings"].device ) enc_out = self.encoder(inputs["embeddings"], mask=encoder_mask) else: raise ValueError( "inputs must contain 'input_ids' or 'embeddings' for encoder tasks" ) # Pass attention_mask to head if available (needed for mean pooling to ignore padding) if isinstance(check_head, ClassificationHead): logits = head(enc_out, mask=inputs.get("attention_mask")) else: logits = head(enc_out) if return_loss: labels = inputs.get("labels", None) if labels is None: raise ValueError("return_loss=True requires 'labels' in inputs") loss = self.compute_loss_for_head(check_head, logits, labels, **loss_kwargs) return loss, logits return logits # LM/seq2seq head: run encoder -> decoder -> lm head if isinstance(check_head, LMHead): if self.encoder is None or self.decoder is None: raise RuntimeError("Both encoder and decoder are required for LM-style heads") # Build encoder memory src_mask = inputs.get("src_mask") if src_mask is None: src_mask = inputs.get("attention_mask") encoder_mask = None reference_tensor = inputs.get("src_ids") if reference_tensor is None: reference_tensor = inputs.get("src_embeddings") if src_mask is not None and reference_tensor is not None: encoder_mask = self._expand_attention_mask(src_mask, reference_tensor.device) if "src_ids" in inputs: memory = self.encoder(inputs["src_ids"], mask=encoder_mask) elif "src_embeddings" in inputs: memory = self.encoder(inputs["src_embeddings"], mask=encoder_mask) else: raise ValueError( "inputs must contain 'src_ids' or 'src_embeddings' for seq2seq tasks" ) # Clone memory to prevent CUDA Graph buffer overwrites when passing between compiled graphs # This fixes "accessing tensor output of CUDAGraphs that has been overwritten" error if isinstance(memory, torch.Tensor): memory = memory.clone() # If training / teacher forcing: expect tgt_ids (shifted by caller) or embeddings if "tgt_ids" in inputs: decoder_inputs = inputs["tgt_ids"] elif "tgt_embeddings" in inputs: decoder_inputs = inputs["tgt_embeddings"] else: # For generation time you may call decoder.greedy_decode separately. # Here we don't attempt to generate when labels not provided. raise ValueError( "Seq2seq tasks require 'tgt_ids' or 'tgt_embeddings' for training forward" ) decoder_out = self.decoder(decoder_inputs, memory, memory_mask=src_mask) if self.decoder_outputs_logits: if not isinstance(decoder_out, torch.Tensor): raise TypeError( "Decoder is configured to return logits, but forward returned a non-tensor value." ) logits = decoder_out else: logits = head(decoder_out) if return_loss: labels = inputs.get("labels", None) if labels is None: raise ValueError("return_loss=True requires 'labels' in inputs for seq2seq") loss = self.compute_loss_for_head(check_head, logits, labels, **loss_kwargs) return loss, logits return logits # Otherwise unsupported head type raise RuntimeError(f"Unsupported head type: {type(check_head)}") def compute_loss_for_head( self, head: nn.Module, logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, ) -> torch.Tensor: """ Default loss dispatch: - ClassificationHead: CrossEntropy on (B, num_labels) - TokenClassificationHead: CrossEntropy per token (flattened) - LMHead: CrossEntropy per token (flattened), ignore_index supported Returns scalar loss. """ if isinstance(head, ClassificationHead): # logits: (B, num_labels) or (B, num_labels) direct loss = F.cross_entropy(logits, labels.long()) return loss if isinstance(head, TokenClassificationHead): # logits: (B, T, C), labels: (B, T) B, T, C = logits.shape loss = F.cross_entropy( logits.view(B * T, C), labels.view(B * T).long(), ignore_index=ignore_index ) return loss if isinstance(head, LMHead): # logits: (B, T, V), labels: (B, T) B, T, V = logits.shape loss = F.cross_entropy( logits.view(B * T, V), labels.view(B * T).long(), ignore_index=ignore_index ) return loss # Generic fall-back: try CrossEntropy on final dim if logits.dim() == 2: return F.cross_entropy(logits, labels.long()) # If we can't determine, raise raise RuntimeError("Cannot compute loss for unknown head type") @staticmethod def _expand_attention_mask(mask: torch.Tensor, device: torch.device) -> torch.Tensor: if mask is None: return None # type: ignore[return-value] bool_mask = mask.to(device=device, dtype=torch.bool) if bool_mask.dim() == 2: return bool_mask.unsqueeze(1) & bool_mask.unsqueeze(2) if bool_mask.dim() in (3, 4): return bool_mask raise ValueError("Attention mask must be 2D, 3D, or 4D tensor")