yuto0o
oom 回避
5bac7bb
raw
history blame
2.85 kB
import torch
from rest_framework.response import Response
from rest_framework.views import APIView
from ml_api.model_loader import get_model
class ChatView(APIView):
def post(self, request):
user_input = request.data.get("text", "")
# ここで呼び出す(初回のみロードが走る)
model, tokenizer = get_model()
# 1. 会話フォーマットの作成
messages = [
{
"role": "system",
"content": "あなたは親切でフレンドリーなAIアシスタントです。自然な日本語で簡潔に返事をしてください。",
},
{"role": "user", "content": user_input},
]
# 2. プロンプトへの変換
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
# 3. 生成
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=128,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
# 4. デコード
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
]
response_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[
0
]
return Response({"result": response_text})
# class ChatView(APIView):
# def post(self, request):
# input_text = request.data.get("text", "")
# # 簡易的なプロンプトエンジニアリング
# # モデルに「会話」であることを認識させるフォーマット
# prompt = f"ユーザー: {input_text}\nシステム: "
# app_config = apps.get_app_config("ml_api")
# tokenizer = app_config.tokenizer
# model = app_config.model
# # トークン化
# inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
# # 生成
# with torch.no_grad():
# output_ids = model.generate(
# inputs["input_ids"],
# max_new_tokens=50, # 返信の長さ
# do_sample=True,
# temperature=0.7, # 創造性(高いほどランダム)
# pad_token_id=tokenizer.pad_token_id,
# eos_token_id=tokenizer.eos_token_id,
# )
# # デコード
# output = tokenizer.decode(output_ids.tolist()[0])
# # プロンプト部分を除去して返信部分だけ抽出
# response_text = output.split("システム: ")[-1].strip()
# return Response({"result": response_text})