pypi trl 0.13.0
v0.13.0

4 days ago

Major and breaking changes

🐾 Process-supervised RM Trainer

We introduced a new trainer to train Process-supervised Reward Model (PRM) in TRL. A PRM rewards the quality of intermediate steps, promoting structured reasoning over focusing solely on the final outcome.With this trainer, we introduce a new dataset type: Stepwise supervision, which is a variant of the prompt-completion type, but for which completion is divided into several intermediate steps, and each step is associated with a label. Find out more in the stepwise-supervision section in the TRL documentation.

Here is an example of how to use the PRMTrainer to train a PRM on the Math Shepherd dataset:

# train_prm.py
from datasets import load_dataset
from trl import PRMConfig, PRMTrainer
from transformers import AutoModelForTokenClassification, AutoTokenizer

model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")

training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10)
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()

For more information, check out the PRMTrainer documentation.

by @qgallouedec and @gaetanlop in #2127 and #2148

🔀 Add MergeModelCallBack

Various works show that model merging can non-trivially improve performance, especially if the models belong to the same architecture. TRL now features a callback that merges the reference model with the current policy and optionally pushes the merged checkpoint to the Hub. This could be done on step/epoch end and/or the end of training. This callback uses Arcee's mergekit lib: https://github.com/arcee-ai/mergekit

from trl import DPOTrainer, MergeModelCallback
from trl.mergekit_utils import MergeConfig

config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(...,  callbacks=[merge_callback])

by @August-murr in #2282

🔨 Support for tools for data utils

TRL preprocessing utils now support tooling. A first step toward agent fine-tuning.

from trl import apply_chat_template

def get_current_temperature(location: str):
    """
    Gets the temperature at a given location.

    Args:
        location: The location to get the temperature for
    """
    return 22.0

example = apply_chat_template(example, tokenizer, tools=[get_current_temperature])

by @August-murr in #2455

🌋 Add support for LLaVA-Next in DPOTrainer

VLMs have their own specificities which require special treatment in the trainer. DPOTrainer now supports LLaVA-Next models natively.

model = model = AutoModelForVision2Seq.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
trainer = DPOTrainer(model=model, ...)

by @chenweize1998 in #2413

🕹ī¸ CLI and TRLParser refactor

TRL CLI has been refactored to be more user-friendly and easy to extend. We plan to extend the support to all trainers soon.

(simplified output, for readibility)

$ trl dpo --help
usage: trl dpo [-h] --dataset_name DATASET_NAME [--dataset_config DATASET_CONFIG] --output_dir OUTPUT_DIR [--loss_type {sigmoid,hinge,ipo}]

options:
  -h, --help            show this help message and exit
  --dataset_name DATASET_NAME, --dataset-name DATASET_NAME
  --dataset_config DATASET_CONFIG, --dataset-config DATASET_CONFIG
  --output_dir OUTPUT_DIR, --output-dir OUTPUT_DIR
                        The output directory where the model predictions and checkpoints will be written. (default: None)
  --loss_type {sigmoid,hinge,ipo}, --loss-type {sigmoid,hinge,ipo}

by @qgallouedec in #2380 and #2412

🤝 Mixture of judges

TRL features a new judge AllTrueJudge that unifies the decision of multiple binary judges. This judge implements the Mixture of Judges as described in the CGPO paper.

from trl import AllTrueJudge, BaseBinaryJudge

class RandomBinaryJudge(BaseBinaryJudge):
    """
    Random binary judge, for testing purposes.
    """

    def judge(self, prompts, completions, gold_completions=None, shuffle_order=True):
        return [random.choice([0, 1, -1]) for _ in range(len(prompts))]


prompts = ["The capital of France is", "The biggest planet in the solar system is"]
completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()])
judgements = judge.judge(prompts=prompts, completions=completions)
print(judgements)  # [0, 1]

by @gaetanlop in #2159

❄ī¸ DPO trainer supports num_logits_to_keep to save memory

Save memory by only keeping the top num_logits_to_keep logits in the DPO trainer.

training_args = DPOConfig(..., use_num_logits_to_keep=True)

by @xyangk in #2129

đŸ—ēī¸ Implementation DiscoPOP Loss

The DiscoPOP paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0).

training_args = DPOConfig(..., loss_type="discopop", discopop_tau=0.05)

by @fanconic in #2323

🧑‍đŸŗ Add precompute batch size argument in DPOTrainer for reference model

We can now control the batch size for precomputing reference model logits.

training_args = DPOConfig(
...
    precompute_ref_log_probs=True,
    precompute_ref_batch_size=4,
)

by @SwayamInSync in #2426

đŸ“Ļ Support for packing tokenized datasets for SFT

SFTTrainer has supported packing datasets for faster training. Now, it support packing tokenized datasets as well.

by @kmehant in #2011

📉 Add PEFT support for PPOTrainer

PPOTrainer now supports PEFT for efficient training.

PPOTrainer(
    ...,
    peft_config=peft_config,
)

by @ccs96307 in #2344

💾 Deprecate config in favor of args in PPOTrainer

config has been deprecated in favor of args in PPOTrainer.

  PPOTrainer(
-   config=training_args,
+   args=training_args,
  )

by @qgallouedec in #2384

👮 Deprecate policy in favor of model in PPOTrainer

policy has been deprecated in favor of model in PPOTrainer.

  PPOTrainer(
-   policy=model,
+   model=model,
  )

by @qgallouedec in #2386

What's Changed

New Contributors

Full Changelog: v0.12.0...v0.13.0

Don't miss a new trl release

NewReleases is sending notifications on new releases.