Spaces:
Running
Running
| """ | |
| Inference pipeline factory for LexiMind. | |
| Assembles a complete inference pipeline from saved checkpoints, tokenizer | |
| artifacts, and label metadata. Handles model loading and configuration. | |
| Author: Oliver Perrin | |
| Date: December 2025 | |
| """ | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Tuple | |
| import torch | |
| from ..data.preprocessing import TextPreprocessor | |
| from ..data.tokenization import Tokenizer, TokenizerConfig | |
| from ..models.factory import build_multitask_model, load_model_config | |
| from ..utils.io import load_state | |
| from ..utils.labels import LabelMetadata, load_label_metadata | |
| from .pipeline import InferenceConfig, InferencePipeline | |
| def create_inference_pipeline( | |
| checkpoint_path: str | Path, | |
| labels_path: str | Path, | |
| *, | |
| tokenizer_config: TokenizerConfig | None = None, | |
| tokenizer_dir: str | Path | None = None, | |
| model_config_path: str | Path | None = None, | |
| device: str | torch.device = "cpu", | |
| summary_max_length: int | None = None, | |
| ) -> Tuple[InferencePipeline, LabelMetadata]: | |
| """Build an :class:`InferencePipeline` from saved model and label metadata.""" | |
| checkpoint = Path(checkpoint_path) | |
| if not checkpoint.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint}") | |
| labels = load_label_metadata(labels_path) | |
| resolved_tokenizer_config = tokenizer_config | |
| if resolved_tokenizer_config is None: | |
| default_dir = Path(__file__).resolve().parent.parent.parent / "artifacts" / "hf_tokenizer" | |
| chosen_dir = Path(tokenizer_dir) if tokenizer_dir is not None else default_dir | |
| local_tokenizer_dir = chosen_dir | |
| if local_tokenizer_dir.exists(): | |
| resolved_tokenizer_config = TokenizerConfig( | |
| pretrained_model_name=str(local_tokenizer_dir) | |
| ) | |
| else: | |
| raise ValueError( | |
| "No tokenizer configuration provided and default tokenizer directory " | |
| f"'{local_tokenizer_dir}' not found. Please provide tokenizer_config parameter or set tokenizer_dir." | |
| ) | |
| tokenizer = Tokenizer(resolved_tokenizer_config) | |
| # Default to the base config because the published checkpoints were trained | |
| # with the 12-layer FLAN-T5-base alignment (vocab 32128, rel pos bias). | |
| if model_config_path is None: | |
| model_config_path = ( | |
| Path(__file__).resolve().parent.parent.parent / "configs" / "model" / "base.yaml" | |
| ) | |
| model_config = load_model_config(model_config_path) | |
| model = build_multitask_model( | |
| tokenizer, | |
| num_emotions=labels.emotion_size, | |
| num_topics=labels.topic_size, | |
| config=model_config, | |
| load_pretrained=False, | |
| ) | |
| # Load checkpoint - weights will load separately since factory doesn't tie them | |
| load_state(model, str(checkpoint)) | |
| if isinstance(device, torch.device): | |
| device_str = str(device) | |
| else: | |
| device_str = device | |
| if summary_max_length is not None: | |
| pipeline_config = InferenceConfig(summary_max_length=summary_max_length, device=device_str) | |
| else: | |
| pipeline_config = InferenceConfig(device=device_str) | |
| pipeline = InferencePipeline( | |
| model=model, | |
| tokenizer=tokenizer, | |
| config=pipeline_config, | |
| emotion_labels=labels.emotion, | |
| topic_labels=labels.topic, | |
| device=device, | |
| preprocessor=TextPreprocessor(tokenizer=tokenizer, lowercase=tokenizer.config.lower), | |
| ) | |
| return pipeline, labels | |