Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline | |
| import torch | |
| import spaces | |
| # Initialize the model pipeline | |
| model_id = "facebook/MobileLLM-R1-950M" | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| ) | |
| def respond(message, history): | |
| # Build prompt from history | |
| prompt = "" | |
| for user_msg, assistant_msg in history: | |
| if user_msg: | |
| prompt += f"User: {user_msg}\n" | |
| if assistant_msg: | |
| prompt += f"Assistant: {assistant_msg}\n" | |
| # Add current message | |
| prompt += f"User: {message}\nAssistant: " | |
| # Generate response with streaming | |
| streamer = pipe.tokenizer.decode | |
| # Generate tokens | |
| inputs = pipe.tokenizer(prompt, return_tensors="pt").to(pipe.model.device) | |
| with torch.no_grad(): | |
| outputs = pipe.model.generate( | |
| **inputs, | |
| max_new_tokens=10000, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=pipe.tokenizer.eos_token_id, | |
| ) | |
| # Decode the generated tokens, skipping the input tokens | |
| generated_tokens = outputs[0][inputs['input_ids'].shape[-1]:] | |
| # Stream the output token by token | |
| response_text = "" | |
| for i in range(len(generated_tokens)): | |
| token = generated_tokens[i:i+1] | |
| token_text = pipe.tokenizer.decode(token, skip_special_tokens=True) | |
| response_text += token_text | |
| yield response_text | |
| # Create the chat interface | |
| demo = gr.ChatInterface( | |
| fn=respond, | |
| title="MobileLLM Chat", | |
| description="Chat with Meta MobileLLM-R1-950M", | |
| examples=[ | |
| "Write a Python function that returns the square of a number.", | |
| "Compute: 1-2+3-4+5- ... +99-100.", | |
| "Write a C++ program that prints 'Hello, World!'.", | |
| "Explain how recursion works in programming.", | |
| "What is the difference between a list and a tuple in Python?", | |
| ], | |
| theme=gr.themes.Soft(), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |