Spaces:
Paused
Paused
Update main.py
Browse files
main.py
CHANGED
|
@@ -66,12 +66,13 @@ attribution and actions without speculation.""",
|
|
| 66 |
attention_mask = inputs["attention_mask"].to(device)
|
| 67 |
|
| 68 |
# Forward pass through the model
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
# Only use outputs from intermediate layers and stack them
|
| 77 |
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
|
@@ -136,8 +137,13 @@ def read_root():
|
|
| 136 |
|
| 137 |
@app.get("/predict")
|
| 138 |
def predict(prompt: str = Query(...)):
|
| 139 |
-
|
| 140 |
prompt=prompt,
|
| 141 |
device=device,
|
| 142 |
)
|
| 143 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
attention_mask = inputs["attention_mask"].to(device)
|
| 67 |
|
| 68 |
# Forward pass through the model
|
| 69 |
+
with torch.inference_mode():
|
| 70 |
+
output = text_encoder(
|
| 71 |
+
input_ids=input_ids,
|
| 72 |
+
attention_mask=attention_mask,
|
| 73 |
+
output_hidden_states=True,
|
| 74 |
+
use_cache=False,
|
| 75 |
+
)
|
| 76 |
|
| 77 |
# Only use outputs from intermediate layers and stack them
|
| 78 |
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
|
|
|
| 137 |
|
| 138 |
@app.get("/predict")
|
| 139 |
def predict(prompt: str = Query(...)):
|
| 140 |
+
prompt_embeds, text_ids = encode_prompt(
|
| 141 |
prompt=prompt,
|
| 142 |
device=device,
|
| 143 |
)
|
| 144 |
+
return {
|
| 145 |
+
"response": {
|
| 146 |
+
"prompt_embeds": prompt_embeds.cpu().tolist(),
|
| 147 |
+
"text_ids": text_ids.cpu().tolist()
|
| 148 |
+
}
|
| 149 |
+
}
|