LexiMind / src /inference /factory.py
OliverPerrin
Update Gradio demo, inference factory, and evaluation results
d9dbe7c
"""
Inference pipeline factory for LexiMind.
Assembles a complete inference pipeline from saved checkpoints, tokenizer
artifacts, and label metadata. Handles model loading and configuration.
Author: Oliver Perrin
Date: December 2025
"""
from __future__ import annotations
from pathlib import Path
from typing import Tuple
import torch
from ..data.preprocessing import TextPreprocessor
from ..data.tokenization import Tokenizer, TokenizerConfig
from ..models.factory import build_multitask_model, load_model_config
from ..utils.io import load_state
from ..utils.labels import LabelMetadata, load_label_metadata
from .pipeline import InferenceConfig, InferencePipeline
def create_inference_pipeline(
checkpoint_path: str | Path,
labels_path: str | Path,
*,
tokenizer_config: TokenizerConfig | None = None,
tokenizer_dir: str | Path | None = None,
model_config_path: str | Path | None = None,
device: str | torch.device = "cpu",
summary_max_length: int | None = None,
) -> Tuple[InferencePipeline, LabelMetadata]:
"""Build an :class:`InferencePipeline` from saved model and label metadata."""
checkpoint = Path(checkpoint_path)
if not checkpoint.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
labels = load_label_metadata(labels_path)
resolved_tokenizer_config = tokenizer_config
if resolved_tokenizer_config is None:
default_dir = Path(__file__).resolve().parent.parent.parent / "artifacts" / "hf_tokenizer"
chosen_dir = Path(tokenizer_dir) if tokenizer_dir is not None else default_dir
local_tokenizer_dir = chosen_dir
if local_tokenizer_dir.exists():
resolved_tokenizer_config = TokenizerConfig(
pretrained_model_name=str(local_tokenizer_dir)
)
else:
raise ValueError(
"No tokenizer configuration provided and default tokenizer directory "
f"'{local_tokenizer_dir}' not found. Please provide tokenizer_config parameter or set tokenizer_dir."
)
tokenizer = Tokenizer(resolved_tokenizer_config)
# Default to the base config because the published checkpoints were trained
# with the 12-layer FLAN-T5-base alignment (vocab 32128, rel pos bias).
if model_config_path is None:
model_config_path = (
Path(__file__).resolve().parent.parent.parent / "configs" / "model" / "base.yaml"
)
model_config = load_model_config(model_config_path)
model = build_multitask_model(
tokenizer,
num_emotions=labels.emotion_size,
num_topics=labels.topic_size,
config=model_config,
load_pretrained=False,
)
# Load checkpoint - weights will load separately since factory doesn't tie them
load_state(model, str(checkpoint))
if isinstance(device, torch.device):
device_str = str(device)
else:
device_str = device
if summary_max_length is not None:
pipeline_config = InferenceConfig(summary_max_length=summary_max_length, device=device_str)
else:
pipeline_config = InferenceConfig(device=device_str)
pipeline = InferencePipeline(
model=model,
tokenizer=tokenizer,
config=pipeline_config,
emotion_labels=labels.emotion,
topic_labels=labels.topic,
device=device,
preprocessor=TextPreprocessor(tokenizer=tokenizer, lowercase=tokenizer.config.lower),
)
return pipeline, labels