File size: 3,507 Bytes
590a604
 
 
 
 
 
 
 
 
ee1a8a3
1fbc47b
 
 
 
 
 
 
86b2059
1fbc47b
a18e93d
1fbc47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a18e93d
 
 
1fbc47b
 
 
 
 
 
 
a18e93d
d9dbe7c
 
86b2059
a18e93d
 
 
 
1fbc47b
 
 
 
 
 
86b2059
1fbc47b
a18e93d
86b2059
1fbc47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86b2059
1fbc47b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Inference pipeline factory for LexiMind.

Assembles a complete inference pipeline from saved checkpoints, tokenizer
artifacts, and label metadata. Handles model loading and configuration.

Author: Oliver Perrin
Date: December 2025
"""

from __future__ import annotations

from pathlib import Path
from typing import Tuple

import torch

from ..data.preprocessing import TextPreprocessor
from ..data.tokenization import Tokenizer, TokenizerConfig
from ..models.factory import build_multitask_model, load_model_config
from ..utils.io import load_state
from ..utils.labels import LabelMetadata, load_label_metadata
from .pipeline import InferenceConfig, InferencePipeline


def create_inference_pipeline(
    checkpoint_path: str | Path,
    labels_path: str | Path,
    *,
    tokenizer_config: TokenizerConfig | None = None,
    tokenizer_dir: str | Path | None = None,
    model_config_path: str | Path | None = None,
    device: str | torch.device = "cpu",
    summary_max_length: int | None = None,
) -> Tuple[InferencePipeline, LabelMetadata]:
    """Build an :class:`InferencePipeline` from saved model and label metadata."""

    checkpoint = Path(checkpoint_path)
    if not checkpoint.exists():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")

    labels = load_label_metadata(labels_path)

    resolved_tokenizer_config = tokenizer_config
    if resolved_tokenizer_config is None:
        default_dir = Path(__file__).resolve().parent.parent.parent / "artifacts" / "hf_tokenizer"
        chosen_dir = Path(tokenizer_dir) if tokenizer_dir is not None else default_dir
        local_tokenizer_dir = chosen_dir
        if local_tokenizer_dir.exists():
            resolved_tokenizer_config = TokenizerConfig(
                pretrained_model_name=str(local_tokenizer_dir)
            )
        else:
            raise ValueError(
                "No tokenizer configuration provided and default tokenizer directory "
                f"'{local_tokenizer_dir}' not found. Please provide tokenizer_config parameter or set tokenizer_dir."
            )

    tokenizer = Tokenizer(resolved_tokenizer_config)

    # Default to the base config because the published checkpoints were trained
    # with the 12-layer FLAN-T5-base alignment (vocab 32128, rel pos bias).
    if model_config_path is None:
        model_config_path = (
            Path(__file__).resolve().parent.parent.parent / "configs" / "model" / "base.yaml"
        )

    model_config = load_model_config(model_config_path)
    model = build_multitask_model(
        tokenizer,
        num_emotions=labels.emotion_size,
        num_topics=labels.topic_size,
        config=model_config,
        load_pretrained=False,
    )

    # Load checkpoint - weights will load separately since factory doesn't tie them
    load_state(model, str(checkpoint))

    if isinstance(device, torch.device):
        device_str = str(device)
    else:
        device_str = device

    if summary_max_length is not None:
        pipeline_config = InferenceConfig(summary_max_length=summary_max_length, device=device_str)
    else:
        pipeline_config = InferenceConfig(device=device_str)

    pipeline = InferencePipeline(
        model=model,
        tokenizer=tokenizer,
        config=pipeline_config,
        emotion_labels=labels.emotion,
        topic_labels=labels.topic,
        device=device,
        preprocessor=TextPreprocessor(tokenizer=tokenizer, lowercase=tokenizer.config.lower),
    )
    return pipeline, labels