multimodalart HF Staff commited on
Commit
2e043a8
·
verified ·
1 Parent(s): fa93b65

Upload mistral_text_encoding_core.py

Browse files
Files changed (1) hide show
  1. mistral_text_encoding_core.py +121 -0
mistral_text_encoding_core.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Mistral3ForConditionalGeneration, AutoProcessor
2
+ from typing import Union, List, Optional
3
+ import torch
4
+
5
+
6
+ def format_text_input(prompts: List[str], system_message: str = None):
7
+ # Remove [IMG] tokens from prompts to avoid Pixtral validation issues
8
+ # when truncation is enabled. The processor counts [IMG] tokens and fails
9
+ # if the count changes after truncation.
10
+ cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
11
+
12
+ return [
13
+ [
14
+ {
15
+ "role": "system",
16
+ "content": [{"type": "text", "text": system_message}],
17
+ },
18
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
19
+ ]
20
+ for prompt in cleaned_txt
21
+ ]
22
+
23
+
24
+ def get_mistral_3_small_prompt_embeds(
25
+ text_encoder: Mistral3ForConditionalGeneration,
26
+ tokenizer: AutoProcessor,
27
+ prompt: Union[str, List[str]],
28
+ max_sequence_length: int = 512,
29
+ system_message: str = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
30
+ attribution and actions without speculation.""",
31
+ hidden_states_layers: List[int] = (10, 20, 30),
32
+ ):
33
+ prompt = [prompt] if isinstance(prompt, str) else prompt
34
+
35
+ # Format input messages
36
+ messages_batch = format_text_input(prompts=prompt, system_message=system_message)
37
+
38
+ # Process all messages at once
39
+ inputs = tokenizer.apply_chat_template(
40
+ messages_batch,
41
+ add_generation_prompt=False,
42
+ tokenize=True,
43
+ return_dict=True,
44
+ return_tensors="pt",
45
+ padding="max_length",
46
+ truncation=True,
47
+ max_length=max_sequence_length,
48
+ )
49
+
50
+ # Move to device
51
+ input_ids = inputs["input_ids"].to(text_encoder.device)
52
+ attention_mask = inputs["attention_mask"].to(text_encoder.device)
53
+
54
+ # Forward pass through the model
55
+ with torch.inference_mode():
56
+ output = text_encoder(
57
+ input_ids=input_ids,
58
+ attention_mask=attention_mask,
59
+ output_hidden_states=True,
60
+ use_cache=False,
61
+ )
62
+
63
+ # Only use outputs from intermediate layers and stack them
64
+ out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
65
+ out = out.to(dtype=text_encoder.dtype, device=text_encoder.device)
66
+
67
+ batch_size, num_channels, seq_len, hidden_dim = out.shape
68
+ prompt_embeds = out.permute(0, 2, 1, 3).reshape(
69
+ batch_size, seq_len, num_channels * hidden_dim
70
+ )
71
+
72
+ return prompt_embeds
73
+
74
+
75
+ def prepare_text_ids(
76
+ x: torch.Tensor, # (B, L, D) or (L, D)
77
+ t_coord: Optional[torch.Tensor] = None,
78
+ ):
79
+ B, L, _ = x.shape
80
+ out_ids = []
81
+
82
+ for i in range(B):
83
+ t = torch.arange(1) if t_coord is None else t_coord[i]
84
+ h = torch.arange(1)
85
+ w = torch.arange(1)
86
+ l = torch.arange(L)
87
+
88
+ coords = torch.cartesian_prod(t, h, w, l)
89
+ out_ids.append(coords)
90
+
91
+ return torch.stack(out_ids)
92
+
93
+
94
+ def encode_prompt(
95
+ text_encoder: Mistral3ForConditionalGeneration,
96
+ tokenizer: AutoProcessor,
97
+ prompt: Union[str, List[str]],
98
+ num_images_per_prompt: int = 1,
99
+ prompt_embeds: Optional[torch.Tensor] = None,
100
+ max_sequence_length: int = 512,
101
+ ):
102
+ if prompt is None:
103
+ prompt = ""
104
+
105
+ prompt = [prompt] if isinstance(prompt, str) else prompt
106
+
107
+ if prompt_embeds is None:
108
+ prompt_embeds = get_mistral_3_small_prompt_embeds(
109
+ text_encoder=text_encoder,
110
+ tokenizer=tokenizer,
111
+ prompt=prompt,
112
+ max_sequence_length=max_sequence_length,
113
+ )
114
+
115
+ batch_size, seq_len, _ = prompt_embeds.shape
116
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
117
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
118
+
119
+ text_ids = prepare_text_ids(prompt_embeds)
120
+ text_ids = text_ids.to(text_encoder.device)
121
+ return prompt_embeds, text_ids