sft_llama3.2_3b
LoRA adapter fine-tuned on meta-llama/Llama-3.2-3B-Instruct for evaluating predicted clinical discharge diagnoses against ground-truth diagnoses. Given a predicted diagnosis list (JSON) and a ground-truth list, the model returns a structured JSON evaluation including primary/top-5 correctness, missed diagnoses, and improvement suggestions.
This checkpoint (checkpoint-50) was selected because it reached the lowest eval loss before the model began overfitting.
Model Details
- Base model: meta-llama/Llama-3.2-3B-Instruct
- Fine-tuning method: LoRA (rank 8, all linear modules)
- Trainable parameters: 12.16M / 3.22B (0.377%)
- Adapter size: ~47 MB
- Precision: bfloat16
- Developed by: JinR
Intended Use
Task: evaluate predicted discharge diagnoses vs. ground truth with semantic matching, returning structured JSON.
Input format (user message): a prompt containing PREDICTED OUTPUT (JSON), GROUND TRUTH DIAGNOSES, and the requested output schema.
Output format:
{
"diagnosis_evaluation": {
"primary_correct": true/false,
"any_top5_correct": true/false,
"missed_diagnoses": ["..."],
"why_missed": "..."
},
"improvement_suggestions": "..."
}
Out of scope: direct clinical decision-making, patient-facing advice, or use without clinician review.
How to Use
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
base = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto",
)
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
model = PeftModel.from_pretrained(base, "JinR/sft_llama3.2_3b")
messages = [{"role": "user", "content": "<your evaluation prompt here>"}]
inputs = tok.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(model.device)
out = model.generate(inputs, max_new_tokens=512, do_sample=False)
print(tok.decode(out[0][inputs.shape[-1]:], skip_special_tokens=True))
Training Data
- Size: 490 examples
- Format: ShareGPT / messages (role + content)
- Split: 441 train / 49 validation (auto-split,
val_size=0.1) - Content: single-turn prompts asking the model to compare predicted vs. ground-truth discharge diagnoses and output a structured JSON evaluation
Training Procedure
Hyperparameters
| Setting | Value |
|---|---|
| LoRA rank | 8 |
| LoRA target | all linear modules |
| cutoff_len | 2048 |
| packing | false |
| learning rate | 3e-4 |
| lr scheduler | cosine |
| warmup ratio | 0.1 |
| epochs | 5 |
| per-device batch size | 4 |
| gradient accumulation | 2 |
| effective batch size | 24 (4 ร 2 ร 3 GPUs) |
| precision | bf16 |
| optimizer | AdamW |
| gradient checkpointing | yes |
| attention | torch SDPA |
Infrastructure
- Hardware: 3ร NVIDIA A100 (TACC Lonestar-6, node
c301-002) - Framework: LLaMA-Factory + transformers 5.2.0 + PEFT 0.18.1
- Runtime: 1 minute 46 seconds
Results
| Step | Epoch | train_loss | eval_loss |
|---|---|---|---|
| 10 | 0.54 | 1.186 | โ |
| 25 | 1.32 | โ | 0.6742 |
| 50 | 2.65 | 0.5099 | 0.6034 (best) |
| 75 | 3.97 | โ | 0.6171 |
| 95 | 5.00 | 0.3155 | 0.6187 |
Training loss drops from 2.017 โ 0.3155; eval loss bottoms at step 50, confirming checkpoint-50 as the best-generalization checkpoint.
Limitations
- Trained on only 490 examples โ expect moderate generalization to out-of-distribution prompts.
- Mild overfitting observed after epoch ~2.6 (train/eval loss diverge).
- Not validated for real clinical use; outputs should be reviewed by a clinician.
- Inherits biases and limitations of the base Llama-3.2-3B-Instruct model.
Full Training Log
See TRAINING_LOG_sft_llama3.2_3b.md in this repo for full reproducibility details (environment, data statistics, exact commands, file sizes, caveats).
Framework versions
- PEFT 0.18.1
- Downloads last month
- 14
Model tree for jinrui123/sft_llama3.2_3b
Base model
meta-llama/Llama-3.2-3B-Instruct