This minor version brings several improvements to contrastive learning: MultipleNegativesRankingLoss now supports alternative InfoNCE formulations (symmetric, GTE-style) and optional hardness weighting for harder negatives. Two new losses are introduced, GlobalOrthogonalRegularizationLoss for embedding space regularization and CachedSpladeLoss for memory-efficient SPLADE training. The release also adds a faster hashed batch sampler, fixes GroupByLabelBatchSampler for triplet losses, and ensures full compatibility with the latest Transformers v5 versions.
Install this version with
# Training + Inference
pip install sentence-transformers[train]==5.3.0
# Inference only, use one of:
pip install sentence-transformers==5.3.0
pip install sentence-transformers[onnx-gpu]==5.3.0
pip install sentence-transformers[onnx]==5.3.0
pip install sentence-transformers[openvino]==5.3.0Updated MultipleNegativesRankingLoss (a.k.a. InfoNCE)
MultipleNegativesRankingLoss received two major upgrades: support for alternative InfoNCE formulations from the literature, and optional hardness weighting to up-weight harder negatives.
Support other InfoNCE variants (#3607)
MultipleNegativesRankingLoss now supports several well-known contrastive loss variants from the literature through new directions and partition_mode parameters. Previously, this loss only supported the standard forward direction (query → doc). You can now configure which similarity interactions are included in the loss:
"query_to_doc"(default): For each query, its matched document should score higher than all other documents."doc_to_query": The symmetric reverse — for each document, its matched query should score higher than all other queries."query_to_query": For each query, all other queries should score lower than its matched document."doc_to_doc": For each document, all other documents should score lower than its matched query.
The partition_mode controls how scores are normalized: "joint" computes a single softmax over all directions, while "per_direction" computes a separate softmax per direction and averages the losses.
These combine to reproduce several loss formulations from the literature:
Standard InfoNCE (default, unchanged behavior):
loss = MultipleNegativesRankingLoss(model)
# equivalent to directions=("query_to_doc",), partition_mode="joint"Symmetric InfoNCE (Günther et al. 2024) — adds the reverse direction so both queries and documents are trained to find their match:
loss = MultipleNegativesRankingLoss(
model,
directions=("query_to_doc", "doc_to_query"),
partition_mode="per_direction",
)GTE improved contrastive loss (Li et al. 2023) — adds same-type negatives (query <-> query, doc <-> doc) for a stronger training signal, especially useful with pairs-only data:
loss = MultipleNegativesRankingLoss(
model,
directions=("query_to_doc", "query_to_query", "doc_to_query", "doc_to_doc"),
partition_mode="joint",
)Hardness-weighted contrastive learning (#3667)
Adds optional hardness weighting to MultipleNegativesRankingLoss and CachedMultipleNegativesRankingLoss, inspired by Lan et al. 2025 (LLaVE). This up-weights harder negatives in the softmax by adding hardness_strength * stop_grad(cos_sim) to selected negative logits. The feature is off by default (hardness_mode=None), so existing behavior is unchanged.
The hardness_mode parameter controls which negatives receive the penalty:
"in_batch_negatives": Penalizes in-batch negatives only (positives and hard negatives from other samples). Works with all data formats including pairs-only."hard_negatives": Penalizes explicit hard negatives only (columns beyond the first two). Only active when hard negatives are provided."all_negatives": Penalizes both in-batch and hard negatives, leaving only the positive unpenalized.
from sentence_transformers.losses import MultipleNegativesRankingLoss
loss = MultipleNegativesRankingLoss(
model,
hardness_mode="in_batch_negatives",
hardness_strength=9.0,
)New loss: GlobalOrthogonalRegularizationLoss (#3654)
Introduces GlobalOrthogonalRegularizationLoss (Zhang et al. 2017), a regularization loss that encourages embeddings to be well-distributed in the embedding space. It penalizes two things: (1) high mean pairwise similarity across unrelated embeddings, and (2) high second moment of similarities (which indicates clustering). This loss is meant to be combined with a primary contrastive loss like MultipleNegativesRankingLoss. By wrapping both losses in a single module, you can share embeddings and only require one forward pass:
import torch
from datasets import Dataset
from torch import Tensor
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
from sentence_transformers.losses import GlobalOrthogonalRegularizationLoss, MultipleNegativesRankingLoss
from sentence_transformers.util import cos_sim
model = SentenceTransformer("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
"anchor": ["It's nice weather outside today.", "He drove to work."],
"positive": ["It's so sunny.", "He took the car to the office."],
})
class InfoNCEGORLoss(torch.nn.Module):
def __init__(self, model: SentenceTransformer, similarity_fct=cos_sim, scale=20.0) -> None:
super().__init__()
self.model = model
self.info_nce_loss = MultipleNegativesRankingLoss(model, similarity_fct=similarity_fct, scale=scale)
self.gor_loss = GlobalOrthogonalRegularizationLoss(model, similarity_fct=similarity_fct)
def forward(self, sentence_features: list[dict[str, Tensor]], labels: Tensor | None = None) -> Tensor:
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
info_nce_loss: dict[str, Tensor] = {
"info_nce": self.info_nce_loss.compute_loss_from_embeddings(embeddings, labels)
}
gor_loss: dict[str, Tensor] = self.gor_loss.compute_loss_from_embeddings(embeddings, labels)
return {**info_nce_loss, **gor_loss}
loss = InfoNCEGORLoss(model)
trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()New loss: CachedSpladeLoss for memory-efficient SPLADE training (#3670)
Introduces CachedSpladeLoss, a gradient-cached version of SpladeLoss that enables training SPLADE models with larger batch sizes without additional GPU memory. It applies the GradCache technique at the SpladeLoss wrapper level, so both the base loss and regularizers receive pre-computed embeddings — no changes to existing base losses or regularizers are needed.
from datasets import Dataset
from sentence_transformers.sparse_encoder import SparseEncoder, SparseEncoderTrainer
from sentence_transformers.sparse_encoder.losses import CachedSpladeLoss, SparseMultipleNegativesRankingLoss
model = SparseEncoder("distilbert/distilbert-base-uncased")
train_dataset = Dataset.from_dict({
"anchor": ["It's nice weather outside today.", "He drove to work."],
"positive": ["It's so sunny.", "He took the car to the office."],
})
loss = CachedSpladeLoss(
model=model,
loss=SparseMultipleNegativesRankingLoss(model),
document_regularizer_weight=3e-5,
query_regularizer_weight=5e-5,
mini_batch_size=32,
)
trainer = SparseEncoderTrainer(model=model, train_dataset=train_dataset, loss=loss)
trainer.train()Faster NoDuplicatesBatchSampler with hashing (#3611)
Adds a NO_DUPLICATES_HASHED batch sampler option, which uses the existing NoDuplicatesBatchSampler with precompute_hashes=True. This pre-computes xxhash 64-bit values for each sample, providing significant speedups for large batch sizes at a small memory cost. Requires the xxhash library.
from sentence_transformers import SentenceTransformerTrainingArguments
args = SentenceTransformerTrainingArguments(
batch_sampler="NO_DUPLICATES_HASHED" # Pre-computes hashes for faster duplicate checking
)GroupByLabelBatchSampler improvements for triplet losses (#3668)
Fixes a critical issue where GroupByLabelBatchSampler produced ~99% single-class batches, causing zero gradients with triplet losses. The sampler now uses round-robin interleaving where each label emits 2 samples per round, with the label visit order reshuffled every round. This guarantees every batch contains multiple distinct labels, each with at least 2 samples.
Transformers v5 compatibility
This release includes full compatibility updates for Transformers v5:
- Compatibility with transformers 5.0.0rc01 and later versions (#3597, #3615)
- Support for T5Gemma and T5Gemma2 models (#3644)
- Transformers v5.2 compatibility for the trainer's
_nested_gathermethod (#3664) - Support for both
warmup_stepsandwarmup_ratiountil Transformers v4 support is dropped (#3645) - Updated CI to test against full Transformers v5 (#3615)
Minor Features
- Add triplets/n-tuple support to AnglE by @tomaarsen in #3609
- Replace
requestsdependency with optionalhttpxdependency by @tomaarsen in #3618 - Specify numpy manually in dependencies by @tomaarsen in #3608
- Support excluding prompt tokens with pooling with left-padding tokenizer by @tomaarsen in #3598
Bug Fixes
- Fix InformationRetrievalEvaluator prediction export when output_path does not exist by @ignasgr in #3659
- Add padding for odd embedding dimensions in tensors (sparse encoders) by @jadermcs in #3623
- Fix IndexError in CrossEncoder
MultipleNegativesRankingLosswhennum_negatives=Noneby @fuutot in #3636 - Fix valid negatives selection in CrossEncoder
MultipleNegativesRankingLossby @fuutot in #3641 - Mention TSDAE incompatibility with transformers v5 by @tomaarsen in #3619
- Fix model card generation with set_transform with new column names by @tomaarsen in #3680
Performance Improvements
- Speed up NoDuplicatesBatchSampler iteration using NumPy arrays and linked lists by @hotchpotch in #3658
Training Script Migrations (v2 to v3)
- Migrate training_batch_hard_trec.py by @omkar-334 in #3624
- Migrate train_ct_from_file.py by @omkar-334 in #3625
- Migrate train_stsb_ct.py by @omkar-334 in #3626
- Migrate train_stsb_ct_improved by @omkar-334 in #3627
- Migrate train_askubuntu_ct-improved.py by @omkar-334 in #3628
- Migrate 2_programming_train_bi-encoder.py by @omkar-334 in #3629
- Migrate train_askubuntu_simcse.py by @omkar-334 in #3630
- Migrate train_simcse_from_file.py by @omkar-334 in #3631
- Migrate training_multi-task-learning.py by @harshitsharma496 in #3632
- Replace http_get with load_dataset - wiki1m_for_simcse and STSbenchmark by @omkar-334 in #3635
- Replace http_get with load_dataset - askubuntu and all-nli by @omkar-334 in #3638
- Update ContrastiveTensionLoss and ContrastiveTensionLossInBatchNegatives by @omkar-334 in #3639
- Migrate train_ct-improved_from_file.py from v2 to v3 by @omkar-334 in #3646
- Migrate train_askubuntu_ct.py from v2 to v3 by @omkar-334 in #3647
- Migrate train_stsb_simcse.py from v2 to v3 by @omkar-334 in #3648
- Update docstring for DenoisingAutoEncoderLoss.py by @omkar-334 in #3652
- Replace model.fit in test files by @omkar-334 in #3653
- Fix: pass batch_size args to CE evaluators by @omkar-334 in #3643
Documentation
- Add sample CLIP training script with datasets & MLFlow by @aardoiz in #3595
- Add Unsloth to Docs by @shimmyshimmer in #3613
- Add tips for adjusting batch size to improve processing speed by @tomaarsen in #3672
- CE trainer: Removed IterableDataset from train and eval dataset type hints by @tomaarsen in #3676
All Changes
- chore: Increment development version for 'main' by @tomaarsen in #3594
- Introduce compatibility with transformers 5.0.0rc01 by @tomaarsen in #3597
- docs: fix typo in custom models: reemain -> remain by @tomaarsen in #3596
- [
feat] Support excluding prompt tokens with pooling with left-padding tokenizer by @tomaarsen in #3598 - Upgrade GitHub Actions for Node 24 compatibility by @salmanmkc in #3600
- Specify numpy manually in dependencies, as it's directly used/imported by @tomaarsen in #3608
- [
tests] Relax the CI branches by @tomaarsen in #3610 - [
compat] Expand test suite to full transformers v5 by @tomaarsen in #3615 - [
deps] Replace requests dependency with optional httpx dependency by @tomaarsen in #3618 - Mention TSDAE incompatibility with transformers v5, update TSDAE snippet by @tomaarsen in #3619
- docs: add sample clip training script with datasets & mlfow implement… by @aardoiz in #3595
- [Fix] Add padding for odd embedding dimensions in tensors (sparse encoders) by @jadermcs in #3623
- [
feat] Add triplets/n-tuple support to AnglE by @tomaarsen in #3609 - Replace
http_getwithload_dataset-wiki1m_for_simcseandSTSbenchmarkby @omkar-334 in #3635 - [
tests] Use 120s HF Hub timeout for tests by @tomaarsen in #3637 - Fix IndexError in
MultipleNegativesRankingLosswhennum_negatives=Noneby @fuutot in #3636 - migrate training_multi-task-learning.py v2to v3 by @harshitsharma496 in #3632
- [feat] Add NO_DUPLICATES_HASHED: optional hashing for NoDuplicatesBatchSampler by @hotchpotch in #3611
- Fix: select valid negatives as in-batch negatives by @fuutot in #3641
- migrate
2_programming_train_bi-encoder.pyfrom v2 to v3 by @omkar-334 in #3629 - migrate
train_simcse_from_file.pyfrom v2 to v3 by @omkar-334 in #3631 - Update
ContrastiveTensionLossandContrastiveTensionLossInBatchNegativesby @omkar-334 in #3639 - Replace
http_getwithload_dataset-askubuntuandall-nliby @omkar-334 in #3638 - fix: pass
batch_sizeargs to CE evaluators by @omkar-334 in #3643 - replace
trecdataset and migratetraining_batch_hard_trec.pyfrom v2 to v3 by @omkar-334 in #3624 - Add Unsloth to Docs by @shimmyshimmer in #3613
- migrate
train_stsb_ct.pyfrom v2 to v3 by @omkar-334 in #3626 - migrate
train_ct_from_file.pyfrom v2 to v3 by @omkar-334 in #3625 - migrate
train_askubuntu_ct-improved.pyfrom v2 to v3 by @omkar-334 in #3628 - migrate
train_stsb_ct_improvedfrom v2 to v3 by @omkar-334 in #3627 - migrate
train_askubuntu_simcse.pyfrom v2 to v3 by @omkar-334 in #3630 - migrate
train_stsb_simcse.pyfrom v2 to v3 by @omkar-334 in #3648 - migrate
train_askubuntu_ct.pyfrom v2 to v3 by @omkar-334 in #3647 - migrate
train_ct-improved_from_file.pyfrom v2 to v3 by @omkar-334 in #3646 - update docstring for
DenoisingAutoEncoderLoss.pyby @omkar-334 in #3652 - Replace
model.fitin test files by @omkar-334 in #3653 - [
feat] Add support for T5Gemma and T5Gemma2 models by @tomaarsen in #3644 - [feat] Refactor MultipleNegativesRankingLoss to support improved contrastive loss from GTE paper by @hotchpotch in #3607
- [
compat] Allow for both warmup_steps and warmup_ratio until transformers v4 support is dropped by @tomaarsen in #3645 - [
feat] Introduce GlobalOrthogonalRegularizationLoss by @tomaarsen in #3654 - fix: correct typo 'seperated' to 'separated' by @thecaptain789 in #3657
- fix: typos by @omkar-334 in #3660
- [
compat] Introduce Transformers v5.2 compatibility: trainer _nested_gather moved by @tomaarsen in #3664 - Fix InformationRetrievalEvaluator prediction export when output_path does not exist by @ignasgr in #3659
- Fix typo in training_batch_hard_trec.py by @tomaarsen in #3669
- [
perf] Speed up NoDuplicatesBatchSampler iteration (NO_DUPLICATES and NO_DUPLICATES_HASHED) by @hotchpotch in #3658 - [
fix]GroupByLabelBatchSamplerto guarantee multi-class batches for triplet losses by @MrLoh in #3668 - [
feat] IntroduceCachedSpladeLossfor memory-efficient SPLADE training by @yjoonjang in #3670 - [
docs] Add tips for adjusting batch size to improve processing speed by @tomaarsen in #3672 - [
docs] CE trainer: Removed IterableDataset from train and eval dataset type hints by @tomaarsen in #3676 - [
loss] Disallow query_to_query/doc_to_doc with partition_mode="per_direction" due to negative loss by @tomaarsen in #3677 - [
feat] Add hardness-weighted contrastive learning to losses by @yjoonjang in #3667 - [
fix] Fix model card generation with set_transform with new column names by @tomaarsen in #3680 - [
tests] Add slow reproduction tests for most common models by @tomaarsen in #3681
New Contributors
- @salmanmkc made their first contribution in #3600
- @aardoiz made their first contribution in #3595
- @jadermcs made their first contribution in #3623
- @fuutot made their first contribution in #3636
- @harshitsharma496 made their first contribution in #3632
- @hotchpotch made their first contribution in #3611
- @shimmyshimmer made their first contribution in #3613
- @thecaptain789 made their first contribution in #3657
- @yjoonjang made their first contribution in #3670
A big thanks to my repeat contributors, a lot of this release originated from your contributions. Much appreciated!
Full Changelog: v5.2.3...v5.3.0