HuaminChen commited on
Commit
e21cde3
·
verified ·
1 Parent(s): 397f86b

Upload multi-modal-embed-large final model

Browse files
README.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: pytorch
4
+ pipeline_tag: sentence-similarity
5
+ tags:
6
+ - sentence-transformers
7
+ - multimodal
8
+ - embeddings
9
+ - retrieval
10
+ - image-text
11
+ - audio-text
12
+ - text-image-audio
13
+ - tri-encoder
14
+ - semantic-router
15
+ - pytorch
16
+ model-index:
17
+ - name: multi-modal-embed-large
18
+ results:
19
+ - task:
20
+ type: sentence-similarity
21
+ dataset:
22
+ name: Internal cached validation set
23
+ type: cached_retrieval_validation
24
+ metrics:
25
+ - name: Eval loss
26
+ type: eval_loss
27
+ value: 0.389702
28
+ - name: Eval top1
29
+ type: eval_top1
30
+ value: 0.861707
31
+ ---
32
+
33
+ # multi-modal-embed-large
34
+
35
+ `multi-modal-embed-large` is the large production multimodal embedding model from the [llm-semantic-router](https://huggingface.co/llm-semantic-router) project.
36
+
37
+ It is designed for routing, retrieval, and cross-modal matching across text, image, and audio rather than for generative chat. The model uses a tri-encoder architecture with separate text, image, and audio towers projected into one shared embedding space.
38
+
39
+ ## Purpose
40
+
41
+ This release exists to provide a large multimodal embedding model for production systems where inputs may arrive as text, screenshots or images, and audio. It is built for semantic routing, multimodal retrieval, and cross-modal similarity.
42
+
43
+ ## What Is In This Repository
44
+
45
+ This repository contains the minimum artifacts needed to load and run the exported model:
46
+
47
+ - `model.pt`: trained weights for the final exported model
48
+ - `config.json`: model configuration and encoder names
49
+ - `src/hf_st_mm/...`: the Python source package used to construct and run the tri-encoder
50
+ - `README.md`: this model card, including usage examples and validation summary
51
+
52
+ This is not a generic Hugging Face Transformers checkpoint with a built-in auto-class loader. It is a packaged custom PyTorch model export.
53
+
54
+ ## Advantages And Innovation
55
+
56
+ Most multimodal models are optimized for generation, captioning, or chat. This model is optimized for embeddings and operational use.
57
+
58
+ What is different here:
59
+
60
+ - map text, image, and audio into one shared semantic space
61
+ - support routing and retrieval instead of text generation
62
+ - preserve a strong multilingual text backbone
63
+ - use stronger modality-specific encoders instead of forcing every modality into one monolithic checkpoint
64
+ - support production training and evaluation on cached shard datasets
65
+
66
+ ## Model Overview
67
+
68
+ This release packages the large routing-grade tri-encoder trained in PyTorch with the server training stack from this project.
69
+
70
+ Architecture:
71
+
72
+ - text encoder: `llm-semantic-router/mmbert-embed-32k-2d-matryoshka`
73
+ - image encoder: `google/siglip2-so400m-patch14-384`
74
+ - audio encoder: `openai/whisper-medium`
75
+ - shared embedding dimension: `768`
76
+ - max text length: `32768`
77
+
78
+ Training characteristics:
79
+
80
+ - objective: cached multiple negatives ranking loss
81
+ - training stack: PyTorch + Accelerate
82
+ - target hardware: AMD MI300X
83
+ - data pipeline: cached tensor shards with sequential shard loading and worker-local prefetch
84
+
85
+ ## How To Use It
86
+
87
+ ## Installation
88
+
89
+ ```bash
90
+ pip install torch sentence-transformers transformers accelerate safetensors pillow librosa soundfile huggingface_hub
91
+ ```
92
+
93
+ ## Python Usage
94
+
95
+ The simplest way to use the model is to download the repository snapshot, load the packaged source code, and then encode one or more modality-tagged items.
96
+
97
+ ```python
98
+ import json
99
+ import os
100
+ import sys
101
+
102
+ import torch
103
+ from huggingface_hub import snapshot_download
104
+
105
+ repo_id = "llm-semantic-router/multi-modal-embed-large"
106
+ local_dir = snapshot_download(repo_id=repo_id)
107
+
108
+ sys.path.insert(0, os.path.join(local_dir, "src"))
109
+
110
+ from hf_st_mm.data import PairItem
111
+ from hf_st_mm.model import MultiModalSentenceEmbedder
112
+
113
+ with open(os.path.join(local_dir, "config.json"), "r", encoding="utf-8") as handle:
114
+ cfg = json.load(handle)
115
+
116
+ model = MultiModalSentenceEmbedder(
117
+ text_encoder_name=cfg["model"]["text_encoder_name"],
118
+ image_encoder_name=cfg["model"]["image_encoder_name"],
119
+ audio_encoder_name=cfg["model"]["audio_encoder_name"],
120
+ embedding_dim=int(cfg["model"]["embedding_dim"]),
121
+ max_text_length=int(cfg["model"]["max_text_length"]),
122
+ )
123
+ state_dict = torch.load(os.path.join(local_dir, "model.pt"), map_location="cpu")
124
+ model.load_state_dict(state_dict)
125
+ model.eval()
126
+
127
+ items = [
128
+ PairItem(modality="text", value="route this request to the billing team"),
129
+ PairItem(modality="image", value="/path/to/screenshot.png"),
130
+ PairItem(modality="audio", value="/path/to/call.wav"),
131
+ ]
132
+
133
+ with torch.no_grad():
134
+ embeddings = model.encode_items(items)
135
+
136
+ print(embeddings.shape) # [3, 768]
137
+
138
+ import torch.nn.functional as F
139
+
140
+ query = PairItem(modality="text", value="refund request for wrong charge")
141
+ candidate = PairItem(modality="audio", value="/path/to/refund_call.wav")
142
+
143
+ with torch.no_grad():
144
+ embs = model.encode_items([query, candidate])
145
+
146
+ similarity = F.cosine_similarity(embs[0:1], embs[1:2]).item()
147
+ print(f"similarity={similarity:.4f}")
148
+ ```
149
+
150
+ ## Validation Snapshot
151
+
152
+ At upload time, the final export was evaluated with the repository's tri-encoder evaluator.
153
+
154
+ - `eval_loss`: `0.389702`
155
+ - `eval_top1`: `0.861707`
156
+
157
+ ## Practical Notes
158
+
159
+ - Text inputs can be provided as raw strings or tokenized features.
160
+ - Image and audio inputs can be provided as file paths.
161
+ - Cached tensor payloads are supported by the training stack, but the simplest inference path is to use file paths or raw text.
162
+ - This release is intended for production retrieval and routing use cases rather than for instruction-following or caption generation.
163
+
164
+ ## Limitations
165
+
166
+ - This is a custom tri-encoder export, not a standard Transformers auto-class package.
167
+ - Inference currently relies on the packaged `hf_st_mm` source code.
168
+ - The validation metrics reported here come from the repository's cached retrieval validation path, not from a public benchmark leaderboard.
169
+
170
+ ## Training Code
171
+
172
+ Training and evaluation code live in the server training project that produced this checkpoint.
173
+
174
+ - trainer: `scripts/train_st_multimodal.py`
175
+ - evaluator: `scripts/evaluate_tri_encoder.py`
176
+ - model: `src/hf_st_mm/model.py`
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "seed": 42,
3
+ "output_dir": "/scratch/hf_st_mm_outputs/server_datacenter_8gpu_tri_encoder",
4
+ "model": {
5
+ "text_encoder_name": "llm-semantic-router/mmbert-embed-32k-2d-matryoshka",
6
+ "image_encoder_name": "google/siglip2-so400m-patch14-384",
7
+ "audio_encoder_name": "openai/whisper-medium",
8
+ "embedding_dim": 768,
9
+ "max_text_length": 32768
10
+ },
11
+ "training": {
12
+ "epochs": 10,
13
+ "batch_size": 12,
14
+ "grad_accum_steps": 8,
15
+ "num_workers": 4,
16
+ "prefetch_factor": 4,
17
+ "shard_prefetch": 2,
18
+ "shard_cache_limit": 4,
19
+ "sequential_shard_loading": true,
20
+ "shuffle": false,
21
+ "modality_homogeneous_batches": false,
22
+ "learning_rate": 1e-05,
23
+ "weight_decay": 0.01,
24
+ "warmup_ratio": 0.1,
25
+ "max_grad_norm": 1.0,
26
+ "mixed_precision": "bf16",
27
+ "log_every": 10,
28
+ "save_every": 2000,
29
+ "hard_negative_ratio": 0.5
30
+ },
31
+ "loss": {
32
+ "type": "cached_mnrl",
33
+ "scale": 20.0
34
+ },
35
+ "data": {
36
+ "cache_dir": "/scratch/2dmse-data/server_full_datacenter_cache/train"
37
+ },
38
+ "validation": {
39
+ "cache_dir": "/scratch/2dmse-data/server_full_datacenter_cache/val",
40
+ "num_workers": 2,
41
+ "shard_prefetch": 1,
42
+ "shard_cache_limit": 2
43
+ }
44
+ }
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5fe61d4864fffb703f53860234a657a2f51f71e393e2dc1b7f635b284cb48c4
3
+ size 6393990436
src/hf_st_mm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Standalone HF Sentence-Transformers multimodal training package."""
src/hf_st_mm/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (200 Bytes). View file
 
src/hf_st_mm/__pycache__/data.cpython-312.pyc ADDED
Binary file (45.7 kB). View file
 
src/hf_st_mm/__pycache__/model.cpython-312.pyc ADDED
Binary file (14.7 kB). View file
 
src/hf_st_mm/data.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import queue
5
+ import random
6
+ import threading
7
+ from bisect import bisect_right
8
+ from collections import OrderedDict
9
+ from dataclasses import dataclass
10
+ from typing import Any, Dict, Iterable, List, Optional
11
+
12
+ import torch
13
+ from datasets import Dataset, Features, IterableDataset, Value
14
+
15
+
16
+ SUPPORTED_MODALITIES = {"text", "image", "audio"}
17
+
18
+
19
+ @dataclass
20
+ class PairItem:
21
+ modality: str
22
+ value: Any
23
+
24
+
25
+ @dataclass
26
+ class TrainRecord:
27
+ query: PairItem
28
+ positive: PairItem
29
+ negative: Optional[PairItem] = None
30
+
31
+
32
+ def _parse_item(obj: Any, prefix: str) -> PairItem:
33
+ if isinstance(obj, dict):
34
+ modality = obj.get("type")
35
+ value = obj.get("value")
36
+ else:
37
+ modality = None
38
+ value = None
39
+
40
+ if not modality or not value:
41
+ raise ValueError(f"{prefix} must include type/value")
42
+ if modality not in SUPPORTED_MODALITIES:
43
+ raise ValueError(f"Unsupported modality '{modality}' in {prefix}")
44
+ return PairItem(modality=modality, value=value)
45
+
46
+
47
+ def parse_record(raw: Dict[str, Any]) -> TrainRecord:
48
+ if "query" in raw and "positive" in raw:
49
+ query = _parse_item(raw["query"], "query")
50
+ positive = _parse_item(raw["positive"], "positive")
51
+ negative = _parse_item(raw["negative"], "negative") if raw.get("negative") else None
52
+ return TrainRecord(query=query, positive=positive, negative=negative)
53
+
54
+ # Compatibility with common pair formats in existing repos
55
+ if "texts_a" in raw and "texts_b" in raw:
56
+ query = PairItem("text", raw["texts_a"])
57
+ positive = PairItem("text", raw["texts_b"])
58
+ return TrainRecord(query=query, positive=positive)
59
+
60
+ if "image_path" in raw and "caption" in raw:
61
+ query = PairItem("image", raw["image_path"])
62
+ positive = PairItem("text", raw["caption"])
63
+ return TrainRecord(query=query, positive=positive)
64
+
65
+ if "audio_path" in raw and "caption" in raw:
66
+ query = PairItem("audio", raw["audio_path"])
67
+ positive = PairItem("text", raw["caption"])
68
+ return TrainRecord(query=query, positive=positive)
69
+
70
+ raise ValueError("Record does not match supported schemas")
71
+
72
+
73
+ class JsonlManifestDataset:
74
+ def __init__(
75
+ self,
76
+ manifest_path: str,
77
+ image_root: Optional[str] = None,
78
+ audio_root: Optional[str] = None,
79
+ allow_missing_negative: bool = True,
80
+ ) -> None:
81
+ self.manifest_path = manifest_path
82
+ self.image_root = image_root
83
+ self.audio_root = audio_root
84
+ self.allow_missing_negative = allow_missing_negative
85
+ self.records = list(
86
+ iter_manifest_records(
87
+ manifest_path=self.manifest_path,
88
+ image_root=self.image_root,
89
+ audio_root=self.audio_root,
90
+ allow_missing_negative=self.allow_missing_negative,
91
+ )
92
+ )
93
+ if not self.records:
94
+ raise ValueError(f"No records loaded from {self.manifest_path}")
95
+
96
+ def __len__(self) -> int:
97
+ return len(self.records)
98
+
99
+ def __getitem__(self, idx: int) -> TrainRecord:
100
+ return self.records[idx]
101
+
102
+
103
+ class CachedShardDataset:
104
+ def __init__(self, cache_dir: str, shard_cache_limit: int = 2, prefetch_shards: int = 0) -> None:
105
+ self.cache_dir = cache_dir
106
+ self.shard_cache_limit = max(int(shard_cache_limit), 1)
107
+ self.prefetch_shards = max(int(prefetch_shards), 0)
108
+ self.metadata = self._load_metadata()
109
+ self.shard_files = self._discover_shards()
110
+ self.shard_sizes = self._resolve_shard_sizes()
111
+ self.shard_offsets = self._build_offsets(self.shard_sizes)
112
+ self.total_rows = sum(self.shard_sizes)
113
+ self._shard_cache: OrderedDict[int, List[Dict[str, Any]]] = OrderedDict()
114
+ self._init_runtime_state()
115
+
116
+ def _init_runtime_state(self) -> None:
117
+ self._cache_lock = threading.Lock()
118
+ self._prefetch_queue = None
119
+ self._prefetch_thread = None
120
+ self._prefetch_stop = threading.Event()
121
+ self._prefetch_requested: set[int] = set()
122
+ self._prefetch_hits = 0
123
+ self._prefetch_misses = 0
124
+
125
+ def __getstate__(self):
126
+ state = self.__dict__.copy()
127
+ state["_shard_cache"] = OrderedDict(state.get("_shard_cache", OrderedDict()))
128
+ state["_cache_lock"] = None
129
+ state["_prefetch_queue"] = None
130
+ state["_prefetch_thread"] = None
131
+ state["_prefetch_stop"] = None
132
+ state["_prefetch_requested"] = set()
133
+ return state
134
+
135
+ def __setstate__(self, state):
136
+ self.__dict__.update(state)
137
+ self._shard_cache = OrderedDict(self._shard_cache)
138
+ self._init_runtime_state()
139
+
140
+ def _load_metadata(self) -> Dict[str, Any]:
141
+ metadata_path = os.path.join(self.cache_dir, "metadata.json")
142
+ if not os.path.exists(metadata_path):
143
+ return {}
144
+ with open(metadata_path, "r", encoding="utf-8") as handle:
145
+ return json.load(handle)
146
+
147
+ def _discover_shards(self) -> List[str]:
148
+ if not os.path.isdir(self.cache_dir):
149
+ raise FileNotFoundError(f"Cache directory not found: {self.cache_dir}")
150
+ shards: List[str] = []
151
+ for name in sorted(os.listdir(self.cache_dir)):
152
+ if not (name.startswith("shard_") and name.endswith(".pt")):
153
+ continue
154
+ shard_path = os.path.join(self.cache_dir, name)
155
+ shards.append(shard_path)
156
+ if not shards:
157
+ raise ValueError(f"No cache shards found under {self.cache_dir}")
158
+ return shards
159
+
160
+ @staticmethod
161
+ def _build_offsets(shard_sizes: List[int]) -> List[int]:
162
+ offsets: List[int] = []
163
+ running_total = 0
164
+ for shard_size in shard_sizes:
165
+ running_total += shard_size
166
+ offsets.append(running_total)
167
+ return offsets
168
+
169
+ def _resolve_shard_sizes(self) -> List[int]:
170
+ num_shards = len(self.shard_files)
171
+ metadata_num_shards = self.metadata.get("num_shards")
172
+ metadata_num_records = self.metadata.get("num_records")
173
+ shard_size = self.metadata.get("shard_size")
174
+
175
+ if (
176
+ isinstance(metadata_num_shards, int)
177
+ and isinstance(metadata_num_records, int)
178
+ and isinstance(shard_size, int)
179
+ and metadata_num_shards == num_shards
180
+ and metadata_num_records > 0
181
+ and shard_size > 0
182
+ ):
183
+ shard_sizes = [shard_size] * num_shards
184
+ full_rows_before_last = shard_size * max(num_shards - 1, 0)
185
+ shard_sizes[-1] = metadata_num_records - full_rows_before_last
186
+ if shard_sizes[-1] <= 0:
187
+ raise ValueError(f"Invalid metadata in {self.cache_dir}: last shard size computed as {shard_sizes[-1]}")
188
+ return shard_sizes
189
+
190
+ shard_sizes: List[int] = []
191
+ for shard_path in self.shard_files:
192
+ payload = torch.load(shard_path, map_location="cpu", weights_only=False)
193
+ records = payload.get("records")
194
+ if not isinstance(records, list):
195
+ raise ValueError(f"Invalid shard format in {shard_path}")
196
+ shard_sizes.append(len(records))
197
+ return shard_sizes
198
+
199
+ def _store_shard(self, shard_idx: int, records: List[Dict[str, Any]]) -> None:
200
+ with self._cache_lock:
201
+ self._shard_cache[shard_idx] = records
202
+ self._shard_cache.move_to_end(shard_idx)
203
+ while len(self._shard_cache) > self.shard_cache_limit:
204
+ self._shard_cache.popitem(last=False)
205
+
206
+ def _ensure_prefetch_thread(self) -> None:
207
+ if self.prefetch_shards <= 0:
208
+ return
209
+ if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
210
+ return
211
+
212
+ self._prefetch_stop.clear()
213
+ self._prefetch_queue = queue.Queue(maxsize=max(self.prefetch_shards * 2, 1))
214
+ self._prefetch_thread = threading.Thread(
215
+ target=self._prefetch_worker,
216
+ daemon=True,
217
+ name=f"cached-shard-prefetch-{os.getpid()}",
218
+ )
219
+ self._prefetch_thread.start()
220
+
221
+ def _prefetch_worker(self) -> None:
222
+ while not self._prefetch_stop.is_set():
223
+ try:
224
+ shard_idx = self._prefetch_queue.get(timeout=0.1)
225
+ except queue.Empty:
226
+ continue
227
+
228
+ if shard_idx is None:
229
+ continue
230
+
231
+ try:
232
+ with self._cache_lock:
233
+ if shard_idx in self._shard_cache:
234
+ self._prefetch_hits += 1
235
+ continue
236
+ payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
237
+ records = payload["records"]
238
+ self._store_shard(shard_idx, records)
239
+ self._prefetch_hits += 1
240
+ finally:
241
+ with self._cache_lock:
242
+ self._prefetch_requested.discard(shard_idx)
243
+
244
+ def _schedule_prefetch(self, shard_idx: int) -> None:
245
+ if self.prefetch_shards <= 0:
246
+ return
247
+
248
+ self._ensure_prefetch_thread()
249
+ if self._prefetch_queue is None:
250
+ return
251
+
252
+ for next_idx in range(shard_idx + 1, min(len(self.shard_files), shard_idx + 1 + self.prefetch_shards)):
253
+ with self._cache_lock:
254
+ if next_idx in self._shard_cache or next_idx in self._prefetch_requested:
255
+ continue
256
+ self._prefetch_requested.add(next_idx)
257
+ try:
258
+ self._prefetch_queue.put_nowait(next_idx)
259
+ except queue.Full:
260
+ with self._cache_lock:
261
+ self._prefetch_requested.discard(next_idx)
262
+ break
263
+
264
+ def _load_shard(self, shard_idx: int) -> List[Dict[str, Any]]:
265
+ cached = None
266
+ with self._cache_lock:
267
+ cached = self._shard_cache.get(shard_idx)
268
+ if cached is not None:
269
+ self._shard_cache.move_to_end(shard_idx)
270
+ if cached is not None:
271
+ self._schedule_prefetch(shard_idx)
272
+ return cached
273
+
274
+ self._prefetch_misses += 1
275
+ payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
276
+ records = payload["records"]
277
+ self._store_shard(shard_idx, records)
278
+ with self._cache_lock:
279
+ self._prefetch_requested.discard(shard_idx)
280
+ self._schedule_prefetch(shard_idx)
281
+ return records
282
+
283
+ @staticmethod
284
+ def _deserialize_item(raw: Optional[Dict[str, Any]]) -> Optional[PairItem]:
285
+ if raw is None:
286
+ return None
287
+ modality = raw["type"]
288
+ if modality == "text" and "tokens" in raw:
289
+ value = raw["tokens"]
290
+ elif modality == "text":
291
+ value = raw["value"]
292
+ elif "tensor" in raw:
293
+ value = raw["tensor"]
294
+ else:
295
+ value = raw.get("value")
296
+ return PairItem(modality=modality, value=value)
297
+
298
+ def __len__(self) -> int:
299
+ return self.total_rows
300
+
301
+ def __getitem__(self, idx: int) -> TrainRecord:
302
+ if idx < 0 or idx >= self.total_rows:
303
+ raise IndexError(idx)
304
+ shard_idx = bisect_right(self.shard_offsets, idx)
305
+ shard_start = 0 if shard_idx == 0 else self.shard_offsets[shard_idx - 1]
306
+ local_idx = idx - shard_start
307
+ raw = self._load_shard(shard_idx)[local_idx]
308
+ return TrainRecord(
309
+ query=self._deserialize_item(raw["query"]),
310
+ positive=self._deserialize_item(raw["positive"]),
311
+ negative=self._deserialize_item(raw.get("negative")),
312
+ )
313
+
314
+ def get_prefetch_stats(self) -> Dict[str, int]:
315
+ with self._cache_lock:
316
+ return {
317
+ "cache_size": len(self._shard_cache),
318
+ "cache_limit": self.shard_cache_limit,
319
+ "prefetch_shards": self.prefetch_shards,
320
+ "prefetch_hits": self._prefetch_hits,
321
+ "prefetch_misses": self._prefetch_misses,
322
+ "prefetch_pending": len(self._prefetch_requested),
323
+ }
324
+
325
+ def close(self) -> None:
326
+ self._prefetch_stop.set()
327
+ if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
328
+ self._prefetch_thread.join(timeout=1.0)
329
+ self._prefetch_thread = None
330
+ self._prefetch_queue = None
331
+
332
+ def __del__(self):
333
+ self.close()
334
+
335
+
336
+ class SequentialShardDataset:
337
+ def __init__(
338
+ self,
339
+ cache_dir: str,
340
+ shuffle: bool = True,
341
+ rank: int = 0,
342
+ world_size: int = 1,
343
+ prefetch_shards: int = 2,
344
+ shard_cache_limit: int = 4,
345
+ ) -> None:
346
+ self.cache_dir = cache_dir
347
+ self.shuffle = shuffle
348
+ self.rank = rank
349
+ self.world_size = max(world_size, 1)
350
+ self.prefetch_shards = max(int(prefetch_shards), 0)
351
+ self.shard_cache_limit = max(int(shard_cache_limit), 1)
352
+
353
+ self.metadata = self._load_metadata()
354
+ self.shard_files = self._discover_shards()
355
+ self.shard_sizes = self._resolve_shard_sizes()
356
+ self.total_rows = sum(self.shard_sizes)
357
+ self.target_shard_size = int(self.metadata.get("shard_size") or max(self.shard_sizes))
358
+
359
+ self._shard_cache: OrderedDict[int, List[Dict[str, Any]]] = OrderedDict()
360
+ self._cache_lock = threading.Lock()
361
+ self._prefetch_queue = None
362
+ self._prefetch_thread = None
363
+ self._prefetch_stop = threading.Event()
364
+ self._prefetch_requested: set[int] = set()
365
+ self._prefetch_hits = 0
366
+ self._prefetch_misses = 0
367
+
368
+ self._all_shard_indices = list(range(len(self.shard_files)))
369
+ self._local_shard_indices: List[int] = []
370
+ self.current_local_shard_pos = -1
371
+ self.current_records: Optional[List[Dict[str, Any]]] = None
372
+
373
+ def _load_metadata(self) -> Dict[str, Any]:
374
+ metadata_path = os.path.join(self.cache_dir, "metadata.json")
375
+ if not os.path.exists(metadata_path):
376
+ return {}
377
+ with open(metadata_path, "r", encoding="utf-8") as handle:
378
+ return json.load(handle)
379
+
380
+ def _discover_shards(self) -> List[str]:
381
+ if not os.path.isdir(self.cache_dir):
382
+ raise FileNotFoundError(f"Cache directory not found: {self.cache_dir}")
383
+ shards: List[str] = []
384
+ for name in sorted(os.listdir(self.cache_dir)):
385
+ if not (name.startswith("shard_") and name.endswith(".pt")):
386
+ continue
387
+ shards.append(os.path.join(self.cache_dir, name))
388
+ if not shards:
389
+ raise ValueError(f"No cache shards found under {self.cache_dir}")
390
+ return shards
391
+
392
+ def _resolve_shard_sizes(self) -> List[int]:
393
+ num_shards = len(self.shard_files)
394
+ metadata_num_shards = self.metadata.get("num_shards")
395
+ metadata_num_records = self.metadata.get("num_records")
396
+ shard_size = self.metadata.get("shard_size")
397
+
398
+ if (
399
+ isinstance(metadata_num_shards, int)
400
+ and isinstance(metadata_num_records, int)
401
+ and isinstance(shard_size, int)
402
+ and metadata_num_shards == num_shards
403
+ and metadata_num_records > 0
404
+ and shard_size > 0
405
+ ):
406
+ shard_sizes = [shard_size] * num_shards
407
+ shard_sizes[-1] = metadata_num_records - shard_size * max(num_shards - 1, 0)
408
+ return shard_sizes
409
+
410
+ shard_sizes: List[int] = []
411
+ for shard_path in self.shard_files:
412
+ payload = torch.load(shard_path, map_location="cpu", weights_only=False)
413
+ records = payload.get("records")
414
+ if not isinstance(records, list):
415
+ raise ValueError(f"Invalid shard format in {shard_path}")
416
+ shard_sizes.append(len(records))
417
+ return shard_sizes
418
+
419
+ @staticmethod
420
+ def _deserialize_item(raw: Optional[Dict[str, Any]]) -> Optional[PairItem]:
421
+ if raw is None:
422
+ return None
423
+ modality = raw["type"]
424
+ if modality == "text" and "tokens" in raw:
425
+ value = raw["tokens"]
426
+ elif modality == "text":
427
+ value = raw["value"]
428
+ elif "tensor" in raw:
429
+ value = raw["tensor"]
430
+ else:
431
+ value = raw.get("value")
432
+ return PairItem(modality=modality, value=value)
433
+
434
+ def _store_shard(self, shard_idx: int, records: List[Dict[str, Any]]) -> None:
435
+ with self._cache_lock:
436
+ self._shard_cache[shard_idx] = records
437
+ self._shard_cache.move_to_end(shard_idx)
438
+ while len(self._shard_cache) > self.shard_cache_limit:
439
+ self._shard_cache.popitem(last=False)
440
+
441
+ def _ensure_prefetch_thread(self) -> None:
442
+ if self.prefetch_shards <= 0:
443
+ return
444
+ if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
445
+ return
446
+ self._prefetch_stop.clear()
447
+ self._prefetch_queue = queue.Queue(maxsize=max(self.prefetch_shards * 2, 1))
448
+ self._prefetch_thread = threading.Thread(
449
+ target=self._prefetch_worker,
450
+ daemon=True,
451
+ name=f"sequential-shard-prefetch-{os.getpid()}",
452
+ )
453
+ self._prefetch_thread.start()
454
+
455
+ def _prefetch_worker(self) -> None:
456
+ while not self._prefetch_stop.is_set():
457
+ try:
458
+ shard_idx = self._prefetch_queue.get(timeout=0.1)
459
+ except queue.Empty:
460
+ continue
461
+ if shard_idx is None:
462
+ continue
463
+ try:
464
+ with self._cache_lock:
465
+ if shard_idx in self._shard_cache:
466
+ self._prefetch_hits += 1
467
+ continue
468
+ payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
469
+ self._store_shard(shard_idx, payload["records"])
470
+ self._prefetch_hits += 1
471
+ finally:
472
+ with self._cache_lock:
473
+ self._prefetch_requested.discard(shard_idx)
474
+
475
+ def _stop_prefetch_thread(self) -> None:
476
+ self._prefetch_stop.set()
477
+ if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
478
+ self._prefetch_thread.join(timeout=1.0)
479
+ self._prefetch_thread = None
480
+ self._prefetch_queue = None
481
+
482
+ def _schedule_prefetch_from_position(self, local_pos: int) -> None:
483
+ if self.prefetch_shards <= 0:
484
+ return
485
+ self._ensure_prefetch_thread()
486
+ if self._prefetch_queue is None:
487
+ return
488
+ for next_pos in range(local_pos + 1, min(len(self._local_shard_indices), local_pos + 1 + self.prefetch_shards)):
489
+ shard_idx = self._local_shard_indices[next_pos]
490
+ with self._cache_lock:
491
+ if shard_idx in self._shard_cache or shard_idx in self._prefetch_requested:
492
+ continue
493
+ self._prefetch_requested.add(shard_idx)
494
+ try:
495
+ self._prefetch_queue.put_nowait(shard_idx)
496
+ except queue.Full:
497
+ with self._cache_lock:
498
+ self._prefetch_requested.discard(shard_idx)
499
+ break
500
+
501
+ def _build_local_shard_order(self, epoch: int) -> List[int]:
502
+ shard_indices = list(self._all_shard_indices)
503
+ if self.shuffle:
504
+ random.Random(42 + epoch).shuffle(shard_indices)
505
+ local_shards = shard_indices[self.rank::self.world_size]
506
+ max_shards = math.ceil(len(shard_indices) / self.world_size)
507
+ if not local_shards:
508
+ raise ValueError(f"Rank {self.rank} received no shards from {self.cache_dir}")
509
+ while len(local_shards) < max_shards:
510
+ local_shards.append(local_shards[len(local_shards) % len(local_shards)])
511
+ return local_shards
512
+
513
+ def _load_records_for_shard(self, shard_idx: int) -> List[Dict[str, Any]]:
514
+ cached = None
515
+ with self._cache_lock:
516
+ cached = self._shard_cache.get(shard_idx)
517
+ if cached is not None:
518
+ self._shard_cache.move_to_end(shard_idx)
519
+ if cached is not None:
520
+ return cached
521
+
522
+ self._prefetch_misses += 1
523
+ payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
524
+ records = payload["records"]
525
+ self._store_shard(shard_idx, records)
526
+ with self._cache_lock:
527
+ self._prefetch_requested.discard(shard_idx)
528
+ return records
529
+
530
+ def _pad_records(self, records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
531
+ if len(records) >= self.target_shard_size:
532
+ return records
533
+ repeat = math.ceil(self.target_shard_size / len(records))
534
+ return (records * repeat)[: self.target_shard_size]
535
+
536
+ def reset(self, epoch: int) -> bool:
537
+ self._stop_prefetch_thread()
538
+ self._local_shard_indices = self._build_local_shard_order(epoch)
539
+ self.current_local_shard_pos = -1
540
+ self.current_records = None
541
+ with self._cache_lock:
542
+ self._prefetch_requested.clear()
543
+ if self.prefetch_shards > 0:
544
+ self._ensure_prefetch_thread()
545
+ return self.next_shard()
546
+
547
+ def next_shard(self) -> bool:
548
+ self.current_local_shard_pos += 1
549
+ if self.current_local_shard_pos >= len(self._local_shard_indices):
550
+ self.current_records = None
551
+ return False
552
+ shard_idx = self._local_shard_indices[self.current_local_shard_pos]
553
+ records = self._load_records_for_shard(shard_idx)
554
+ self.current_records = self._pad_records(records)
555
+ self._schedule_prefetch_from_position(self.current_local_shard_pos)
556
+ return True
557
+
558
+ def __len__(self) -> int:
559
+ return len(self.current_records or [])
560
+
561
+ def __getitem__(self, idx: int) -> TrainRecord:
562
+ if self.current_records is None:
563
+ raise IndexError(idx)
564
+ raw = self.current_records[idx]
565
+ return TrainRecord(
566
+ query=self._deserialize_item(raw["query"]),
567
+ positive=self._deserialize_item(raw["positive"]),
568
+ negative=self._deserialize_item(raw.get("negative")),
569
+ )
570
+
571
+ def estimated_num_batches(self, batch_size: int, drop_last: bool) -> int:
572
+ shard_batches = self.target_shard_size // batch_size if drop_last else math.ceil(self.target_shard_size / batch_size)
573
+ return shard_batches * max(len(self._build_local_shard_order(0)), 1)
574
+
575
+ def get_prefetch_stats(self) -> Dict[str, int]:
576
+ with self._cache_lock:
577
+ return {
578
+ "cache_size": len(self._shard_cache),
579
+ "cache_limit": self.shard_cache_limit,
580
+ "prefetch_shards": self.prefetch_shards,
581
+ "prefetch_hits": self._prefetch_hits,
582
+ "prefetch_misses": self._prefetch_misses,
583
+ "prefetch_pending": len(self._prefetch_requested),
584
+ "local_shards": len(self._local_shard_indices),
585
+ "target_shard_size": self.target_shard_size,
586
+ }
587
+
588
+ def close(self) -> None:
589
+ self._stop_prefetch_thread()
590
+
591
+ def __del__(self):
592
+ self.close()
593
+
594
+
595
+ def _process_shard() -> tuple[int, int]:
596
+ rank = int(os.environ.get("ACCELERATE_PROCESS_INDEX") or os.environ.get("RANK") or 0)
597
+ world_size = int(os.environ.get("WORLD_SIZE") or os.environ.get("ACCELERATE_NUM_PROCESSES") or 1)
598
+ worker_info = torch.utils.data.get_worker_info()
599
+ if worker_info is None:
600
+ return rank, max(world_size, 1)
601
+
602
+ total_shards = max(world_size, 1) * worker_info.num_workers
603
+ shard_id = rank * worker_info.num_workers + worker_info.id
604
+ return shard_id, max(total_shards, 1)
605
+
606
+
607
+ def iter_sentence_transformers_rows(
608
+ manifest_path: str,
609
+ image_root: Optional[str],
610
+ audio_root: Optional[str],
611
+ allow_missing_negative: bool,
612
+ allowed_modalities: Optional[List[str]],
613
+ query_modalities: Optional[List[str]],
614
+ positive_modalities: Optional[List[str]],
615
+ negative_modalities: Optional[List[str]],
616
+ use_negative_column: bool,
617
+ ):
618
+ allowed = set(allowed_modalities or [])
619
+ allowed_query = set(query_modalities or [])
620
+ allowed_positive = set(positive_modalities or [])
621
+ allowed_negative = set(negative_modalities or [])
622
+ shard_id, total_shards = _process_shard()
623
+ matched_index = 0
624
+
625
+ for record in iter_manifest_records(
626
+ manifest_path=manifest_path,
627
+ image_root=image_root,
628
+ audio_root=audio_root,
629
+ allow_missing_negative=allow_missing_negative,
630
+ ):
631
+ if not record_matches_filters(
632
+ record,
633
+ allowed=allowed,
634
+ allowed_query=allowed_query,
635
+ allowed_positive=allowed_positive,
636
+ allowed_negative=allowed_negative,
637
+ ):
638
+ continue
639
+
640
+ if matched_index % total_shards == shard_id:
641
+ yield record_to_sentence_transformers_row(record, include_negative=use_negative_column)
642
+ matched_index += 1
643
+
644
+
645
+ def collate_records(batch: List[TrainRecord]) -> Dict[str, List[PairItem]]:
646
+ return {
647
+ "query": [r.query for r in batch],
648
+ "positive": [r.positive for r in batch],
649
+ "negative": [r.negative for r in batch],
650
+ }
651
+
652
+
653
+ def sentence_transformers_input(item: PairItem) -> Any:
654
+ payload: Dict[str, Any] = {}
655
+ if item.modality == "text":
656
+ payload["text"] = item.value
657
+ return payload
658
+ if item.modality == "image":
659
+ payload["image"] = item.value
660
+ return payload
661
+ if item.modality == "audio":
662
+ payload["audio"] = item.value
663
+ return payload
664
+ return item.value
665
+
666
+
667
+ def resolve_media(item: PairItem, image_root: Optional[str], audio_root: Optional[str]) -> PairItem:
668
+ if item.modality == "image" and image_root and not os.path.isabs(item.value):
669
+ return PairItem(item.modality, os.path.join(image_root, item.value))
670
+ if item.modality == "audio" and audio_root and not os.path.isabs(item.value):
671
+ return PairItem(item.modality, os.path.join(audio_root, item.value))
672
+ return item
673
+
674
+
675
+ def iter_manifest_records(
676
+ manifest_path: str,
677
+ image_root: Optional[str] = None,
678
+ audio_root: Optional[str] = None,
679
+ allow_missing_negative: bool = True,
680
+ ) -> Iterable[TrainRecord]:
681
+ if not os.path.exists(manifest_path):
682
+ raise FileNotFoundError(f"Manifest not found: {manifest_path}")
683
+
684
+ with open(manifest_path, "r", encoding="utf-8") as handle:
685
+ for line_no, line in enumerate(handle, start=1):
686
+ line = line.strip()
687
+ if not line:
688
+ continue
689
+ raw = json.loads(line)
690
+ record = parse_record(raw)
691
+ record = TrainRecord(
692
+ query=resolve_media(record.query, image_root, audio_root),
693
+ positive=resolve_media(record.positive, image_root, audio_root),
694
+ negative=resolve_media(record.negative, image_root, audio_root) if record.negative else None,
695
+ )
696
+ if record.negative is None and not allow_missing_negative:
697
+ raise ValueError(f"Missing negative at line {line_no}")
698
+ yield record
699
+
700
+
701
+ def record_matches_filters(
702
+ record: TrainRecord,
703
+ allowed: set[str],
704
+ allowed_query: set[str],
705
+ allowed_positive: set[str],
706
+ allowed_negative: set[str],
707
+ ) -> bool:
708
+ record_modalities = {record.query.modality, record.positive.modality}
709
+ if record.negative is not None:
710
+ record_modalities.add(record.negative.modality)
711
+ if allowed and not record_modalities.issubset(allowed):
712
+ return False
713
+ if allowed_query and record.query.modality not in allowed_query:
714
+ return False
715
+ if allowed_positive and record.positive.modality not in allowed_positive:
716
+ return False
717
+ if record.negative is not None and allowed_negative and record.negative.modality not in allowed_negative:
718
+ return False
719
+ return True
720
+
721
+
722
+ def record_to_sentence_transformers_row(record: TrainRecord, include_negative: bool) -> Dict[str, Any]:
723
+ row = {
724
+ "query": sentence_transformers_input(record.query),
725
+ "positive": sentence_transformers_input(record.positive),
726
+ }
727
+ if include_negative and record.negative is not None:
728
+ row["negative_0"] = sentence_transformers_input(record.negative)
729
+ return row
730
+
731
+
732
+ def summarize_manifest_records(
733
+ manifest_path: str,
734
+ image_root: Optional[str] = None,
735
+ audio_root: Optional[str] = None,
736
+ allow_missing_negative: bool = True,
737
+ allowed_modalities: Optional[List[str]] = None,
738
+ query_modalities: Optional[List[str]] = None,
739
+ positive_modalities: Optional[List[str]] = None,
740
+ negative_modalities: Optional[List[str]] = None,
741
+ max_records: Optional[int] = None,
742
+ ) -> Dict[str, Any]:
743
+ modalities = set()
744
+ negatives_present = 0
745
+ negatives_missing = 0
746
+ skipped_rows = 0
747
+ num_rows = 0
748
+ allowed = set(allowed_modalities or [])
749
+ allowed_query = set(query_modalities or [])
750
+ allowed_positive = set(positive_modalities or [])
751
+ allowed_negative = set(negative_modalities or [])
752
+
753
+ for record in iter_manifest_records(
754
+ manifest_path=manifest_path,
755
+ image_root=image_root,
756
+ audio_root=audio_root,
757
+ allow_missing_negative=allow_missing_negative,
758
+ ):
759
+ if not record_matches_filters(
760
+ record,
761
+ allowed=allowed,
762
+ allowed_query=allowed_query,
763
+ allowed_positive=allowed_positive,
764
+ allowed_negative=allowed_negative,
765
+ ):
766
+ skipped_rows += 1
767
+ continue
768
+
769
+ modalities.add(record.query.modality)
770
+ modalities.add(record.positive.modality)
771
+ if record.negative is not None:
772
+ modalities.add(record.negative.modality)
773
+ negatives_present += 1
774
+ else:
775
+ negatives_missing += 1
776
+ num_rows += 1
777
+ if max_records is not None and num_rows >= max_records:
778
+ break
779
+
780
+ if num_rows == 0:
781
+ raise ValueError(f"No records loaded from {manifest_path}")
782
+
783
+ return {
784
+ "modalities": sorted(modalities),
785
+ "num_rows": num_rows,
786
+ "has_uniform_negatives": negatives_present > 0 and negatives_missing == 0,
787
+ "num_negatives_present": negatives_present,
788
+ "num_negatives_missing": negatives_missing,
789
+ "skipped_rows": skipped_rows,
790
+ }
791
+
792
+
793
+ def manifest_to_sentence_transformers_dataset(
794
+ manifest_path: str,
795
+ image_root: Optional[str] = None,
796
+ audio_root: Optional[str] = None,
797
+ allow_missing_negative: bool = True,
798
+ allowed_modalities: Optional[List[str]] = None,
799
+ query_modalities: Optional[List[str]] = None,
800
+ positive_modalities: Optional[List[str]] = None,
801
+ negative_modalities: Optional[List[str]] = None,
802
+ as_iterable: bool = False,
803
+ max_records: Optional[int] = None,
804
+ ) -> tuple[Dataset | IterableDataset, Dict[str, Any]]:
805
+ info = summarize_manifest_records(
806
+ manifest_path=manifest_path,
807
+ image_root=image_root,
808
+ audio_root=audio_root,
809
+ allow_missing_negative=allow_missing_negative,
810
+ allowed_modalities=allowed_modalities,
811
+ query_modalities=query_modalities,
812
+ positive_modalities=positive_modalities,
813
+ negative_modalities=negative_modalities,
814
+ max_records=max_records,
815
+ )
816
+
817
+ dataset_out: Dataset | IterableDataset
818
+ if as_iterable:
819
+ column_names = ["query", "positive"]
820
+ if info["has_uniform_negatives"]:
821
+ column_names.append("negative_0")
822
+ dataset_out = IterableDataset.from_generator(
823
+ iter_sentence_transformers_rows,
824
+ features=Features({key: Value("null") for key in column_names}),
825
+ gen_kwargs={
826
+ "manifest_path": manifest_path,
827
+ "image_root": image_root,
828
+ "audio_root": audio_root,
829
+ "allow_missing_negative": allow_missing_negative,
830
+ "allowed_modalities": allowed_modalities,
831
+ "query_modalities": query_modalities,
832
+ "positive_modalities": positive_modalities,
833
+ "negative_modalities": negative_modalities,
834
+ "use_negative_column": info["has_uniform_negatives"],
835
+ },
836
+ )
837
+ else:
838
+ dataset = JsonlManifestDataset(
839
+ manifest_path=manifest_path,
840
+ image_root=image_root,
841
+ audio_root=audio_root,
842
+ allow_missing_negative=allow_missing_negative,
843
+ )
844
+ allowed = set(allowed_modalities or [])
845
+ allowed_query = set(query_modalities or [])
846
+ allowed_positive = set(positive_modalities or [])
847
+ allowed_negative = set(negative_modalities or [])
848
+ rows: List[Dict[str, Any]] = []
849
+ for record in dataset.records:
850
+ if not record_matches_filters(
851
+ record,
852
+ allowed=allowed,
853
+ allowed_query=allowed_query,
854
+ allowed_positive=allowed_positive,
855
+ allowed_negative=allowed_negative,
856
+ ):
857
+ continue
858
+ rows.append(record_to_sentence_transformers_row(record, include_negative=info["has_uniform_negatives"]))
859
+ if max_records is not None and len(rows) >= max_records:
860
+ break
861
+ dataset_out = Dataset.from_list(rows)
862
+
863
+ return dataset_out, info
src/hf_st_mm/model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Any, Dict, List
3
+
4
+ import librosa
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from PIL import Image
10
+ from sentence_transformers import SentenceTransformer
11
+ from transformers import AutoModel, AutoProcessor, WhisperFeatureExtractor, WhisperModel
12
+
13
+ from .data import PairItem
14
+
15
+
16
+ class MultiModalSentenceEmbedder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ text_encoder_name: str,
20
+ image_encoder_name: str,
21
+ audio_encoder_name: str,
22
+ embedding_dim: int,
23
+ max_text_length: int,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.text_model = SentenceTransformer(text_encoder_name)
27
+ self.text_model.max_seq_length = max_text_length
28
+
29
+ self.image_model = AutoModel.from_pretrained(image_encoder_name, trust_remote_code=True)
30
+ self.image_processor = AutoProcessor.from_pretrained(image_encoder_name, trust_remote_code=True)
31
+
32
+ whisper = WhisperModel.from_pretrained(audio_encoder_name)
33
+ self.audio_model = whisper.encoder
34
+ self.audio_processor = WhisperFeatureExtractor.from_pretrained(audio_encoder_name)
35
+
36
+ text_dim = self.text_model.get_sentence_embedding_dimension()
37
+ image_dim = self._get_vision_dim(self.image_model)
38
+ audio_dim = whisper.config.d_model
39
+
40
+ self.text_proj = nn.Linear(text_dim, embedding_dim) if text_dim != embedding_dim else nn.Identity()
41
+ self.image_proj = nn.Linear(image_dim, embedding_dim) if image_dim != embedding_dim else nn.Identity()
42
+ self.audio_proj = nn.Linear(audio_dim, embedding_dim) if audio_dim != embedding_dim else nn.Identity()
43
+
44
+ @staticmethod
45
+ def _get_vision_dim(model: nn.Module) -> int:
46
+ if hasattr(model, "vision_model") and hasattr(model.config, "vision_config"):
47
+ return int(model.config.vision_config.hidden_size)
48
+ if hasattr(model.config, "hidden_size"):
49
+ return int(model.config.hidden_size)
50
+ raise ValueError("Could not infer image hidden size")
51
+
52
+ def _encode_text(self, texts: List[Any]) -> torch.Tensor:
53
+ device = next(self.parameters()).device
54
+ normalized: List[torch.Tensor | None] = [None] * len(texts)
55
+
56
+ dict_positions = [idx for idx, item in enumerate(texts) if isinstance(item, dict)]
57
+ if dict_positions:
58
+ pad_values = {
59
+ "input_ids": 0,
60
+ "attention_mask": 0,
61
+ "token_type_ids": 0,
62
+ }
63
+ dict_items = [texts[idx] for idx in dict_positions]
64
+ features = {
65
+ key: pad_sequence(
66
+ [item[key].detach().cpu() for item in dict_items],
67
+ batch_first=True,
68
+ padding_value=pad_values.get(key, 0),
69
+ ).to(device)
70
+ for key in dict_items[0].keys()
71
+ }
72
+ out = self.text_model(features)
73
+ emb = F.normalize(self.text_proj(out["sentence_embedding"]), p=2, dim=-1)
74
+ for loc, row in zip(dict_positions, emb):
75
+ normalized[loc] = row
76
+
77
+ raw_positions = [idx for idx, item in enumerate(texts) if not isinstance(item, dict)]
78
+ if raw_positions:
79
+ raw_texts = [texts[idx] for idx in raw_positions]
80
+ features = self.text_model.tokenize(raw_texts)
81
+ features = {
82
+ k: (v.to(device) if hasattr(v, "to") else v)
83
+ for k, v in features.items()
84
+ }
85
+ out = self.text_model(features)
86
+ emb = F.normalize(self.text_proj(out["sentence_embedding"]), p=2, dim=-1)
87
+ for loc, row in zip(raw_positions, emb):
88
+ normalized[loc] = row
89
+
90
+ return torch.stack([row for row in normalized if row is not None], dim=0)
91
+
92
+ def _encode_image_paths(self, paths: List[str]) -> torch.Tensor:
93
+ images = [Image.open(path).convert("RGB") for path in paths]
94
+ proc = self.image_processor(images=images, return_tensors="pt")
95
+ device = next(self.parameters()).device
96
+ proc = {k: v.to(device) for k, v in proc.items()}
97
+ return self._encode_image_pixel_values(proc["pixel_values"])
98
+
99
+ def _encode_image_pixel_values(self, pixel_values: torch.Tensor) -> torch.Tensor:
100
+ device = next(self.parameters()).device
101
+ proc = {"pixel_values": pixel_values.to(device)}
102
+ if hasattr(self.image_model, "vision_model"):
103
+ out = self.image_model.vision_model(**proc, output_hidden_states=False)
104
+ hidden = out.last_hidden_state
105
+ else:
106
+ out = self.image_model(**proc, output_hidden_states=False)
107
+ hidden = out.last_hidden_state
108
+ pooled = hidden[:, 1:].mean(dim=1) if hidden.shape[1] > 1 else hidden.mean(dim=1)
109
+ emb = self.image_proj(pooled)
110
+ return F.normalize(emb, p=2, dim=-1)
111
+
112
+ def _encode_audio_paths(self, paths: List[str]) -> torch.Tensor:
113
+ waves = [librosa.load(path, sr=16000, mono=True)[0] for path in paths]
114
+ proc = self.audio_processor(waves, sampling_rate=16000, return_tensors="pt")
115
+ return self._encode_audio_features(proc["input_features"])
116
+
117
+ def _encode_audio_features(self, input_features: torch.Tensor) -> torch.Tensor:
118
+ device = next(self.parameters()).device
119
+ input_features = input_features.to(device)
120
+ input_features = input_features.to(self.audio_model.conv1.weight.dtype)
121
+ out = self.audio_model(input_features=input_features, output_hidden_states=False)
122
+ pooled = out.last_hidden_state.mean(dim=1)
123
+ emb = self.audio_proj(pooled)
124
+ return F.normalize(emb, p=2, dim=-1)
125
+
126
+ @staticmethod
127
+ def _stack_tensor_values(values: List[Any]) -> torch.Tensor:
128
+ tensors = []
129
+ for value in values:
130
+ if not torch.is_tensor(value):
131
+ raise TypeError("Expected tensor payload in cached item")
132
+ tensor = value.detach().cpu()
133
+ if tensor.dim() > 0 and tensor.shape[0] == 1:
134
+ tensor = tensor.squeeze(0)
135
+ tensors.append(tensor)
136
+ return torch.stack(tensors, dim=0)
137
+
138
+ def encode_items(self, items: List[PairItem]) -> torch.Tensor:
139
+ grouped = defaultdict(list)
140
+ for idx, item in enumerate(items):
141
+ grouped[item.modality].append((idx, item.value))
142
+
143
+ device = next(self.parameters()).device
144
+ out = [None] * len(items)
145
+
146
+ if grouped["text"]:
147
+ idxs, vals = zip(*grouped["text"])
148
+ embs = self._encode_text(list(vals))
149
+ for loc, emb in zip(idxs, embs):
150
+ out[loc] = emb
151
+
152
+ if grouped["image"]:
153
+ idxs, vals = zip(*grouped["image"])
154
+ tensor_pairs = [(idx, val) for idx, val in zip(idxs, vals) if torch.is_tensor(val)]
155
+ path_pairs = [(idx, val) for idx, val in zip(idxs, vals) if not torch.is_tensor(val)]
156
+ if path_pairs:
157
+ p_idxs, p_vals = zip(*path_pairs)
158
+ embs = self._encode_image_paths(list(p_vals))
159
+ for loc, emb in zip(p_idxs, embs):
160
+ out[loc] = emb
161
+ if tensor_pairs:
162
+ t_idxs, t_vals = zip(*tensor_pairs)
163
+ embs = self._encode_image_pixel_values(self._stack_tensor_values(list(t_vals)))
164
+ for loc, emb in zip(t_idxs, embs):
165
+ out[loc] = emb
166
+
167
+ if grouped["audio"]:
168
+ idxs, vals = zip(*grouped["audio"])
169
+ tensor_pairs = [(idx, val) for idx, val in zip(idxs, vals) if torch.is_tensor(val)]
170
+ path_pairs = [(idx, val) for idx, val in zip(idxs, vals) if not torch.is_tensor(val)]
171
+ if path_pairs:
172
+ p_idxs, p_vals = zip(*path_pairs)
173
+ embs = self._encode_audio_paths(list(p_vals))
174
+ for loc, emb in zip(p_idxs, embs):
175
+ out[loc] = emb
176
+ if tensor_pairs:
177
+ t_idxs, t_vals = zip(*tensor_pairs)
178
+ embs = self._encode_audio_features(self._stack_tensor_values(list(t_vals)))
179
+ for loc, emb in zip(t_idxs, embs):
180
+ out[loc] = emb
181
+
182
+ stacked = torch.stack(out, dim=0).to(device=device, dtype=torch.float32)
183
+ return F.normalize(stacked, p=2, dim=-1)
184
+
185
+
186
+ def multiple_negatives_ranking_loss(anchor: torch.Tensor, positive: torch.Tensor, scale: float = 20.0) -> torch.Tensor:
187
+ scores = torch.matmul(anchor, positive.T) * scale
188
+ labels = torch.arange(scores.shape[0], device=scores.device)
189
+ loss_a = torch.nn.functional.cross_entropy(scores, labels)
190
+ loss_b = torch.nn.functional.cross_entropy(scores.T, labels)
191
+ return (loss_a + loss_b) * 0.5