Spaces:
Running
Running
| """Integration tests for the inference pipeline.""" | |
| from __future__ import annotations | |
| import sys | |
| import warnings | |
| from pathlib import Path | |
| from typing import cast | |
| import pytest | |
| import torch | |
| PROJECT_ROOT = Path(__file__).resolve().parents[2] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from src.data.tokenization import Tokenizer, TokenizerConfig | |
| from src.inference.pipeline import ( | |
| EmotionPrediction, | |
| InferenceConfig, | |
| InferencePipeline, | |
| TopicPrediction, | |
| ) | |
| from src.utils.labels import LabelMetadata | |
| # Silence noisy DeprecationWarnings from underlying tokenizer bindings used in tests | |
| warnings.filterwarnings("ignore", category=DeprecationWarning) | |
| warnings.filterwarnings( | |
| "ignore", | |
| message=r"builtin type SwigPy.*has no __module__ attribute", | |
| category=DeprecationWarning, | |
| ) | |
| warnings.filterwarnings( | |
| "ignore", | |
| category=DeprecationWarning, | |
| module=r"importlib\\._bootstrap", | |
| ) | |
| pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") | |
| def _local_tokenizer_config() -> TokenizerConfig: | |
| root = Path(__file__).resolve().parents[2] | |
| hf_path = root / "artifacts" / "hf_tokenizer" | |
| return TokenizerConfig(pretrained_model_name=str(hf_path)) | |
| class DummyEncoder(torch.nn.Module): | |
| def forward( | |
| self, input_ids: torch.Tensor, mask: torch.Tensor | None = None | |
| ) -> torch.Tensor: # pragma: no cover - trivial | |
| batch, seq_len = input_ids.shape | |
| return torch.zeros(batch, seq_len, 8, device=input_ids.device) | |
| class DummyDecoder(torch.nn.Module): | |
| def __init__(self, tokenizer: Tokenizer) -> None: | |
| super().__init__() | |
| tokens = tokenizer.tokenizer.encode("dummy summary", add_special_tokens=False) | |
| sequence = [tokenizer.bos_token_id, *tokens, tokenizer.eos_token_id] | |
| self.register_buffer("sequence", torch.tensor(sequence, dtype=torch.long)) | |
| def greedy_decode( | |
| self, | |
| *, | |
| memory: torch.Tensor, | |
| max_len: int, | |
| start_token_id: int, | |
| end_token_id: int | None, | |
| device: torch.device, | |
| **kwargs: object, | |
| ) -> torch.Tensor: | |
| seq = cast(torch.Tensor, self.sequence).to(device) | |
| if seq.numel() > max_len: | |
| seq = seq[:max_len] | |
| batch = memory.size(0) | |
| return seq.unsqueeze(0).repeat(batch, 1) | |
| class DummyModel(torch.nn.Module): | |
| def __init__(self, tokenizer: Tokenizer, metadata: LabelMetadata) -> None: | |
| super().__init__() | |
| self.encoder = DummyEncoder() | |
| self.decoder = DummyDecoder(tokenizer) | |
| emotion_logits = torch.tensor([-2.0, 3.0, -1.0], dtype=torch.float32) | |
| topic_logits = torch.tensor([0.25, 2.5, 0.1], dtype=torch.float32) | |
| self.register_buffer("_emotion_logits", emotion_logits) | |
| self.register_buffer("_topic_logits", topic_logits) | |
| def forward( | |
| self, task: str, inputs: dict[str, torch.Tensor] | |
| ) -> torch.Tensor: # pragma: no cover - simple dispatch | |
| batch = inputs["input_ids"].size(0) | |
| if task == "emotion": | |
| return cast(torch.Tensor, self._emotion_logits).unsqueeze(0).repeat(batch, 1) | |
| if task == "topic": | |
| return cast(torch.Tensor, self._topic_logits).unsqueeze(0).repeat(batch, 1) | |
| raise KeyError(task) | |
| def _build_pipeline() -> InferencePipeline: | |
| tokenizer = Tokenizer(_local_tokenizer_config()) | |
| metadata = LabelMetadata(emotion=["anger", "joy", "sadness"], topic=["news", "sports", "tech"]) | |
| model = DummyModel(tokenizer, metadata) | |
| return InferencePipeline( | |
| model=model, | |
| tokenizer=tokenizer, | |
| emotion_labels=metadata.emotion, | |
| topic_labels=metadata.topic, | |
| config=InferenceConfig(summary_max_length=12, summary_formatting=False), | |
| ) | |
| def test_pipeline_predictions_across_tasks() -> None: | |
| pipeline = _build_pipeline() | |
| text = "A quick unit test input." | |
| summaries = pipeline.summarize([text]) | |
| assert summaries == ["dummy summary"], "Summaries should be decoded from dummy decoder sequence" | |
| emotions = pipeline.predict_emotions([text]) | |
| assert len(emotions) == 1 | |
| emotion = emotions[0] | |
| assert isinstance(emotion, EmotionPrediction) | |
| assert emotion.labels == ["joy"], "Only the positive logit should pass the threshold" | |
| topics = pipeline.predict_topics([text]) | |
| assert len(topics) == 1 | |
| topic = topics[0] | |
| assert isinstance(topic, TopicPrediction) | |
| assert topic.label == "sports" | |
| assert topic.confidence > 0.0 | |
| combined = pipeline.batch_predict([text]) | |
| assert combined["summaries"] == summaries | |
| combined_emotions = cast(list[EmotionPrediction], combined["emotion"]) | |
| combined_topics = cast(list[TopicPrediction], combined["topic"]) | |
| assert combined_emotions[0].labels == emotion.labels | |
| assert combined_topics[0].label == topic.label | |