QLoRA RLHF, SFT Trainer and RewardTrainer
A new version of TRL that includes training larger models using QLoRA (4 bit quantization through bitsandbytes), brand new classes RewardTrainer
and SFTTrainer
to easily conduct your RLHF projects end-to-end!
Introducing SFTTrainer
and RewardTrainer
Use the brand new trainer to easily train your reward model and supervised fine-tuned (SFT) model with few lines of code!
- [
core
] officially support SFT (Supervised Finetuning) by @younesbelkada in #323 - [
SFT
] Fix sft issues by @younesbelkada in #336 - [
docs
] fix SFT doc by @younesbelkada in #367 - [
core
] Officially Support Reward Modeling by @younesbelkada in #303 - Resolve broken evaluation/prediction for RewardTrainer by @tomaarsen in #404
QLoRA integration
Pass 4bit models directly into PPOTrainer
for more memory efficient training
- [
core
] Add 4bit QLora by @younesbelkada in #383 - [
bnb
] fix 4 bit SFT by @younesbelkada in #396
Updated StackLlama example
Great work by @mnoukhov that managed to fix the issues related with StackLlama and the new versions of accelerate
, peft
and transformers
. The completely reproducible examples below:
- StackLLaMA: correctly merge peft model by @mnoukhov in #398
- StackLlama: fixed RL training and added args by @mnoukhov in #400
- Fixed some type annotations of trl.trainer.PPoTrainer by @JulesGM in #392
- StackLLaMA: fix supervised finetuning and reward model training by @mnoukhov in #399
Bug fixes and improvements
- [
core
] refactor peft API by @younesbelkada in #231 - Batched generation by @lvwerra in #228
- Reduce memory consumption in batched_forward_pass by @ohashi56225 in #234
- [
core
] Add warning when negative KL by @younesbelkada in #239 - adds early stopping by @edbeeching in #238
- PPO config init is bloated by @GauravVirmani in #241
- feat(ci): enable
pip
cache by @SauravMaheshkar in #198 - Improve logging for PPO + Docs page by @natolambert in #243
- Fix typo by @heya5 in #253
- Using batched generate in sentiment scripts by @GauravVirmani in #249
- [
core
] Fix DeepSpeed zero-3 issue by @younesbelkada in #182 - [
distributed
] Fix early stopping and DP by @younesbelkada in #254 - [
core
] Fix ds issue by @younesbelkada in #260 - Add LlaMa in tests +
create_reference_model
by @younesbelkada in #261 - Use active model to generate response in example on README (#269) by @rmill040 in #271
- stack-llama by @edbeeching in #273
- Adding pointer back to Meta's LLaMA. by @meg-huggingface in #277
- fix doc string problem in ppo trainer loss function by @thuwyh in #279
- Add LLaMA tutorial to docs by @natolambert in #278
- Fix swapped helper texts by @philipp-classen in #284
- fix typo in gpt2-sentiment.ipynb by @eltociear in #293
- add functionality to push best models to the hub during training by @Bearnardd in #275
- Small improvements / fixes to toxicity example by @natolambert in #266
- Fix arguments description by @lvzii in #298
- [
t5
] Fix negative kl issue by @younesbelkada in #262 - Log Token distribution of Query / Response by @natolambert in #295
- clean examples folder by @natolambert in #294
- fixed typo in error message by @soerenarlt in #312
- fix DS for peft ref_model in ppo trainer by @halfrot in #309
- [
CI
] Fix broken tests by @younesbelkada in #318 - [
Docs
] Add details on multi-GPU / multi-node by @younesbelkada in #320 - Give a key to the wandb PPOConfig config entry by @JulesGM in #315
- added doc for using torch.distributed.launch/run by @oroojlooy in #324
- Fix argument's description by @vinhkhuc in #339
- stack_llama: update instructions in README, fix broken _get_submodules and save tokenizer by @teticio in #358
- stack_llama: add parameter to control max_length (to mitigate OOM errors) by @teticio in #359
- [
PPO
] Relax negative KL constraint by @younesbelkada in #352 - [
PPOTrainer
] Fix tensorboard issue by @younesbelkada in #330 - 140/best n sampling by @metric-space in #326
- Fix bug when loading local peft model by @Opdoop in #342
- add is_trainable in kwargs by @Opdoop in #363
- Remove obsolete layer_norm_names parameter and add peft>=0.3.0 to requirements by @teticio in #366
- Delete test_training.py by @younesbelkada in #371
- [
core
] Fix warning issue by @younesbelkada in #377 - Update customization.mdx by @binganao in #390
- fix dataloader typo in ppo_trainer.py by @LZY-the-boys in #389
- from_pretrain with peft adapter on the hub (# 379) by @glerzing in #380
- keep state_dict kwargs instead of popping it in save_pretrained by @rizar in #393
- Remove unused imports in docs. by @vwxyzjn in #406
New Contributors
- @ohashi56225 made their first contribution in #234
- @GauravVirmani made their first contribution in #241
- @SauravMaheshkar made their first contribution in #198
- @heya5 made their first contribution in #253
- @rmill040 made their first contribution in #271
- @thuwyh made their first contribution in #279
- @philipp-classen made their first contribution in #284
- @Bearnardd made their first contribution in #275
- @lvzii made their first contribution in #298
- @soerenarlt made their first contribution in #312
- @halfrot made their first contribution in #309
- @oroojlooy made their first contribution in #324
- @vinhkhuc made their first contribution in #339
- @teticio made their first contribution in #358
- @metric-space made their first contribution in #326
- @Opdoop made their first contribution in #342
- @binganao made their first contribution in #390
- @LZY-the-boys made their first contribution in #389
- @glerzing made their first contribution in #380
- @rizar made their first contribution in #393
- @mnoukhov made their first contribution in #398
- @tomaarsen made their first contribution in #404
- @vwxyzjn made their first contribution in #406
Full Changelog: v0.4.1...v0.4.2