import React, { useState, useEffect, useCallback, useRef, useMemo, } from "react"; import { openDB, type IDBPDatabase } from "idb"; import { Play, Plus, RotateCcw, Wrench } from "lucide-react"; import { useLLM } from "./hooks/useLLM"; import type { Tool } from "./components/ToolItem"; import { extractFunctionAndRenderer, generateSchemaFromCode, getErrorMessage, isMobileOrTablet, } from "./utils"; import { DB_NAME, STORE_NAME, SETTINGS_STORE_NAME } from "./constants/db"; import { DEFAULT_TOOLS, TEMPLATE } from "./tools"; import ToolResultRenderer from "./components/ToolResultRenderer"; import ToolCallIndicator from "./components/ToolCallIndicator"; import ToolItem from "./components/ToolItem"; import ResultBlock from "./components/ResultBlock"; import ExamplePrompts from "./components/ExamplePrompts"; import { LoadingScreen } from "./components/LoadingScreen"; interface RenderInfo { call: string; result?: any; renderer?: string; input?: Record; error?: string; } interface BaseMessage { role: "system" | "user" | "assistant"; content: string; } interface ToolMessage { role: "tool"; content: string; renderInfo: RenderInfo[]; // Rich data for the UI } type Message = BaseMessage | ToolMessage; async function getDB(): Promise { return openDB(DB_NAME, 1, { upgrade(db) { if (!db.objectStoreNames.contains(STORE_NAME)) { db.createObjectStore(STORE_NAME, { keyPath: "id", autoIncrement: true, }); } if (!db.objectStoreNames.contains(SETTINGS_STORE_NAME)) { db.createObjectStore(SETTINGS_STORE_NAME, { keyPath: "key" }); } }, }); } const App: React.FC = () => { const [messages, setMessages] = useState([]); const [tools, setTools] = useState([]); const [input, setInput] = useState(""); const [isGenerating, setIsGenerating] = useState(false); const isMobile = useMemo(isMobileOrTablet, []); const [selectedModel, setSelectedModel] = useState( isMobile ? "350M" : "1B", ); const [isModelDropdownOpen, setIsModelDropdownOpen] = useState(false); const [isToolsPanelVisible, setIsToolsPanelVisible] = useState(false); const chatContainerRef = useRef(null); const debounceTimers = useRef>({}); const toolsContainerRef = useRef(null); const inputRef = useRef(null); const { isLoading, isReady, error, progress, loadModel, generateResponse, clearPastKeyValues, } = useLLM(selectedModel); const loadTools = useCallback(async (): Promise => { const db = await getDB(); const allTools: Tool[] = await db.getAll(STORE_NAME); if (allTools.length === 0) { const defaultTools: Tool[] = Object.entries(DEFAULT_TOOLS).map( ([name, code], id) => ({ id, name, code, enabled: true, isCollapsed: false, }), ); const tx = db.transaction(STORE_NAME, "readwrite"); await Promise.all(defaultTools.map((tool) => tx.store.put(tool))); await tx.done; setTools(defaultTools); } else { setTools(allTools.map((t) => ({ ...t, isCollapsed: false }))); } }, []); useEffect(() => { loadTools(); }, [loadTools]); useEffect(() => { if (chatContainerRef.current) { chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight; } }, [messages]); const updateToolInDB = async (tool: Tool): Promise => { const db = await getDB(); await db.put(STORE_NAME, tool); }; const saveToolDebounced = (tool: Tool): void => { if (tool.id !== undefined && debounceTimers.current[tool.id]) { clearTimeout(debounceTimers.current[tool.id]); } if (tool.id !== undefined) { debounceTimers.current[tool.id] = setTimeout(() => { updateToolInDB(tool); }, 300); } }; const clearChat = useCallback(() => { setMessages([]); clearPastKeyValues(); }, [clearPastKeyValues]); const addTool = async (): Promise => { const newTool: Omit = { name: "new_tool", code: TEMPLATE, enabled: true, isCollapsed: false, }; const db = await getDB(); const id = await db.add(STORE_NAME, newTool); setTools((prev) => { const updated = [...prev, { ...newTool, id: id as number }]; setTimeout(() => { if (toolsContainerRef.current) { toolsContainerRef.current.scrollTop = toolsContainerRef.current.scrollHeight; } }, 0); return updated; }); clearChat(); }; const deleteTool = async (id: number): Promise => { if (debounceTimers.current[id]) { clearTimeout(debounceTimers.current[id]); } const db = await getDB(); await db.delete(STORE_NAME, id); setTools(tools.filter((tool) => tool.id !== id)); clearChat(); }; const toggleToolEnabled = (id: number): void => { let changedTool: Tool | undefined; const newTools = tools.map((tool) => { if (tool.id === id) { changedTool = { ...tool, enabled: !tool.enabled }; return changedTool; } return tool; }); setTools(newTools); if (changedTool) saveToolDebounced(changedTool); }; const toggleToolCollapsed = (id: number): void => { setTools( tools.map((tool) => tool.id === id ? { ...tool, isCollapsed: !tool.isCollapsed } : tool, ), ); }; const expandTool = (id: number): void => { setTools( tools.map((tool) => tool.id === id ? { ...tool, isCollapsed: false } : tool, ), ); }; const handleToolCodeChange = (id: number, newCode: string): void => { let changedTool: Tool | undefined; const newTools = tools.map((tool) => { if (tool.id === id) { const { functionCode } = extractFunctionAndRenderer(newCode); const schema = generateSchemaFromCode(functionCode); changedTool = { ...tool, code: newCode, name: schema.name }; return changedTool; } return tool; }); setTools(newTools); if (changedTool) saveToolDebounced(changedTool); }; interface ToolCallPayload { name: string; arguments?: Record; } const extractToolCalls = (text: string): ToolCallPayload[] => { const matches = Array.from( text.matchAll(/([\s\S]*?)<\/tool_call>/g), ); const toolCalls: ToolCallPayload[] = []; for (const match of matches) { try { const parsed = JSON.parse(match[1].trim()); if (parsed && typeof parsed.name === "string") { toolCalls.push({ name: parsed.name, arguments: parsed.arguments ?? {}, }); } } catch { // ignore malformed tool call payloads } } return toolCalls; }; const executeToolCall = async ( toolCall: ToolCallPayload, ): Promise<{ serializedResult: string; rendererCode?: string; input: Record; }> => { const toolToUse = tools.find((t) => t.name === toolCall.name && t.enabled); if (!toolToUse) throw new Error(`Tool '${toolCall.name}' not found or is disabled.`); const { functionCode, rendererCode } = extractFunctionAndRenderer( toolToUse.code, ); const schema = generateSchemaFromCode(functionCode); const properties = schema.parameters?.properties ?? {}; const paramNames = Object.keys(properties); const requiredParams = schema.parameters?.required ?? []; const callArgs = toolCall.arguments ?? {}; const finalArgs: any[] = []; const resolvedArgs: Record = Object.create(null); for (const paramName of paramNames) { const propertyConfig = properties[paramName] ?? {}; if (Object.prototype.hasOwnProperty.call(callArgs, paramName)) { const value = callArgs[paramName]; finalArgs.push(value); resolvedArgs[paramName] = value; } else if ( Object.prototype.hasOwnProperty.call(propertyConfig, "default") ) { const value = propertyConfig.default; finalArgs.push(value); resolvedArgs[paramName] = value; } else if (!requiredParams.includes(paramName)) { finalArgs.push(undefined); resolvedArgs[paramName] = undefined; } else { throw new Error(`Missing required argument: ${paramName}`); } } for (const extraKey of Object.keys(callArgs)) { if (!Object.prototype.hasOwnProperty.call(resolvedArgs, extraKey)) { resolvedArgs[extraKey] = callArgs[extraKey]; } } const bodyMatch = functionCode.match(/function[^{]+\{([\s\S]*)\}/); if (!bodyMatch) { throw new Error( "Could not parse function body. Ensure it's a standard `function` declaration.", ); } const body = bodyMatch[1]; const AsyncFunction = Object.getPrototypeOf( async function () {}, ).constructor; const func = new AsyncFunction(...paramNames, body); const result = await func(...finalArgs); return { serializedResult: JSON.stringify(result), rendererCode, input: resolvedArgs, }; }; const executeToolCalls = async ( toolCalls: ToolCallPayload[], ): Promise => { if (toolCalls.length === 0) { return [{ call: "", error: "No valid tool calls found." }]; } const results: RenderInfo[] = []; for (const toolCall of toolCalls) { const callDisplay = `${JSON.stringify(toolCall)}`; try { const { serializedResult, rendererCode, input } = await executeToolCall(toolCall); let parsedResult: unknown; try { parsedResult = JSON.parse(serializedResult); } catch { parsedResult = serializedResult; } results.push({ call: callDisplay, result: parsedResult, renderer: rendererCode, input, }); } catch (error) { results.push({ call: callDisplay, error: getErrorMessage(error), }); } } return results; }; const handleSendMessage = async (): Promise => { if (!input.trim() || !isReady) return; const userMessage: Message = { role: "user", content: input }; let currentMessages: Message[] = [...messages, userMessage]; setMessages(currentMessages); setInput(""); setIsGenerating(true); try { const toolSchemas = tools .filter((tool) => tool.enabled) .map((tool) => generateSchemaFromCode(tool.code)); while (true) { const messagesForGeneration = [...currentMessages]; setMessages([...currentMessages, { role: "assistant", content: "" }]); let accumulatedContent = ""; const response = await generateResponse( messagesForGeneration, toolSchemas, (token: string) => { accumulatedContent += token; setMessages((current) => { const updated = [...current]; updated[updated.length - 1] = { role: "assistant", content: accumulatedContent, }; return updated; }); }, ); currentMessages.push({ role: "assistant", content: response }); const toolCalls = extractToolCalls(response); if (toolCalls.length > 0) { const toolResults = await executeToolCalls(toolCalls); const toolMessage: ToolMessage = { role: "tool", content: JSON.stringify(toolResults.map((r) => r.result ?? null)), renderInfo: toolResults, }; currentMessages.push(toolMessage); setMessages([...currentMessages]); continue; } else { setMessages(currentMessages); break; } } } catch (error) { const errorMessage = getErrorMessage(error); setMessages([ ...currentMessages, { role: "assistant", content: `Error generating response: ${errorMessage}`, }, ]); } finally { setIsGenerating(false); setTimeout(() => inputRef.current?.focus(), 0); } }; const loadSelectedModel = useCallback(async (): Promise => { try { await loadModel(); } catch (error) { console.error("Failed to load model:", error); } }, [selectedModel, loadModel]); const saveSelectedModel = useCallback(async (modelId: string) => { try { const db = await getDB(); await db.put(SETTINGS_STORE_NAME, { key: "selectedModelId", value: modelId, }); } catch (error) { console.error("Failed to save selected model ID:", error); } }, []); const loadSelectedModelId = useCallback(async (): Promise => { try { const db = await getDB(); const stored = await db.get(SETTINGS_STORE_NAME, "selectedModelId"); if (stored && stored.value) { setSelectedModel(stored.value); } } catch (error) { console.error("Failed to load selected model ID:", error); } }, []); useEffect(() => { loadSelectedModelId(); }, [loadSelectedModelId]); const handleModelSelect = async (modelId: string) => { setSelectedModel(modelId); setIsModelDropdownOpen(false); await saveSelectedModel(modelId); }; const handleExampleClick = async (messageText: string): Promise => { if (!isReady || isGenerating) return; setInput(messageText); const userMessage: Message = { role: "user", content: messageText }; const currentMessages: Message[] = [...messages, userMessage]; setMessages(currentMessages); setInput(""); setIsGenerating(true); try { const toolSchemas = tools .filter((tool) => tool.enabled) .map((tool) => generateSchemaFromCode(tool.code)); while (true) { const messagesForGeneration = [...currentMessages]; setMessages([...currentMessages, { role: "assistant", content: "" }]); let accumulatedContent = ""; const response = await generateResponse( messagesForGeneration, toolSchemas, (token: string) => { accumulatedContent += token; setMessages((current) => { const updated = [...current]; updated[updated.length - 1] = { role: "assistant", content: accumulatedContent, }; return updated; }); }, ); currentMessages.push({ role: "assistant", content: response }); const toolCalls = extractToolCalls(response); if (toolCalls.length > 0) { const toolResults = await executeToolCalls(toolCalls); const toolMessage: ToolMessage = { role: "tool", content: JSON.stringify(toolResults.map((r) => r.result ?? null)), renderInfo: toolResults, }; currentMessages.push(toolMessage); setMessages([...currentMessages]); continue; } else { setMessages(currentMessages); break; } } } catch (error) { const errorMessage = getErrorMessage(error); setMessages([ ...currentMessages, { role: "assistant", content: `Error generating response: ${errorMessage}`, }, ]); } finally { setIsGenerating(false); setTimeout(() => inputRef.current?.focus(), 0); } }; return (
{!isReady ? ( ) : (
IBM Granite

Granite-4.0 Tool Studio

{messages.length === 0 && isReady ? ( ) : ( messages.map((msg, index) => { const key = `${msg.role}-${index}`; if (msg.role === "user") { return (

{msg.content}

); } if (msg.role === "assistant") { const isToolCall = msg.content.includes(""); if (isToolCall) { const nextMessage = messages[index + 1]; const isCompleted = nextMessage?.role === "tool"; const hasError = isCompleted && (nextMessage as ToolMessage).renderInfo.some( (info) => !!info.error, ); return (
); } return (

{msg.content}

); } if (msg.role === "tool") { const visibleToolResults = msg.renderInfo.filter( (info) => info.error || (info.result != null && info.renderer), ); if (visibleToolResults.length === 0) return null; return (
{visibleToolResults.map((info, idx) => (
{info.call}
{info.error ? ( ) : ( )}
))}
); } return null; }) )}
setInput(e.target.value)} onKeyDown={(e) => e.key === "Enter" && !isGenerating && isReady && handleSendMessage() } disabled={isGenerating || !isReady} className="flex-grow bg-transparent px-5 py-3 text-lg text-white placeholder:text-[#a6c8ff]/70 focus:outline-none disabled:opacity-40" placeholder={ isReady ? "Type your message here..." : "Load a Granite model to enable chat" } />
{isToolsPanelVisible && (
Tool Workspace

Tools

{tools.map((tool) => ( toggleToolEnabled(tool.id)} onToggleCollapsed={() => toggleToolCollapsed(tool.id)} onExpand={() => expandTool(tool.id)} onDelete={() => deleteTool(tool.id)} onCodeChange={(newCode) => handleToolCodeChange(tool.id, newCode) } /> ))}
)}
)}
); }; export default App;