Spaces:
Paused
Paused
File size: 4,823 Bytes
719c89d e554b85 719c89d e554b85 719c89d e554b85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from fastapi import FastAPI, Query
from transformers import Mistral3ForConditionalGeneration, AutoProcessor
from typing import Union, Optional, List
import torch
app = FastAPI()
device = "cuda"
model_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
processor_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
tokenizer = AutoProcessor.from_pretrained(processor_id)
def format_text_input(prompts: List[str], system_message: str = None):
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
# when truncation is enabled. The processor counts [IMG] tokens and fails
# if the count changes after truncation.
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
return [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
for prompt in cleaned_txt
]
def _get_mistral_3_small_prompt_embeds(
text_encoder: Mistral3ForConditionalGeneration,
tokenizer: AutoProcessor,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
system_message: str = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
attribution and actions without speculation.""",
hidden_states_layers: List[int] = (10, 20, 30),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
# Format input messages
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
# Process all messages at once
inputs = tokenizer.apply_chat_template(
messages_batch,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
# Move to device
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
# Forward pass through the model
with torch.inference_mode():
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
def _prepare_text_ids(
x: torch.Tensor, # (B, L, D) or (L, D)
t_coord: Optional[torch.Tensor] = None,
):
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, l)
out_ids.append(coords)
return torch.stack(out_ids)
def encode_prompt(
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
):
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds = _get_mistral_3_small_prompt_embeds(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
device=device,
max_sequence_length=max_sequence_length,
)
batch_size, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
text_ids = _prepare_text_ids(prompt_embeds)
text_ids = text_ids.to(device)
return prompt_embeds, text_ids
@app.get("/")
def read_root():
return {"message": "API is live. Use the /predict endpoint."}
@app.get("/predict")
def predict(prompt: str = Query(...)):
prompt_embeds, text_ids = encode_prompt(
prompt=prompt,
device=device,
)
return {
"response": {
"prompt_embeds": prompt_embeds.cpu().tolist(),
"text_ids": text_ids.cpu().tolist()
}
} |