| | |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| | from peft import PeftModel |
| |
|
| | def load_model(model_path="final_model_continue"): |
| | """Load the fine-tuned model""" |
| | print("🔧 Loading model...") |
| |
|
| | |
| | bnb_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_compute_dtype=torch.bfloat16 |
| | ) |
| |
|
| | |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | "meta-llama/Meta-Llama-3.1-8B-Instruct", |
| | quantization_config=bnb_config, |
| | device_map="auto", |
| | torch_dtype=torch.bfloat16, |
| | ) |
| |
|
| | |
| | model = PeftModel.from_pretrained(base_model, model_path) |
| | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| |
|
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | print("✅ Model loading completed!") |
| | return model, tokenizer |
| |
|
| | def generate_response(model, tokenizer, prompt, max_length=200): |
| | """Generate financial advice response""" |
| | inputs = tokenizer(prompt, return_tensors="pt") |
| |
|
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=max_length, |
| | do_sample=True, |
| | temperature=0.7, |
| | top_p=0.9, |
| | pad_token_id=tokenizer.eos_token_id |
| | ) |
| |
|
| | response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | return response[len(prompt):] |
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | model, tokenizer = load_model() |
| |
|
| | |
| | prompt = """### Instruction: |
| | Please provide investment advice for investors regarding technology stocks. |
| | |
| | ### Input: |
| | A technology company's revenue grew 20% this quarter, but profit margin decreased by 5%, mainly due to increased R&D investment. The company has major breakthroughs in AI. |
| | |
| | ### Response:""" |
| |
|
| | |
| | advice = generate_response(model, tokenizer, prompt) |
| | print("🤖 AI Investment Advice:") |
| | print(advice) |
| |
|