gradio-test / app.py
mavilov's picture
Rewrite locally
ba050cc
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
import torch
import os
# -------------------------
# Load dataset
# -------------------------
dataset = load_dataset("mavilov/convos", split="train")
# -------------------------
# Load model and tokenizer
# -------------------------
model_id = "swiss-ai/Apertus-8B-2509"
model_kwargs = {}
if torch.backends.mps.is_available():
print("⚡ Using Apple MPS backend (Metal)")
model_kwargs = {
"dtype": torch.float16,
"device_map": {"": "mps"}, # force load directly on MPS
"offload_folder": "./offload",
"low_cpu_mem_usage": True, # avoid meta tensors
}
elif torch.cuda.is_available():
print("⚡ Using CUDA with bitsandbytes quantization")
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
model_kwargs["quantization_config"] = bnb_config
model_kwargs["device_map"] = "auto"
else:
print("⚠️ No GPU/MPS detected, running on CPU (very slow)")
model_kwargs = {
"dtype": torch.float32,
"device_map": {"": "cpu"},
"low_cpu_mem_usage": True,
}
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Load model safely
model = AutoModelForCausalLM.from_pretrained(
model_id,
**model_kwargs
)
model.config.use_cache = False
model.config.pretraining_tp = 1
# -------------------------
# Attach LoRA adapters
# -------------------------
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# -------------------------
# Preprocess / tokenize dataset
# -------------------------
def tokenize_fn(example):
tokenized = tokenizer(
example["text"],
truncation=True,
max_length=2048
)
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
dataset = dataset.map(tokenize_fn, batched=True)
# -------------------------
# Data collator with dynamic padding
# -------------------------
data_collator = DataCollatorForSeq2Seq(tokenizer, padding="longest")
# -------------------------
# Training configuration
# -------------------------
training_args = SFTConfig(
output_dir="./results",
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-4,
num_train_epochs=3,
logging_steps=10,
report_to="tensorboard",
bf16=False,
)
# -------------------------
# Initialize trainer
# -------------------------
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
data_collator=data_collator
)
# -------------------------
# Start training
# -------------------------
trainer.train()