Spaces:
Sleeping
Sleeping
File size: 4,850 Bytes
4dace6a f900124 4dace6a e14e91c 4dace6a e14e91c 4dace6a e14e91c 4dace6a e14e91c 4dace6a e14e91c 4dace6a e14e91c 4dace6a e14e91c 4dace6a e14e91c 4dace6a e14e91c 4dace6a f900124 4dace6a e14e91c 4dace6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from __future__ import annotations
import base64
import json
import logging
import os
from pathlib import Path
from typing import Any
import dotenv
import gradio as gr
import requests
from gradio.components.chatbot import (
FileDataDict,
FileMessageDict,
NormalizedMessageContent,
NormalizedMessageDict,
TextMessageDict,
)
from gradio.components.multimodal_textbox import MultimodalValue
logger = logging.getLogger()
API_URL = "https://openrouter.ai/api/v1/chat/completions"
MODEL = "google/gemini-2.5-flash-lite-preview-09-2025"
SYSTEM_PROMPT = Path("system-prompt.md").read_text().strip()
AUDIO_FORMATS = {"wav", "mp3", "m4a", "flac"}
def chat_fn(user_msg: MultimodalValue, history: list[NormalizedMessageDict], api_key: str | None) -> str:
logger.info(f"History (oldest first):\n{json.dumps(history[::-1], indent=2)}")
logger.info(f"User message:\n{json.dumps(user_msg, indent=2)}")
# Determine API key
if api_key is None or len(api_key) == 0:
return "Boh!"
if api_key == os.environ["PASSWORD"]:
api_key = os.environ["OPENROUTER_API_KEY"]
# Build message history including system prompt, conversation history, and current user message
user_content: list[NormalizedMessageContent] = []
if "text" in user_msg and len(text := user_msg["text"].strip()) > 0:
user_content.append(TextMessageDict(type="text", text=text))
for path in user_msg.get("files", []):
user_content.append(FileMessageDict(type="file", file=FileDataDict(path=path)))
user_msg_dict = NormalizedMessageDict(role="user", content=user_content)
history = [
NormalizedMessageDict(role="system", content=[TextMessageDict(type="text", text=SYSTEM_PROMPT)]),
*history,
user_msg_dict,
]
# Call the model API
payload = {
"model": MODEL,
"messages": history_to_messages(history),
"max_tokens": 4096,
"temperature": 0.2,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
logger.info(f"Response:\n{json.dumps(response.json(), indent=2)}")
response.raise_for_status()
msg = response.json()["choices"][0]["message"]["content"].strip()
return msg
def history_to_messages(history: list[NormalizedMessageDict]) -> list[dict[str, Any]]:
"""
Transform content entries to openrouter format.
From:
{
"type": "file",
"file": {"path": "file.wav"}
}
To:
{
"type": "input_audio",
"input_audio": {
"data": "<base64-encoded-audio>",
"format": "wav",
}
}
"""
def transform_content(content: NormalizedMessageContent) -> dict[str, Any]:
if content["type"] == "file":
path = Path(content["file"]["path"])
suffix = path.suffix.lstrip(".").lower()
if suffix not in AUDIO_FORMATS:
raise ValueError(f"Unsupported file format: {suffix}")
return {
"type": "input_audio",
"input_audio": {
"data": file_to_base64(path),
"format": suffix,
},
}
return content # pyright: ignore[reportReturnType]
return [
{
"role": item["role"],
"content": [transform_content(c) for c in item["content"]],
}
for item in history
]
def file_to_base64(path: str | Path) -> str:
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
# Set up logging
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s: %(message)s")
logging.captureWarnings(True)
# Load environment variables from .env file
dotenv.load_dotenv()
# Chat (top-level demo variable to allow live reloading)
demo = gr.ChatInterface(
chat_fn,
multimodal=True,
chatbot=gr.Chatbot(placeholder="Ready!"),
textbox=gr.MultimodalTextbox(
placeholder="Your message",
file_count="single",
file_types=["audio"],
sources=["microphone"],
),
additional_inputs=[
gr.Textbox(type="password", label="Openrouter API Key"),
],
additional_inputs_accordion=gr.Accordion("Options", open=True),
title="Mamma AI",
description="Parla con la mamma più saggia del mondo! Puoi inviare messaggi di testo o audio.\n\nPrima di usare la chat, inserici un'API key di [Openrouter](https://openrouter.ai/) (oppure la password segreta).",
autofocus=True,
examples=[
["È meglio lavare i piatti a mano o in lavastoviglie?", None],
["Aiuto! Ho sporcato la camicia di vino!", None],
],
)
demo.launch(server_name="0.0.0.0")
|