sft_llama3.2_3b

LoRA adapter fine-tuned on meta-llama/Llama-3.2-3B-Instruct for evaluating predicted clinical discharge diagnoses against ground-truth diagnoses. Given a predicted diagnosis list (JSON) and a ground-truth list, the model returns a structured JSON evaluation including primary/top-5 correctness, missed diagnoses, and improvement suggestions.

This checkpoint (checkpoint-50) was selected because it reached the lowest eval loss before the model began overfitting.

Model Details

  • Base model: meta-llama/Llama-3.2-3B-Instruct
  • Fine-tuning method: LoRA (rank 8, all linear modules)
  • Trainable parameters: 12.16M / 3.22B (0.377%)
  • Adapter size: ~47 MB
  • Precision: bfloat16
  • Developed by: JinR

Intended Use

Task: evaluate predicted discharge diagnoses vs. ground truth with semantic matching, returning structured JSON.

Input format (user message): a prompt containing PREDICTED OUTPUT (JSON), GROUND TRUTH DIAGNOSES, and the requested output schema.

Output format:

{
  "diagnosis_evaluation": {
    "primary_correct": true/false,
    "any_top5_correct": true/false,
    "missed_diagnoses": ["..."],
    "why_missed": "..."
  },
  "improvement_suggestions": "..."
}

Out of scope: direct clinical decision-making, patient-facing advice, or use without clinician review.

How to Use

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

base = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B-Instruct",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
model = PeftModel.from_pretrained(base, "JinR/sft_llama3.2_3b")

messages = [{"role": "user", "content": "<your evaluation prompt here>"}]
inputs = tok.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(model.device)
out = model.generate(inputs, max_new_tokens=512, do_sample=False)
print(tok.decode(out[0][inputs.shape[-1]:], skip_special_tokens=True))

Training Data

  • Size: 490 examples
  • Format: ShareGPT / messages (role + content)
  • Split: 441 train / 49 validation (auto-split, val_size=0.1)
  • Content: single-turn prompts asking the model to compare predicted vs. ground-truth discharge diagnoses and output a structured JSON evaluation

Training Procedure

Hyperparameters

Setting Value
LoRA rank 8
LoRA target all linear modules
cutoff_len 2048
packing false
learning rate 3e-4
lr scheduler cosine
warmup ratio 0.1
epochs 5
per-device batch size 4
gradient accumulation 2
effective batch size 24 (4 ร— 2 ร— 3 GPUs)
precision bf16
optimizer AdamW
gradient checkpointing yes
attention torch SDPA

Infrastructure

  • Hardware: 3ร— NVIDIA A100 (TACC Lonestar-6, node c301-002)
  • Framework: LLaMA-Factory + transformers 5.2.0 + PEFT 0.18.1
  • Runtime: 1 minute 46 seconds

Results

Step Epoch train_loss eval_loss
10 0.54 1.186 โ€”
25 1.32 โ€” 0.6742
50 2.65 0.5099 0.6034 (best)
75 3.97 โ€” 0.6171
95 5.00 0.3155 0.6187

Training loss drops from 2.017 โ†’ 0.3155; eval loss bottoms at step 50, confirming checkpoint-50 as the best-generalization checkpoint.

Limitations

  • Trained on only 490 examples โ€” expect moderate generalization to out-of-distribution prompts.
  • Mild overfitting observed after epoch ~2.6 (train/eval loss diverge).
  • Not validated for real clinical use; outputs should be reviewed by a clinician.
  • Inherits biases and limitations of the base Llama-3.2-3B-Instruct model.

Full Training Log

See TRAINING_LOG_sft_llama3.2_3b.md in this repo for full reproducibility details (environment, data statistics, exact commands, file sizes, caveats).

Framework versions

  • PEFT 0.18.1
Downloads last month
14
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for jinrui123/sft_llama3.2_3b

Adapter
(703)
this model