ariG23498 HF Staff commited on
Commit
e554b85
·
verified ·
1 Parent(s): ad732ce

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -8
main.py CHANGED
@@ -66,12 +66,13 @@ attribution and actions without speculation.""",
66
  attention_mask = inputs["attention_mask"].to(device)
67
 
68
  # Forward pass through the model
69
- output = text_encoder(
70
- input_ids=input_ids,
71
- attention_mask=attention_mask,
72
- output_hidden_states=True,
73
- use_cache=False,
74
- )
 
75
 
76
  # Only use outputs from intermediate layers and stack them
77
  out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
@@ -136,8 +137,13 @@ def read_root():
136
 
137
  @app.get("/predict")
138
  def predict(prompt: str = Query(...)):
139
- output = encode_prompt(
140
  prompt=prompt,
141
  device=device,
142
  )
143
- return {"response": output}
 
 
 
 
 
 
66
  attention_mask = inputs["attention_mask"].to(device)
67
 
68
  # Forward pass through the model
69
+ with torch.inference_mode():
70
+ output = text_encoder(
71
+ input_ids=input_ids,
72
+ attention_mask=attention_mask,
73
+ output_hidden_states=True,
74
+ use_cache=False,
75
+ )
76
 
77
  # Only use outputs from intermediate layers and stack them
78
  out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
 
137
 
138
  @app.get("/predict")
139
  def predict(prompt: str = Query(...)):
140
+ prompt_embeds, text_ids = encode_prompt(
141
  prompt=prompt,
142
  device=device,
143
  )
144
+ return {
145
+ "response": {
146
+ "prompt_embeds": prompt_embeds.cpu().tolist(),
147
+ "text_ids": text_ids.cpu().tolist()
148
+ }
149
+ }