This release consists of a major refactor that overhauls the reranker a.k.a. Cross Encoder training approach (introducing multi-gpu training, bf16, loss logging, callbacks, and much more), including all new Training Overview, Loss Overview, API Reference docs, training examples and more!
Install this version with
# Training + Inference
pip install sentence-transformers[train]==4.0.1
# Inference only, use one of:
pip install sentence-transformers==4.0.1
pip install sentence-transformers[onnx-gpu]==4.0.1
pip install sentence-transformers[onnx]==4.0.1
pip install sentence-transformers[openvino]==4.0.1
Tip
My Training and Finetuning Reranker Models with Sentence Transformers v4 blogpost is an excellent place to learn 1) why finetuning rerankers makes sense and 2) how you can do it, too!
Reranker (Cross Encoder) training refactor (#3222)
The v4.0 release centers around this huge modernization of the training approach for CrossEncoder
models, following v3.0 which introduced the same for SentenceTransformer
models. Whereas training before v4.0 used to be all about InputExample
, DataLoader
and model.fit
, the new training approach relies on 5 components. You can learn more about these components in our Training and Finetuning Embedding Models with Sentence Transformers v4 blogpost. Additionally, you can read the new Training Overview, check out the Training Examples, or read this summary:
- Dataset
A trainingDataset
orDatasetDict
. This class is much more suited for sharing & efficient modifications than lists/DataLoaders ofInputExample
instances. ADataset
can contain multiple text columns that will be fed in order to the corresponding loss function. So, if the loss expects (anchor, positive, negative) triplets, then your dataset should also have 3 columns. The names of these columns are irrelevant. If there is a "label" or "score" column, it is treated separately, and used as the labels during training.
ADatasetDict
can be used to train with multiple datasets at once, e.g.:When aDatasetDict({ natural_questions: Dataset({ features: ['anchor', 'positive'], num_rows: 392702 }) gooaq: Dataset({ features: ['anchor', 'positive', 'negative'], num_rows: 549367 }) stsb: Dataset({ features: ['sentence1', 'sentence2', 'label'], num_rows: 5749 }) })
DatasetDict
is used, theloss
parameter to theCrossEncoderTrainer
must also be a dictionary with these dataset keys, e.g.:{ 'natural_questions': CachedMultipleNegativesRankingLoss(...), 'gooaq': CachedMultipleNegativesRankingLoss(...), 'stsb': BinaryCrossEntropyLoss(...), }
- Loss Function
A loss function, or a dictionary of loss functions like described above. - Training Arguments
A CrossEncoderTrainingArguments instance, subclass of a TrainingArguments instance. This powerful class controls the specific details of the training. - Evaluator
An optionalSentenceEvaluator
instance. Unlike before, models can now be evaluated both on an evaluation dataset with some loss function and/or aSentenceEvaluator
instance. - Trainer
The newCrossEncoderTrainer
instance based on thetransformers
Trainer
. This instance can be initialized with a CrossEncoder model, a CrossEncoderTrainingArguments class, a SentenceEvaluator, a training and evaluation Dataset/DatasetDict and a loss function/dict of loss functions. Most of these parameters are optional. Once provided, all you have to do is calltrainer.train()
.
Some of the major features that are now implemented include:
- MultiGPU Training (Data Parallelism (DP) and Distributed Data Parallelism (DDP))
- bf16 training support
- Loss logging
- Evaluation datasets + evaluation loss
- Improved callback support (built-in via Weights and Biases, TensorBoard, CodeCarbon, etc., as well as custom callbacks)
- Gradient checkpointing
- Gradient accumulation
- Improved model card generation
- Warmup ratio
- Pushing to the Hugging Face Hub on every model checkpoint
- Resuming from a training checkpoint
- Hyperparameter Optimization
This script is a minimal example (no evaluator, no training arguments) of training mpnet-base
on a part of the sentence-transformers/hotpotqa
dataset using BinaryCrossEntropyLoss
:
from datasets import load_dataset
from sentence_transformers import CrossEncoder, CrossEncoderTrainer
from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
# 1. Define the model. Either from scratch of by loading a pre-trained model
model = CrossEncoder("microsoft/mpnet-base")
# 2. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/hotpotqa", "triplet", split="train")
def triplet_to_labeled_pair(batch):
anchors = batch["anchor"]
positives = batch["positive"]
negatives = batch["negative"]
return {
"sentence_A": anchors * 2,
"sentence_B": positives + negatives,
"labels": [1] * len(positives) + [0] * len(negatives),
}
dataset = dataset.map(triplet_to_labeled_pair, batched=True, remove_columns=dataset.column_names)
train_dataset = dataset.select(range(10_000))
eval_dataset = dataset.select(range(10_000, 11_000))
# 3. Define a loss function
loss = BinaryCrossEntropyLoss(model)
# 4. Create a trainer & train
trainer = CrossEncoderTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
)
trainer.train()
# 5. Save the trained model
model.save_pretrained("models/mpnet-base-hotpotqa")
# model.push_to_hub("mpnet-base-hotpotqa")
Additionally, trained models now automatically produce extensive model cards. Each of the following models were trained using some script from the Training Examples, and the model cards were not edited manually whatsoever:
- tomaarsen/reranker-MiniLM-L12-gooaq-bce
- tomaarsen/reranker-msmarco-MiniLM-L12-H384-uncased-lambdaloss
- tomaarsen/reranker-distilroberta-base-nli
Prior to the Sentence Transformer v4 release, all reranker models would be trained using the CrossEncoder.fit
method. Rather than deprecating this method, starting from v4.0, this method will use the CrossEncoderTrainer
behind the scenes. This means that your old training code should still work, and should even be upgraded with the new features such as multi-gpu training, loss logging, etc. That said, the new training approach is much more powerful, so it is recommended to write new training scripts using the new approach.
To help you out, all of the Cross Encoder (a.k.a. reranker) training scripts were updated to use the new Trainer-based approach.
Is finetuning worth it?
Finetuning reranker models on your data is very valuable. Consider for example these 2 models that I finetuned on 100k samples from the GooAQ dataset in 30 minutes and 1 hour, respectively. After finetuning, my models heavily outperformed general-purpose reranker models, even though GooAQ is a very generic dataset/domain!
Read my Training and Finetuning Reranker Models with Sentence Transformers v4 blogpost for many more details on these models and how they were trained.
Resources:
- How to use Cross Encoder models? Cross Encoder > Usage
- What Cross Encoder models can I use? Cross Encoder > Pretrained Models
- How do I train/finetune a Cross Encoder model? Cross Encoder > Training Overview
Refactor Stats
- Code:
- New Trainer, Training Arguments, Data Collator, Model Card generation + template, with backwards compatibility
- 11 new losses
- 1 new, 3 refactored, 6 deprecated evaluators
- Tests:
- Docs:
- All new Training Overview, Loss Overview, API Reference docs
- 5 new, 1 refactored training examples docs pages
- 13 new, 6 refactored training scripts
- Migration guide (2.x -> 3.x, 3.x -> 4.x)
Small Features
- Introduce
show_progress_bar
for theInformationRetrievalEvaluator
(#3227) - Replace
SubsetRandomSampler
withRandomSampler
in the default batch sampler, should result in reduced memory usage and increased training speed! (#3261) - Allow resuming from checkpoint when training with the deprecated
SentenceTransformer.fit
(#3269) - Allow truncation and setting
model.max_seq_length
for CLIP models (#2969)
Bug Fixes
- Fixed
MatryoshkaLoss
withn_dims_per_step
and an unsortedmatryoshka_dims
crashing (#3203) - Fixed
GISTEmbedLoss
failing with some base models whose tokenizers don't have thevocab
attribute (#3219, #3226) - Fixed support of
Asym
-basedSentenceTransformer
models (#3220, #3244) - Fixed some evaluator outputs not being converted to a Python float, i.e. staying as
numpy
ortorch
(#3277)
Examples
Note
The v4.0.0
version did not include the model_card_template.md
in the package, this has been resolved in v4.0.1
via ba1260d.
What's Changed
- fix MatryoshkaLoss bug: sort sampled dimension indices to maintain descending dimension order by @emapco in #3203
- [
docs
] Resolve broken URL due to weird & behaviour in pretrained ST models by @tomaarsen in #3213 - Update Evaluation Script for Reranking by @milistu in #3198
- [
docs
] Update incorrect name: pairwise_similarity -> similarity_pairwise by @tomaarsen in #3224 - [
fix
] Use .get_vocab() instead of .vocab for checking tokenizer vocabulary by @tomaarsen in #3226 - [
feat
] Add progress bar support for corpus in IR Evaluator by @tomaarsen in #3227 - Update CoSENTLoss.py documentation by @johneckberg in #3230
- NoDuplicatesDataLoader Compatability with Asymmetric models by @OsamaS99 in #3220
- [
fix
] Fix Syntax issue; move 'as fIn' to after the if-else inSTSDataReader
by @tomaarsen in #3235 - Model Card Compatability & BinaryClassificationEvaluator with Asymmetric Models by @OsamaS99 in #3244
- [fix] Changed value error for missing model into FileNotFoundError by @PhorstenkampFuzzy in #3238
- Add check for hpu and wrap_in_hpu_graph availability. by @vshekhawat-hlab in #3249
- Replacing SubsetRandomSampler by RandomSampler in BATCH_SAMPLER by @NohTow in #3261
- Fix: Reorder dataset columns for DenoisingAutoEncoderLoss in TSADE examples by @HuangBugWei in #3263
- Update to fit_mixin.fit to allow fine tuning to resume from a checkpoint by @NRamirez01 in #3269
- Fix: dynamic noise addition during training in TSADE examples by @HuangBugWei in #3265
- [
typing
] Fix the type hints in CGISTEmbedLoss by @tomaarsen in #3272 - typing: fix typing on encode by @stephantul in #3270
- feat: add 'Path' parameter for ModelCard template by @sam-hey in #3253
- Always convert the evaluation metrics to float, also without a 'name' by @tomaarsen in #3277
- Add truncation to CLIP model by @MrLoh in #2969
- [
v4
] CrossEncoder Training refactor - MultiGPU, loss logging, bf16, etc. by @tomaarsen in #3222 - Bump jinja2 from 3.1.5 to 3.1.6 in /docs by @dependabot in #3282
- Update the core README in preparation for the v4.0 release by @tomaarsen in #3283
- Make minor updates to docs by @tomaarsen in #3285
- Add the .htaccess to git, automatically include it in builds by @tomaarsen in #3286
- Update main description by @tomaarsen in #3287
New Contributors
- @emapco made their first contribution in #3203
- @OsamaS99 made their first contribution in #3220
- @PhorstenkampFuzzy made their first contribution in #3238
- @vshekhawat-hlab made their first contribution in #3249
- @NohTow made their first contribution in #3261
- @HuangBugWei made their first contribution in #3263
- @NRamirez01 made their first contribution in #3269
- @stephantul made their first contribution in #3270
- @sam-hey made their first contribution in #3253
- @MrLoh made their first contribution in #2969
A special shoutout to @milistu for contributing the LambdaLoss & ListNetLoss and @yjoonjang for contributing the ListMLELoss, PListMLELoss, and RankNetLoss. Much appreciated, you really helped improve this release!
Full Changelog: v3.4.1...v4.0.1