# Generalized Knowledge Distillation Trainer

[![model badge](https://img.shields.io/badge/All_models-GKD-blue)](https://huggingface.co/models?other=gkd,trl)

## Overview

Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.

The abstract from the paper is the following:

> Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning.

The key aspects of GKD are:

1. It addresses the train-inference distribution mismatch in auto-regressive sequence models by training the student model on its self-generated output sequences.
2. GKD allows flexibility in choosing different divergence measures between student and teacher models via the generalized Jensen-Shannon Divergence (JSD), which can be useful when the student lacks the capacity to fully mimic the teacher.

This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Lewis Tunstall](https://huggingface.co/lewtun).

## Usage tips

The [GKDTrainer](/docs/trl/v0.25.1/en/gkd_trainer#trl.GKDTrainer) is a wrapper around the [SFTTrainer](/docs/trl/v0.25.1/en/sft_trainer#trl.SFTTrainer) class that takes in a teacher model argument. It needs three parameters to be set via the [GKDConfig](/docs/trl/v0.25.1/en/gkd_trainer#trl.GKDConfig) namely:

* `lmbda`:  controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
* `seq_kd`:  controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher. 
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence.  When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.

The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.

> [!WARNING]
> Make sure that `attn_implementation="flash_attention_2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture.

The basic API is as follows:

```python
from datasets import Dataset
from trl import GKDConfig, GKDTrainer
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)

NUM_DUMMY_SAMPLES = 100

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The teacher model to calculate the KL divergence against
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")

train_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "Hi, how are you?"},
                {"role": "assistant", "content": "I'm great thanks"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)
eval_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "What colour is the sky?"},
                {"role": "assistant", "content": "The sky is blue"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)

