|
|
from transformers import Mistral3ForConditionalGeneration, AutoProcessor |
|
|
from typing import Union, List, Optional |
|
|
import torch |
|
|
|
|
|
|
|
|
def format_text_input(prompts: List[str], system_message: str = None): |
|
|
|
|
|
|
|
|
|
|
|
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] |
|
|
|
|
|
return [ |
|
|
[ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": [{"type": "text", "text": system_message}], |
|
|
}, |
|
|
{"role": "user", "content": [{"type": "text", "text": prompt}]}, |
|
|
] |
|
|
for prompt in cleaned_txt |
|
|
] |
|
|
|
|
|
|
|
|
def get_mistral_3_small_prompt_embeds( |
|
|
text_encoder: Mistral3ForConditionalGeneration, |
|
|
tokenizer: AutoProcessor, |
|
|
prompt: Union[str, List[str]], |
|
|
max_sequence_length: int = 512, |
|
|
system_message: str = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object |
|
|
attribution and actions without speculation.""", |
|
|
hidden_states_layers: List[int] = (10, 20, 30), |
|
|
): |
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
|
|
|
|
|
|
messages_batch = format_text_input(prompts=prompt, system_message=system_message) |
|
|
|
|
|
|
|
|
inputs = tokenizer.apply_chat_template( |
|
|
messages_batch, |
|
|
add_generation_prompt=False, |
|
|
tokenize=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=max_sequence_length, |
|
|
) |
|
|
|
|
|
|
|
|
input_ids = inputs["input_ids"].to(text_encoder.device) |
|
|
attention_mask = inputs["attention_mask"].to(text_encoder.device) |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
output = text_encoder( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
output_hidden_states=True, |
|
|
use_cache=False, |
|
|
) |
|
|
|
|
|
|
|
|
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) |
|
|
out = out.to(dtype=text_encoder.dtype, device=text_encoder.device) |
|
|
|
|
|
batch_size, num_channels, seq_len, hidden_dim = out.shape |
|
|
prompt_embeds = out.permute(0, 2, 1, 3).reshape( |
|
|
batch_size, seq_len, num_channels * hidden_dim |
|
|
) |
|
|
|
|
|
return prompt_embeds |
|
|
|
|
|
|
|
|
def prepare_text_ids( |
|
|
x: torch.Tensor, |
|
|
t_coord: Optional[torch.Tensor] = None, |
|
|
): |
|
|
B, L, _ = x.shape |
|
|
out_ids = [] |
|
|
|
|
|
for i in range(B): |
|
|
t = torch.arange(1) if t_coord is None else t_coord[i] |
|
|
h = torch.arange(1) |
|
|
w = torch.arange(1) |
|
|
l = torch.arange(L) |
|
|
|
|
|
coords = torch.cartesian_prod(t, h, w, l) |
|
|
out_ids.append(coords) |
|
|
|
|
|
return torch.stack(out_ids) |
|
|
|
|
|
|
|
|
def encode_prompt( |
|
|
text_encoder: Mistral3ForConditionalGeneration, |
|
|
tokenizer: AutoProcessor, |
|
|
prompt: Union[str, List[str]], |
|
|
num_images_per_prompt: int = 1, |
|
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
|
max_sequence_length: int = 512, |
|
|
): |
|
|
if prompt is None: |
|
|
prompt = "" |
|
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
|
|
|
if prompt_embeds is None: |
|
|
prompt_embeds = get_mistral_3_small_prompt_embeds( |
|
|
text_encoder=text_encoder, |
|
|
tokenizer=tokenizer, |
|
|
prompt=prompt, |
|
|
max_sequence_length=max_sequence_length, |
|
|
) |
|
|
|
|
|
batch_size, seq_len, _ = prompt_embeds.shape |
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
|
|
|
text_ids = prepare_text_ids(prompt_embeds) |
|
|
text_ids = text_ids.to(text_encoder.device) |
|
|
return prompt_embeds, text_ids |
|
|
|