training_args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
trainer = GKDTrainer(
    model=model,
    teacher_model=teacher_model,
    args=training_args,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()
```

### Expected dataset type

The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys:

* `role`: either `system`, `assistant` or `user`
* `content`: the message content

## GKDTrainer[[trl.GKDTrainer]]

#### trl.GKDTrainer[[trl.GKDTrainer]]

[Source](https://github.com/huggingface/trl/blob/v0.25.1/trl/trainer/gkd_trainer.py#L54)

Trainer for Generalized Knowledge Distillation (GKD) of language models.

For details on GKD, see the paper: [On-Policy Distillation of Language Models: Learning from Self-Generated
Mistakes](https://huggingface.co/papers/2306.13649).

traintrl.GKDTrainer.trainhttps://github.com/huggingface/trl/blob/v0.25.1/transformers/trainer.py#L2213[{"name": "resume_from_checkpoint", "val": ": typing.Union[str, bool, NoneType] = None"}, {"name": "trial", "val": ": typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None"}, {"name": "ignore_keys_for_eval", "val": ": typing.Optional[list[str]] = None"}, {"name": "**kwargs", "val": ": typing.Any"}]- **resume_from_checkpoint** (`str` or `bool`, *optional*) --
  If a `str`, local path to a saved checkpoint as saved by a previous instance of `Trainer`. If a
  `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
  of `Trainer`. If present, training will resume from the model/optimizer/scheduler states loaded here.
- **trial** (`optuna.Trial` or `dict[str, Any]`, *optional*) --
  The trial run or the hyperparameter dictionary for hyperparameter search.
- **ignore_keys_for_eval** (`list[str]`, *optional*) --
  A list of keys in the output of your model (if it is a dictionary) that should be ignored when
  gathering predictions for evaluation during the training.
- **kwargs** (`dict[str, Any]`, *optional*) --
  Additional keyword arguments used to hide deprecated arguments0

Main training entry point.

**Parameters:**

model ([PreTrainedModel](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/model#transformers.PreTrainedModel) or `torch.nn.Module` or `str`, *optional*) : Model to be trained, or the string identifier of the model to be instantiated from a pretrained model.

teacher_model ([PreTrainedModel](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/model#transformers.PreTrainedModel) or `torch.nn.Module` or `str`, *optional*) : Teacher model for knowledge distillation, or the string identifier of the model to be instantiated from a pretrained model.

args ([GKDConfig](/docs/trl/v0.25.1/en/gkd_trainer#trl.GKDConfig), *optional*) : Training arguments.

data_collator (`DataCollator`, *optional*) : Data collator to batch samples from the dataset. It defaults to a `DataCollatorForChatML` using the `processing_class`.

train_dataset ([Dataset](https://huggingface.co/docs/datasets/v4.4.1/en/package_reference/main_classes#datasets.Dataset), *optional*) : Dataset for training.

eval_dataset ([Dataset](https://huggingface.co/docs/datasets/v4.4.1/en/package_reference/main_classes#datasets.Dataset) or `dict` of [Dataset](https://huggingface.co/docs/datasets/v4.4.1/en/package_reference/main_classes#datasets.Dataset), *optional*) : Dataset for evaluation.

processing_class ([PreTrainedTokenizerBase](https://huggingface.co/docs/transformers/v4.57.1/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase), [BaseImageProcessor](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/image_processor#transformers.BaseImageProcessor), [FeatureExtractionMixin](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/feature_extractor#transformers.FeatureExtractionMixin) or [ProcessorMixin](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/processors#transformers.ProcessorMixin), *optional*) : Class to process the data.

compute_metrics (`Callable`, *optional*) : Function to compute metrics at evaluation. Must take in an [EvalPrediction](https://huggingface.co/docs/transformers/v4.57.1/en/internal/trainer_utils#transformers.EvalPrediction) and return a dictionary string to float.

callbacks (`list` of [TrainerCallback](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/callback#transformers.TrainerCallback), *optional*) : Callbacks to use during training.

optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`) : Tuple containing the optimizer and the learning rate scheduler to use for training.

preprocess_logits_for_metrics (`Callable`, *optional*) : Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and return the logits to be used for metrics computation.

peft_config ([PeftConfig](https://huggingface.co/docs/peft/v0.18.0.rc0/en/package_reference/config#peft.PeftConfig), *optional*) : PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be wrapped with the specified PEFT adapter.

formatting_func (`Callable`, *optional*) : Function to format the dataset. Must take in an example and return an example.
#### save_model[[trl.GKDTrainer.save_model]]

[Source](https://github.com/huggingface/trl/blob/v0.25.1/transformers/trainer.py#L4177)

Will save the model, so you can reload it using `from_pretrained()`.

Will only save from the main process.
#### push_to_hub[[trl.GKDTrainer.push_to_hub]]

[Source](https://github.com/huggingface/trl/blob/v0.25.1/transformers/trainer.py#L5117)

Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`.

**Parameters:**

commit_message (`str`, *optional*, defaults to `"End of training"`) : Message to commit while pushing.

blocking (`bool`, *optional*, defaults to `True`) : Whether the function should return only when the `git push` has finished.

token (`str`, *optional*, defaults to `None`) : Token with write permission to overwrite Trainer's original args.

revision (`str`, *optional*) : The git revision to commit from. Defaults to the head of the "main" branch.

kwargs (`dict[str, Any]`, *optional*) : Additional keyword arguments passed along to `~Trainer.create_model_card`.

**Returns:**

The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
progress of the commit if `blocking=True`.

## GKDConfig[[trl.GKDConfig]]

#### trl.GKDConfig[[trl.GKDConfig]]

[Source](https://github.com/huggingface/trl/blob/v0.25.1/trl/trainer/gkd_config.py#L24)

Configuration class for [GKDTrainer](/docs/trl/v0.25.1/en/gkd_trainer#trl.GKDTrainer).

This class includes only the parameters that are specific to GKD training. For a full list of training arguments,
please refer to the [TrainingArguments](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/trainer#transformers.TrainingArguments) and [SFTConfig](/docs/trl/v0.25.1/en/sft_trainer#trl.SFTConfig) documentation.

**Parameters:**

temperature (`float`, *optional*, defaults to `0.9`) : Temperature for sampling. The higher the temperature, the more random the completions.

lmbda (`float`, *optional*, defaults to `0.5`) : Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy student-generated outputs).

beta (`float`, *optional*, defaults to `0.5`) : Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.

max_new_tokens (`int`, *optional*, defaults to `128`) : Maximum number of tokens to generate per completion.

teacher_model_name_or_path (`str`, *optional*) : Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being trained.

teacher_model_init_kwargs (`dict[str, Any]]`, *optional*) : Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model from a string.

disable_dropout (`bool`, *optional*, defaults to `True`) : Whether to disable dropout in the model.

seq_kd (`bool`, *optional*, defaults to `False`) : Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated output).

