Transformers v5 release notes
- Highlights
- Significant API changes: dynamic weight loading, tokenization
- Backwards Incompatible Changes
- Bugfixes and improvements
Highlights
We are excited to announce the initial release of Transformers v5. This is the first major release in five years, and the release is significant: 800 commits have been pushed to main since the latest minor release. This release removes a lot of long-due deprecations, introduces several refactors that significantly simplify our APIs and internals, and comes with a large number of bug fixes.
We give an overview of our focus for this release in the following blogpost. In these release notes, we'll focus directly on the refactors and new APIs coming with v5.
This release is a release candidate (RC). It is not the final v5 release, and we will push on pypi as a pre-release. This means that the current release is purely opt-in, as installing transformers without specifying this exact release will install the latest version instead (v4.57.3 as of writing).
In order to install this release, please do so with the following:
pip install transformers --preFor us to deliver the best package possible, it is imperative that we have feedback on how the toolkit is currently working for you. Please try it out, and open an issue in case you're facing something inconsistent/a bug.
Transformers version 5 is a community endeavor, and this is the last mile. Let's ship this together!
Significant API changes
Note
👀 Nothing is final and things are still actively in movement. We have a section dedicated to what is planned for future release candidates, yet is known not to work in the RC0. Look for "Disclaimers for the RC0".
We'll be eagerly awaiting your feedback in our GitHub issues!
Tokenization
Just as we moved towards a single backend library for model definition, we want our tokenizers, and the Tokenizer object to be a lot more intuitive. With v5, tokenizer definition is much simpler; one can now initialize an empty LlamaTokenizer and train it directly on your corpus.
Defining a new tokenizer object should be as simple as this:
from transformers import TokenizersBackend, generate_merges
from tokenizers import pre_tokenizers, Tokenizer
from tokenizers.model import BPE
class Llama5Tokenizer(TokenizersBackend):
def __init__(self, unk_token="<unk>",bos_token="<s>", eos_token="</s>", vocab=None, merges=None ):
if vocab is None:
self._vocab = {
str(unk_token): 0,
str(bos_token): 1,
str(eos_token): 2,
}
else:
self._vocab = vocab
if merges is not None:
self._merges = merges
else:
self._merges = generate_merges(filtered_vocab)
self._tokenizer = Tokenizer(
BPE(vocab=self._vocab, merges=self._merges, fuse_unk=True)
)
self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
replacement="▁", prepend_scheme=_get_prepend_scheme(self.add_prefix_space, self), split=False
)
super().__init__(
tokenizer_object=self._tokenizer,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
)Once the tokenizer is defined as above, you can load it with the following: Llama5Tokenizer(). Doing this returns you an empty, trainable tokenizer that follows the definition of the authors of Llama5 (it does not exist yet 😉).
The above is the main motivation towards refactoring tokenization: we want tokenizers to behave similarly to models: trained or empty, and with exactly what is defined in their class definition.
Backend Architecture Changes: moving away from the slow/fast tokenizer separation
Up to now, transformers maintained two parallel implementations for many tokenizers:
- "Slow" tokenizers (
tokenization_<model>.py) - Python-based implementations, often using SentencePiece as the backend. - "Fast" tokenizers (
tokenization_<model>_fast.py) - Rust-based implementations using the 🤗 tokenizers library.
In v5, we consolidate to a single tokenizer file per model: tokenization_<model>.py. This file will use the most appropriate backend available:
- TokenizersBackend (preferred): Rust-based tokenizers from the 🤗 tokenizers library. In general it provides optimal performance, but it also offers a lot more features that are commonly adopted across the ecosystem:
- handling additional tokens
- a full python API for setting and updating
- automatic parallelization,
- automatic offsets
- customization
- training
- SentencePieceBackend: for tokenizers requiring the
sentencepiecelibrary. It inherits fromPythonBackend. - PythonBackend: a Python implementations of the features provided by
tokenizers. Basically allows adding tokens. - MistralCommonBackend: relies on
MistralCommon's tokenization library. (Previously known as theMistralCommonTokenizer)
The AutoTokenizer automatically selects the appropriate backend based on available files and dependencies. This is transparent, you continue to use AutoTokenizer.from_pretrained() as before. This allows transformers to be future-proof and modular to easily support future backends.
Defining a tokenizers outside of the existing backends
We enable users and tokenizer builders to define their own tokenizers from top to bottom. Tokenizers are usually defined using a backend such as tokenizers, sentencepiece or mistral-common, but we offer the possibility to design the tokenizer at a higher-level, without relying on those backends.
To do so, you can import the PythonBackend (which was previously known as PreTrainedTokenizer). This class encapsulates all the logic related to added tokens, encoding, and decoding.
If you want something even higher up the stack, then PreTrainedTokenizerBase is what PythonBackend inherits from. It contains the very basic tokenizer API features:
encodedecodevocab_sizeget_vocabconvert_tokens_to_idsconvert_ids_to_tokensfrom_pretrainedsave_pretrained- among a few others
API Changes
1. Direct tokenizer initialization with vocab and merges
Starting with v5, we now enable initializing blank, untrained tokenizers-backed tokenizers:
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer()This tokenizer will therefore follow the definition of the LlamaTokenizer as defined in its class definition. It can then be trained on a corpus as can be seen in the tokenizers documentation.
These tokenizers can also be initialized from vocab and merges (if necessary), like the previous "slow" tokenizers:
from transformers import LlamaTokenizer
vocab = {"<unk>": 0, "<s>": 1, "</s>": 2, "hello": 3, "world": 4}
merges = [("h", "e"), ("l", "l"), ("o", " ")]
tokenizer = LlamaTokenizer(vocab=vocab, merges=merges)This tokenizer will behave as a Llama-like tokenizer, with an updated vocabulary. This allows comparing different tokenizer classes with the same vocab; therefore enabling the comparison of different pre-tokenizers, normalizers, etc.
⚠️ The vocab_file (as in, a path towards a file containing the vocabulary) cannot be used to initialize the LlamaTokenizer as loading from files is reserved to the from_pretrained method.
2. Simplified decoding API
The batch_decode and decode methods have been unified to reflect behavior of the encode method. Both single and batch decoding now use the same decode method. See an example of the new behavior below:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small")
inputs = ["hey how are you?", "fine"]
tokenizer.decode(tokenizer.encode(inputs))Gives:
- 'hey how are you?</s> fine</s>'
+ ['hey how are you?</s>', 'fine</s>']We expect encode and decode to behave, as two sides of the same coin: encode, process, decode, should work.
Note
A common use-case would be: encode, model.generate, decode. However, using generate would return list[list[int]], which would then be incompatible with decode.
3. Unified encoding API
The encode_plus method is deprecated in favor of the single __call__ method.
4. apply_chat_template returns BatchEncoding
Previously, apply_chat_template returned input_ids for backward compatibility. Starting with v5, it now consistently returns a BatchEncoding dict like other tokenizer methods.
# v5
messages = [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"}
]
# Now returns BatchEncoding with input_ids, attention_mask, etc.
outputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
print(outputs.keys()) # dict_keys(['input_ids', 'attention_mask'])5. Removed legacy configuration file saving:
We simplify the serialization of tokenization attributes:
special_tokens_map.json- special tokens are now stored intokenizer_config.json.added_tokens.json- added tokens are now stored intokenizer.json.added_tokens_decoderis only stored when there is notokenizer.json.
When loading older tokenizers, these files are still read for backward compatibility, but new saves use the consolidated format. We're gradually moving towards consolidating attributes to fewer files so that other libraries and implementations may depend on them more reliably.
6. Model-Specific Changes
Several models that had identical tokenizers now import from their base implementation:
- LayoutLM → uses BertTokenizer
- LED → uses BartTokenizer
- Longformer → uses RobertaTokenizer
- LXMert → uses BertTokenizer
- MT5 → uses T5Tokenizer
- MVP → uses BartTokenizer
These modules will eventually be removed altogether.
Removed T5-specific workarounds
The internal _eventually_correct_t5_max_length method has been removed. T5 tokenizers now handle max length consistently with other models.
Testing Changes
A few testing changes specific to tokenizers have been applied:
- Model-specific tokenization test files now focus on integration tests.
- Common tokenization API tests (e.g.,
add_tokens,encode,decode) are now centralized and automatically applied across all tokenizers. This reduces test duplication and ensures consistent behavior
For legacy implementations, the original BERT Python tokenizer code (including WhitespaceTokenizer, BasicTokenizer, etc.) is preserved in bert_legacy.py for reference purposes.
7. Deprecated / Modified Features
Special Tokens Structure:
SpecialTokensMixin: Merged intoPreTrainedTokenizerBaseto simplify the tokenizer architecture.special_tokens_map: Now only stores named special token attributes (e.g.,bos_token,eos_token). Useextra_special_tokensfor additional special tokens (formerlyadditional_special_tokens).all_special_tokensincludes both named and extra tokens.
# v4
tokenizer.special_tokens_map # Included 'additional_special_tokens'
# v5
tokenizer.special_tokens_map # Only named tokens
tokenizer.extra_special_tokens # Additional tokensspecial_tokens_map_extendedandall_special_tokens_extended: Removed. AccessAddedTokenobjects directly from_special_tokens_mapor_extra_special_tokensif needed.additional_special_tokens: Still accepted for backward compatibility but is automatically converted toextra_special_tokens.
Deprecated Methods:
sanitize_special_tokens(): Already deprecated in v4, removed in v5.prepare_seq2seq_batch(): Deprecated; use__call__()withtext_targetparameter instead.
# v4
model_inputs = tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts, max_length=128)
# v5
model_inputs = tokenizer(src_texts, text_target=tgt_texts, max_length=128, return_tensors="pt")
model_inputs["labels"] = model_inputs.pop("input_ids_target")BatchEncoding.words(): Deprecated; useword_ids()instead.
Removed Methods:
create_token_type_ids_from_sequences(): Removed from base class. Subclasses that need custom token type ID creation should implement this method directly.clean_up_tokenization(): Removed from base class. Now defined at model class level for models that need it (e.g., PLBart, CLVP, Wav2Vec2).prepare_for_model(),build_inputs_with_special_tokens(),truncate_sequences(): Moved fromtokenization_utils_base.pytotokenization_python.pyforPythonBackendtokenizers.TokenizersBackendprovides model-ready input viatokenize()andencode(), so these methods are no longer needed in the base class._switch_to_input_mode(),_switch_to_target_mode(),as_target_tokenizer(): Removed from base class. Use__call__()withtext_targetparameter instead.
# v4
with tokenizer.as_target_tokenizer():
labels = tokenizer(tgt_texts, ...)
# v5
labels = tokenizer(text_target=tgt_texts, ...)parse_response(): Removed from base class.
Disclaimers for the RC0
PEFT + MoE:
Because we are switching from the naive MOE (nn.ModuleList for experts) we currently have an issue with MoEs that have adapters. For more details see #42491 (comment).
We aim for this to be fixed and released in a following release candidate in the week that follows RC0.
Tensor parallel and Expert parallel + MoE
We are streamlining the MoE support with vLLM; while this is being implemented, tensor parallelism and expert parallelism aren't working as expected.
This is known and actively being worked on.
We aim for this to be fixed and released in a following release candidate in the week that follows RC0.
Custom pretrained models:
For anyone inheriting from a transformers PreTrainedModel, the weights are automatically initialized with the common scheme:
@torch.no_grad()
def _init_weights(self, module):
"""
Initialize the weights. This is quite general on purpose, in the spirit of what we usually do. For more complex
initialization scheme, it should be overridden by the derived `PreTrainedModel` class. In case a model adds an explicit
`nn.Parameter`, this method should also be overridden in order to initialize it correctly.
"""
if hasattr(self.config, "initializer_range"):
std = self.config.initializer_range or 0.02
elif hasattr(self.config, "init_std"):
std = self.config.init_std
elif hasattr(self.config, "initializer_factor"):
std = self.config.initializer_factor
else:
# 0.02 is the standard default value across the library
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
if getattr(module, "weight", None) is not None:
init.normal_(module.weight, mean=0.0, std=std)
if getattr(module, "bias", None) is not None:
init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
if getattr(module, "weight", None) is not None:
init.normal_(module.weight, mean=0.0, std=std)
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
init.zeros_(module.weight[module.padding_idx])
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
# We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
# between modelings (because they are prefixed with the model name)
elif (
isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
or "LayerNorm" in module.__class__.__name__
or "RMSNorm" in module.__class__.__name__
):
# Norms can exist without weights (in which case they are None from torch primitives)
if hasattr(module, "weight") and module.weight is not None:
init.ones_(module.weight)
if hasattr(module, "bias") and module.bias is not None:
init.zeros_(module.bias)If you want to avoid that, for now you should just do:
class CustomModel(Qwen3VLForConditionalGeneration):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action_head = nn.Linear(1024, 7)
self.positional_embedding = nn.Parameter(torch.randn(16, 1152))
self.post_init()
def _init_weights(self, module):
pass There is a tracker for that here: #42418.
Library-wide changes with lesser impact
use_auth_token
The use_auth_token argument/parameter is deprecated in favor of token everywhere.
You should be able to search and replace use_auth_token with token and get the same logic.
Linked PR: #41666
Attention-related features
We decided to remove some features for the upcoming v5 as they are currently only supported in a few old models and no longer integrated in current model additions. It's recommended to stick to v4.x in case you need them. Following features are affected:
- No more head masking, see #41076. This feature allowed to turn off certain heads during the attention calculation and only worked for eager.
- No more relative positional biases in Bert-like models, see #41170. This feature was introduced to allow relative position scores within attention calculations (similar to T5). However, this feature is barely used in official models and a lot of complexity instead. It also only worked with eager.
- No more head pruning, see #41417 by @gante. As the name suggests, it allowed to prune heads within your attention layers.
Updates to supported torch APIs
We dropped support for two torch APIs:
Those APIs were deprecated by the PyTorch team, and we're instead focusing on the supported APIs dynamo and export.
Quantization changes
We clean up the quantization API in transformers, and significantly refactor the weight loading as highlighted
above.
We drop support for two quantization arguments that have been deprecated for some time:
load_in_4bitload_in_8bit
We remove them in favor of the quantization_config argument which is much more complete. As an example, here is how
you would load a 4-bit bitsandbytes model using this argument:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B",
device_map="auto",
quantization_config=quantization_config
)Configuration
- Methods to init a nested config such as
from_xxx_configare deleted. Configs can be init from the__init__method in the same way. See #41314. - It is no longer possible to load a config class from a URL file. Configs must be loaded from either a local path or a repo on the Hub. See #42383.
- All parameters for configuring model's rotary embedding are now stored under
mode.rope_parameters, including therope_thetaandrope_type. Model'sconfig.rope_parametersis a simple dictionaty in most cases, and can also be a nested dict in special cases (i.e. Gemma3 and ModernBert) with different rope parameterization for each layer type. Trying to getconfig.rope_thetawill throw an attribute error from now on. See #39847 and #42255 - Qwen-VL family configuration is in a nested format and trying to access keys directly will throw an error (e.g.
config.vocab_size). Users are expected to access keys from their respective sub-configs (config.text_config.vocab_size). - Configurations of non-generative models (any model that doesn't call
model.generate()) will no longer have ageneration_configandmodel.config.generation_configwill throw an attribute error.
Processing
Tokenization
- Slow tokenizer files (aka:
tokenization_<model>.py) will be removed in favor of using fast tokenizer filestokenization_<model>_fast.py--> will be renamed totokenization_<model>.py. As fast tokenizers are 🤗tokenizers- backend, they include a wider range of features that are maintainable and reliable. - Other backends (sentence piece, tokenizers, etc.) will be supported with a light layer if loading a fast tokenizer fails
- Remove legacy files like special_tokens_map.json and added_tokens.json
- Remove _eventually_correct_t5_max_length
encode_plus-->__call__batch_decode-->decode
apply_chat_template by default returns naked input_ids rather than a BatchEncoding dict.
This was inconvenient - it should return a BatchEncoding dict like tokenizer.__call__(), but we were stuck with
it for backward compatibility. The method now returns a BatchEncoding.
Linked PRs:
Processing classes
- In processing classes each attribute will be serialized under
processor_config.jsonas a nested dict, instead of serializing attributes in their own config files. Loading will be supported for all old format processors (#41474) XXXFeatureExtractorsclasses are completely removed in favor ofXXXImageProcessorclass for all vision models (#41174)- Minor change:
XXXFastImageProcessorKwargsis removed in favor ofXXXImageProcessorKwargswhich will be shared between fast and slow processors (#40931)
Modeling
- Some
RotaryEmbeddingslayers will start returning a dict of tuples, in case the model uses several RoPE configurations (Gemma2, ModernBert). Each value will be a tuple of "cos, sin" per RoPE type. - Config attribute for
RotaryEmbeddingslayer will be unified and accessed viaconfig.rope_parameters. Config attr forrope_thetamight not be accessible anymore for some models, and instead will be inconfig.rope_parameters['rope_theta']. BC will be supported for a while as much as possible, and in the near future we'll gradually move to the new RoPE format (#39847) - Vision Language models will not have a shortcut access to its language and vision component from the generative model via
model.language_model. It is recommended to either access the module withmodel.model.language_modelormodel.get_decoder(). See #42156
Generate
- Old, deprecated output type aliases were removed (e.g.
GreedySearchEncoderDecoderOutput). We now only have 4 output classes built from the following matrix: decoder-only vs encoder-decoder, uses beams vs doesn't use beams (#40998) - Removed deprecated classes regarding decoding methods that were moved to the Hub due to low usage (constraints and beam scores) (#41223)
- If
generatedoesn't receive any KV Cache argument, the default cache class used is now defined by the model (as opposed to always beingDynamicCache) (#41505) - Generation parameters are no longer accessible via model's config. If generation paramaters are serialized in
config.jsonfor any old model, it will be loaded back into model's generation config. Users are expected to access or modify generation parameters only withmodel.generation_config.do_sample = True.
Trainer
Removing arguments without deprecation cycle in TrainingArguments due to low usage
mp_parameters-> legacy param that was later on added to sagemaker trainer_n_gpu-> not intended for users to set, we will initialize it correctly instead of putting it in theTrainingArgumentsoverwrite_output_dir- > replaced byresume_from_checkpointand it was only used in examples script, no impact on Trainer.logging_dir-> only used for tensorboard, setTENSORBOARD_LOGGING_DIRenv var insteadjit_mode_eval-> useuse_torch_compileinstead as torchscript is not recommended anymoretpu_num_cores-> It is actually better to remove it as it is not recommended to set the number of cores. By default, all tpu cores are used . SetTPU_NUM_CORESenv var insteadpast_index-> it was only used for a very small number of models that have special architecture like transformersxl + it was not documented at all how to train those modelray_scope-> only for a minor arg for ray integration. SetRAY_SCOPEvar env insteadwarmup_ratio-> usewarmup_stepinstead. We combined both args together by allowing passing float values inwarmup_step.
Removing deprecated arguments in TrainingArguments
fsdp_min_num_paramsandfsdp_transformer_layer_cls_to_wrap-> usefsdp_configtpu_metrics_debug->debugpush_to_hub_token->hub_tokenpush_to_hub_model_idandpush_to_hub_organization->hub_model_idinclude_inputs_for_metrics->include_for_metricsper_gpu_train_batch_size->per_device_train_batch_sizeper_gpu_eval_batch_size->per_device_eval_batch_sizeuse_mps_device-> mps will be used by default if detectedfp16_backendandhalf_precision_backend-> we will only rely on torch.amp as everything has been upstream to torchno_cuda->use_cpuinclude_tokens_per_second->include_num_input_tokens_seenuse_legacy_prediction_loop-> we only useevaluation_loopfunction from now on
Removing deprecated arguments in Trainer
tokenizerin initialization ->processing_classmodel_pathin train() ->resume_from_checkpoint
Removed features for Trainer
- sigpot integration for hp search was removed as the library was archived + the api stopped working
- drop support for sagemaker API <1.10
- bump accelerate minimum version to 1.1.0
New defaults for Trainer
use_cachein the model config will be set toFalse. You can still change the cache value throughTrainingArgumentsusel_cacheargument if needed.
Pipeline
- Image text to text pipelines will no longer accept images as a separate argument along with conversation chats. Image data has to be embedded in the chat's "content" field. See #42359
PushToHubMixin
- removed deprecated
organizationandrepo_urlfromPushToHubMixin. You must pass arepo_idinstead. - removed
ignore_metadata_errorsfromPushToMixin. In practice if we ignore errors while loading the model card, we won't be able to push the card back to the Hub so it's better to fail early and not provide the option to fail later. push_to_hubdo not accept**kwargsanymore. All accepted parameters are explicitly documented.- arguments of
push_to_hubare now keyword-only to avoid confusion. Onlyrepo_idcan be positional since it's the main arg. - removed
use_temp_dirargument frompush_to_hub. We now use a tmp dir in all cases.
Linked PR: #42391.
CLI
The deprecated transformers-cli ... command was deprecated, transformers ... is now the only CLI entry point.
transformers CLI has been migrated to Typer, making it easier to maintain + adding some nice features out of
the box (improved --help section, autocompletion).
Biggest breaking change is in transformers chat. This command starts a terminal UI to interact with a chat model.
It used to also be able to start a Chat Completion server powered by transformers and chat with it. In this revamped
version, this feature has been removed in favor of transformers serve. The goal of splitting transformers chat
and transformers serve is to define clear boundaries between client and server code. It helps with maintenance
but also makes the commands less bloated. The new signature of transformers chat is:
Usage: transformers chat [OPTIONS] BASE_URL MODEL_ID [GENERATE_FLAGS]...
Chat with a model from the command line.
It works hand in hand with transformers serve, which means that if transformers serve is running on its default endpoint, transformers chat can be launched as follows:
transformers chat HuggingFaceTB/SmolLM3-3BIt can however use any OpenAI API compatible HTTP endpoint:
transformers chat HuggingFaceTB/SmolLM3-3B https://router.huggingface.co/v1Linked PRs:
Removal of the run method
The transformers run (previously transformers-cli run) is an artefact of the past, was not documented nor tested,
and isn't part of any public documentation. We're removing it for now and ask you to please let us know in case
this is a method you are using; in which case we should bring it back with better support.
Linked PR: #42447
Environment variables
- Legacy environment variables like
TRANSFORMERS_CACHE,PYTORCH_TRANSFORMERS_CACHE, andPYTORCH_PRETRAINED_BERT_CACHEhave been removed. Please useHF_HOMEinstead. - Constants
HUGGINGFACE_CO_EXAMPLES_TELEMETRY,HUGGINGFACE_CO_EXAMPLES_TELEMETRY,HUGGINGFACE_CO_PREFIX, andHUGGINGFACE_CO_RESOLVE_ENDPOINThave been removed. Please usehuggingface_hub.constants.ENDPOINTinstead.
Linked PR: #42391.
Requirements update
transformers v5 pins the huggingface_hub version to >=1.0.0. See this migration guide to learn more about this major release. Here are to main aspects to know about:
- switched the HTTP backend from
requeststohttpx. This change was made to improve performance and to support both synchronous and asynchronous requests the same way. If you are currently catchingrequests.HTTPErrorerrors in your codebase, you'll need to switch tohttpx.HTTPError. - related to 1., it is not possible to set proxies from your script. To handle proxies, you must set the
HTTP_PROXY/HTTPS_PROXYenvironment variables hf_transferand thereforeHF_HUB_ENABLE_HF_TRANSFERhave been completed dropped in favor ofhf_xet. This should be transparent for most users. Please let us know if you notice any downside!
typer-slim has been added as required dependency, used to implement both hf and transformers CLIs.
New model additions in v5
CWM
The Code World Model (CWM) model was proposed in CWM: An Open-Weights LLM for Research on Code Generation with World Models by Meta FAIR CodeGen Team. CWM is an LLM for code generation and reasoning about code that has, in particular, been trained to better represent and reason about how code and commands affect the state of a program or system. Specifically, we mid-trained CWM on a large number of observation-action trajectories from Python execution traces and agentic interactions in containerized environments. We post-trained with extensive multi-task RL in verifiable coding, math, and multi-turn software engineering environments.
- Add Code World Model (CWM) by @jacobkahn in #41199
SAM3
SAM3 (Segment Anything Model 3) was introduced in SAM 3: Segment Anything with Concepts.
The SAM3 addition adds four new architectures:
- Sam3
- Sam3Tracker
- Sam3TrackerVideo
- Sam3Video
SAM3 performs Promptable Concept Segmentation (PCS) on images. PCS takes text and/or image exemplars as input (e.g., "yellow school bus"), and predicts instance and semantic masks for every single object matching the concept.
Sam3Tracker and Sam3TrackerVideo perform Promptable Visual Segmentation (PVS) on images. PVS takes interactive visual prompts (points, boxes, masks) or text inputs to segment a specific object instance per prompt. This is the task that SAM 1 and SAM 2 focused on, and SAM 3 improves upon it. Sam3Tracker and Sam3TrackerVideo are updated versions of SAM2 Video that maintain the same API while providing improved performance and capabilities.
SAM3 Video performs Promptable Concept Segmentation (PCS) on videos. PCS takes text as input (e.g., "yellow school bus"), and predicts instance and semantic masks for every single object matching the concept, while preserving object identities across video frames. The model combines a detection module (SAM3) with a tracking module (SAM2-style tracker) to enable robust object tracking across video frames using text prompts.
- Add SAM3 to 🤗 Transformers by @yonigozlan in #42285
LFM2 MoE
LFM2-MoE is a Mixture-of-Experts (MoE) variant of LFM2. The LFM2 family is optimized for on-device inference by combining short‑range, input‑aware gated convolutions with grouped‑query attention (GQA) in a layout tuned to maximize quality under strict speed and memory constraints.
LFM2‑MoE keeps this fast backbone and introduces sparse MoE feed‑forward networks to add representational capacity without significantly increasing the active compute path. The first LFM2-MoE release is LFM2-8B-A1B, with 8.3B total parameters and 1.5B active parameters. The model excels in quality (comparable to 3-4B dense models) and speed (faster than other 1.5B class models).
- [Model] Lfm2Moe by @paulpak58 in #41401
VideoLlama 3
The VideoLLaMA3 model is a major update to VideoLLaMA2 from Alibaba DAMO Academy.
AudioFlamingo 3
Audio Flamingo 3 (AF3) is a fully open large audio–language model designed for robust understanding and reasoning over speech, environmental sounds, and music. AF3 pairs a Whisper-style audio encoder with a causal language model and performs replace-in-place audio–text fusion: the processor aligns post-pool audio frames to a dedicated placeholder token and the model replaces those token slots with projected audio embeddings during the forward pass.
The model checkpoint is available at: nvidia/audio-flamingo-3-hf
Highlights:
- Unified audio encoder across speech, sound, and music.
- Long-audio support via windowing and post-pool alignment (up to 10 minutes maximum). The model processes audio in 30-second windows with a hard limit of 20 windows (10 minutes total). Audio longer than 10 minutes will be truncated.
- Deterministic fusion that preserves sequence length by replacing audio placeholder tokens with audio embeddings.
Nanochat
NanoChat is a compact decoder-only transformer model designed for educational purposes and efficient training. The model features several fundamental architectural innovations which are common in modern transformer models. Therefore, it is a good model to use as a starting point to understand the principles of modern transformer models. NanoChat is a variant of the Llama architecture, with simplified attention mechanism and normalization layers.
- [MODEL] Nanochat implementation by @burtenshaw in #41634
Bugfixes and improvements
JetMoeFix jetmoe after #40132 by @ArthurZucker in #41324- Fixed tiny incorrect import in
gemma3by @Sai-Suraj-27 in #41354 - Rope for Qwen2--5-vl by @zucchini-nlp in #41173
- 🚨 Bump to Python 3.10 and rework how we check 3rd-party libraries existence by @Cyrilvallez in #41268
- Standardize
PretrainedConfigtoPreTrainedConfigby @Cyrilvallez in #41300 - Fix trainer for py3.9 by @SunMarc in #41359
- Check model inputs - hidden states by @zucchini-nlp in #40994
- [
ModularChecker] QOL for the modular checker by @ArthurZucker in #41361 - Fixing a typo for BLT model by @Narsil in #41325
- 🚨 [
v5] Remove relative position embeddings (for bert like models) by @vasqu in #41170 - Fix typo in model proposal template by @Ombucha in #41352
- Better typehints for
apply_chat_templateby @Samoed in #41355 - 🚨 Remove BetterTransformer by @Cyrilvallez in #41367
- [testing] update
test_longcat_generation_cpuby @ydshieh in #41368 - Fix flash_attention.py: wrong argument passing for attn_implementation by @TKONIY in #41347
- Use canonical get_size_with_aspect_ratio (with max_size) from transformers.image_transforms to fix #37939 by @sonianuj287 in #41284
- Fixes in check_model_inputs, GPTBigCodeModel and ImageGPTModel by @IlyasMoutawwakil in #40811
- Remove unnecessary list comprehension by @cyyever in #41305
- make some ut cases pass on xpu w/ latest torch by @yao-matrix in #41337
- Remove unused function patameters by @cyyever in #41358
- [
CB] Refactors the way we access paged by @ArthurZucker in #41370 - serve: add non-streaming mode to /v1/responses; stream event parity; remove placeholder logprobs by @antznette1 in #41353
- Update from pretrained error when loading by @ArthurZucker in #33380
- [
v5] Sync Bert and Bart eager attention by @vasqu in #41248 - fix asr ut failures by @yao-matrix in #41332
- fix resample in asr pipeline by @yhzx233 in #41298
- Correct numerical regression in vision embeddings by @i3hz in #41374
- [kernels] Kernel Config by @MekkCyber in #41232
- [Cache] lfm2 cache: allocate empty kv layers during init by @paulpak58 in #41396
- Fix test for model with dotted name and relative imports by @st81 in #41343
- Prefer raising
TypeErrorexception for invalid type by @Sai-Suraj-27 in #41346 - [v5] Bump accelerate to 1.1.0 by @SunMarc in #41234
- Fix incorrect assignment in
update_device_mapfor GPTQ quantizer by @Sai-Suraj-27 in #41328 - [v5] Delete left traces of feature extractor by @zucchini-nlp in #41321
- Remove deprecation warning by @Cyrilvallez in #41425
- Fix overriding common_kwargs defaults in processor calls by @yonigozlan in #41381
- v5 dev version by @LysandreJik in #41436
- Tiny Cleanup - Removed duplicate class field definition's by @Sai-Suraj-27 in #41293
- 🚨🚨 Remove all traces of legacy cache format by @Cyrilvallez in #41378
- 🚨 [v5] Prune
prune_headsby @gante in #41417 - [v5] Bump min version of bitsandbytes to 0.46.1 by @SunMarc in #41283
- Fixing comments in init file by @MekkCyber in #41414
- Use accelerator API to free device memory by @cyyever in #41195
- enable new model uts to xpu and fix some failures on xpu by @yao-matrix in #41386
- [torchao] Add regex support for ModuleFqnToConfig by @jerryzh168 in #41242
- 🤦 CB nit! by @ArthurZucker in #41413
- Remove Python 3.9 classifier by @cyyever in #41410
- [
JetMoe] Fix KV head repetition and padding free by @vasqu in #41423 - [testing] Fix
JetMoeIntegrationTestby @ydshieh in #41377 - Add Top-H decoding (entropy-bounded truncation) as a LogitsWarper for text generation by @ErfanBaghaei in #40837
- Validate processing kwargs with @strict from huggingface_hub by @zucchini-nlp in #40793
- Update hqq.md by @prathamesh-chavan-22 in #41452
- enable some falcon-mamba uts on xpu by @yao-matrix in #41428
- Fix generate outputs and simplify cache tests by @Cyrilvallez in #41440
- Fix doc by @Cyrilvallez in #41457
- 🚨 [v5] Rename left traces of
past_key_valuein BERT-like models by @zucchini-nlp in #41448 - Subconfig is a class attribute by @zucchini-nlp in #41308
- [v5] rm
utils/tf_ops/by @gante in #41402 - Update GLM-4.1V MMRope implementation by @zRzRzRzRzRzRzR in #41182
- [kernels] Cleanup deta kernel by @MekkCyber in #41470
- 🚨 [v5] Rendundant code in nested configs by @zucchini-nlp in #41314
- Remove KERAS_NLP_IMPORT_ERROR by @cyyever in #41468
- Fix auto model configuration for encoder of perceptionlm by @fschlatt in #41464
- Fix tests fsdp by @SunMarc in #41422
- Import Callable from collections.abc by @cyyever in #41130
- Pickle - part 2 by @ydshieh in #41476
- Remove infer_device by @cyyever in #41088
- Change RT-Detr docs to reflect fixed 640x640 input size by @konstantinos-p in #41364
- Cleaning hub kernels by @MekkCyber in #41477
- [v5] remove load_in_4bit and load_in_8bit by @SunMarc in #41287
- 🚨 [
Attention Masks] Bidirectional masks for encoder and encoder-decoder models by @vasqu in #41265 - [Fix] Fix test file error by @YangKai0616 in #40973
- enhance patched_tearDown to support python 3.11+ by @yao-matrix in #41429
- RT-Detr correct 2d positional embeddings for non-square images by @konstantinos-p in #41380
- Fix bnb fsdp loading for pre-quantized checkpoint by @SunMarc in #41415
- Remove SigOpt by @SunMarc in #41479
- Remove
past_indexby @SunMarc in #41384 - Remove deprecated args in Trainer for v5 by @SunMarc in #41404
- Update GLM-4.6 doc by @zRzRzRzRzRzRzR in #41471
report_todefault changed to "none" + cleaning deprecated env var by @SunMarc in #41375- deprecate
overwrite_output_dirby @SunMarc in #41323 - [
CI] Fix copies on main by @vasqu in #41486 - [Trainer] deprecate ray scope by @SunMarc in #41403
- deprecate
jit_mode_evalby @SunMarc in #41376 - Remove
local_rankarg fromTrainingArgumentsby @SunMarc in #41382 - Update philosophy by @molbap in #41438
- Remove DISABLE_KERNEL_MAPPING flag by @MekkCyber in #41475
- Streaming should be handled at the request-level rather than at the istance level by @LysandreJik in #41444
- fix bnb model loading by @jiqing-feng in #41499
- [kernels] Remove RWKV kernel finally ! by @MekkCyber in #41493
- [kernels] rm yoso kernel by @MekkCyber in #41495
- Try to remove
pickle-BloomTokenizerFastby @ydshieh in #41466 - Fixed tiny incorrect imports in
glm4vby @Sai-Suraj-27 in #41483 - [Parakeet] unnecessary warning & auto mapping by @eustlb in #41412
- [causallm tester] automate pipeline mappings + bloom tests by @gante in #41318
- Fix some tests by @Cyrilvallez in #41503
- fix gemma3n case failure by @yao-matrix in #41426
- [voxtral] language detection + skipping lang:xx by @eustlb in #41225
- Set
truncationtoFalsein Qwen3Omni to avoid default truncation by @BakerBunker in #41473 - [QoL] modular conversion shows LoC saved by @molbap in #41500
- More trainer cleaning by @SunMarc in #41489
- Bump to hfh 1.0.0.rc5 to fix test by @Wauplin in #41508
- Revert
local_rankdeletion and some cleaning by @SunMarc in #41504 - Fix detectron2 import by @Cyrilvallez in #41510
- add Trainer import to .md in appropriate cell block for training.ipynb transformers_doc by @benkeene in #41484
- Remove outdated flags by @Cyrilvallez in #41512
- remove
tpu_num_coresby @SunMarc in #41383 - Allow optuna's catch kwargs passthrough by @nicha-api in #41496
- Fix Latex typesetting in documentation by @cyyever in #41177
- [testing] reduce runtime of
HunYuanMoEV1IntegrationTest:test_model_generationby @ydshieh in #41373 - [Qwen3VL] fix: hidden_states in place modification error by @HollowMan6 in #41535
- Add MLlama fast image processor by @yonigozlan in #41391
- Fixed Type-hints in function defintions by @Sai-Suraj-27 in #41525
- [SAM] Fix typing hints by @zucchini-nlp in #41506
- Restore cuda graphs to continuous batching by @remi-or in #41421
- Add AMD developer cloud support by @fan-amd in #41126
- Enable modular files from other libraries by @regisss in #41372
- 🚨 [v5]
generatedelegates default cache initialization to the model by @gante in #41505 - Fixed typos and formatting by @julian-st in #34215
- Add VideoMAE video processor by @Aki-07 in #41534
- [
from_pretrained] Small refactorfrom_pretrained: move around unrelated stuff by @ArthurZucker in #41445 - Remove references to AutoModelForVision2Seq by @Rocketknight1 in #41513
- [Qwen3VL] fix device mismatch error for FSDP2 training by @HollowMan6 in #41536
- Patch MistralCommonTokenizer by @juliendenize in #41439
- Fix an import error with PreTrainModel by @remi-or in #41571
- [Qwen3VLMoe] Fixed: Expected self.dtype to be equal to src.dtype - routing_weights casting by @danielquintas8 in #41420
- [kernels] rm mra kernels by @MekkCyber in #41507
- delete some tokenizer tests using pickle by @ydshieh in #41514
- Add DINOv3Backbone for ConvNext variant by @merveenoyan in #40651
- Add conditional checks to _check_and_adjust_attn_implementation() by @zheliuyu in #41542
- add rmsnorm kernels support for Intel XPU by @kaixuanliu in #41563
- Revert "add rmsnorm kernels support for Intel XPU" by @MekkCyber in #41579
- [VisionEncoderDecoderModel] Update loss function by @NielsRogge in #40863
- Add iter to DynamicCache by @remi-or in #41569
- Revert some breaking changes bnb by @SunMarc in #41581
- Fix typsetting and content of llm_tutorial_optimization.md by @cyyever in #41172
- Gemma3 fixes by @remi-or in #41572
- Benchmark overhaul by @remi-or in #41408
- Enable non-streaming mode in
transformers serveby @LysandreJik in #41446 - [device_map] Accelerate loading by computing device_map much faster by @Cyrilvallez in #41548
- Add
logits_to_keepto many older CausalLM models by @philiproeleveld in #41335 - fix some case failures lead by "
torch.compilerecompiled part of th… by @sywangyi in #41558 - remove ray_scope and check_quantized_param by @SunMarc in #41587
- Update issue template by @SunMarc in #41573
- [
Docs] Fix changed references by @vasqu in #41614 - Import
expand_device_mapinstead of redefining it by @Cyrilvallez in #41608 - Fix trainer simple tests by @SunMarc in #41449
- More markdown file fixes by @cyyever in #41599
- torch 2.9 don't ❤️ torchcodec 💔 by @ydshieh in #41610
- Update a dataset reop link by @ydshieh in #41618
- Add fast path for bidirectional mask creation to fix regression by @i3hz in #41586
- enable sdpa enable gqa logic for Ascend NPU by @FightingZhen in #41601
- Fix video processing channel format by @zucchini-nlp in #41603
- [chat template] update when "push_to_hub" by @zucchini-nlp in #39815
- Remove the head masking block in some vision models by @ydshieh in #41620
- Remove deprecated code by @SunMarc in #41616
- Fix quantization base class by @SunMarc in #41613
- [docs] Duplicate entry by @stevhliu in #41591
- Update executorch.md by @jackzhxng in #41582
- Add Backbone API fine-tuning tutorial by @merveenoyan in #41590
- 🚨 [v5] Toggle the serialization format in processors by @zucchini-nlp in #41474
- Add aux loss for GLM-4.5V by @zRzRzRzRzRzRzR in #41564
- Allow passing
tp_planinfrom_pretraineddirectly by @Cyrilvallez in #41435 - Fix tokenization test by @Cyrilvallez in #41649
- Remove randomly added script by @Cyrilvallez in #41650
- Add missing dates to docs by @yonigozlan in #41576
- Migrate transformers cli to Typer by @Wauplin in #41487
- Fix FP-Quant quantization fallback CPU dispatch. by @BlackSamorez in #41619
- fix check inputs for text2text pipeline by @jiqing-feng in #41556
- [
Executorch] Simplify for encoder models by @vasqu in #41627 - [
Ernie 4.5 Moe] Fix Moe and offloading by @vasqu in #41385 - [CI] Build translated docs by @stevhliu in #41632
- Fix fp32_ln for various models by @remi-or in #41605
- Adjust device logging level and add minor fixes by @mario-koddenbrock in #41636
- Fix EncoderDecoder cache by @remi-or in #41612
- Format MarkDown documentation and tiny fixes by @cyyever in #41638
- Fix typos in documentation by @cyyever in #41641
- Fix confusing cls assignment by @cyyever in #41642
- Double router compute? by @molbap in #41653
- [kernels] refactor function kernel calling by @MekkCyber in #41577
- [Fix] Deepseek V3 expert bias routing by @fjosw in #41647
- purge HF_HUB_ENABLE_HF_TRANSFER; promote Xet by @Vaibhavs10 in #41656
- [
Masks] Fix mask handling in eager for vision models by @vasqu in #41625 - Use | for Optional and Union typing by @cyyever in #41646
- Switch to CB if cache_implementation == paged by @remi-or in #41655
- Add in-out modalities as class attribute per model by @zucchini-nlp in #41366
- Fix dtype casting with quantization by @Cyrilvallez in #41665
- Fix serving continuous batching by @SunMarc in #41624
- Small changes to benchmarking script by @remi-or in #41662
- Improve package version check by @Cyrilvallez in #41661
- improve
utils/check_bad_commit.pyby @ydshieh in #41658 - Erroring when KernelConfig is passed without use_kernels = True by @MekkCyber in #41657
- [Trainer] [Breaking change]
use_cachedefault toFalseby @SunMarc in #41585 - 🌐 [i18n-KO] Translated
chat_extras.mdto Korean by @Judy-Choi in #39863 - 🌐 [i18n-KO] Translated sam_hq.md to Korean by @HyunZ118 in #41340
- [i18n-KO] Translated
big_bird.mdto Korean by @ssum21 in #40445 - 🌐 [i18n-KO] Translated
code_llama.mdto Korean by @Judy-Choi in #40558 - 🌐 [i18n-KO] Translated llama4.md to Korean by @TaskerJang in #40396
- 🌐 [i18n-KO] Translated
ko-LFM2.mdto Korean by @ssum21 in #41502 - Adding superglue fast image processing by @AlphaOrOmega in #41394
- Fix ckpt in docs by @zucchini-nlp in #41659
- torch 2.9 still don't ❤️ torchcodec 0.8 💔 by @ydshieh in #41686
- Remove deprecated
use_auth_tokenparameter by @Wauplin in #41666 - Remove require_torch_bf16_gpu by @cyyever in #40979
- path validation for security reason by @ydshieh in #41256
- 🚨 Remove torchscript support by @Cyrilvallez in #41688
- Fix MarkDown syntax by @cyyever in #41676
- Use | for Optional and Union typing by @cyyever in #41675
- 🚨 [v5] Refactor RoPE for layer types by @zucchini-nlp in #39847
- Enable faiss-cpu on Windows by @cyyever in #41678
- Fix Pylint warnings by @cyyever in #41644
- 🚨 Remove torch.fx support by @Cyrilvallez in #41683
- Remove skipped tests without parents by @Cyrilvallez in #41691
- Enable FURB rules in ruff by @cyyever in #41395
- Remove upper version bound of pandas by @cyyever in #41677
- [
Attn] Allow dynamic causality in SDPA via Kwargs by @vasqu in #41692 - Simplify GQA conditions in sdpa_attention.py by @justinchuby in #41699
- [docs] Manual tp-plan by @stevhliu in #41674
- 🌐 [i18n-KO] Translated gemma3n.md to Korean by @HyunZ118 in #40873
- pin torchcodec on CI docker image by @ydshieh in #41703
- Update
run_namedocs in TrainingArguments by @tobiasofsn in #41705 - further improve
utils/check_bad_commit.pyby @ydshieh in #41658) - feat: add benchmark v2 ci with results pushed to dataset by @McPatate in #41672
- Gemma3 conversion script maintenance by @RyanMullins in #41704
- Fix Qwen3-Omni inference when mixing video and image inputs in one batch by @BakerBunker in #41741
- Fix typo in LFM-VL by @zucchini-nlp in #41742
- Revert "Remove upper version bound of pandas" by @ydshieh in #41744
- [doc] remove broken notebooks on AMD Dev Cloud by @pagezyhf in #41743
- Update type hints in tokenization_utils.py to use | syntax by @faizan842 in #41713
- Fix documentation issues by @cyyever in #41726
- Apply RUFF PIE rules by @cyyever in #41727
- Small Fix for imports by @MekkCyber in #41411
- Docs(zh-hans): Refine wording for professionalism in README by @Ri-Nai in #40943
- Add vision contribution guide by @molbap in #41456
- upgrade xpu docker file to torch 2.8 by @yao-matrix in #41551
- [v5] Delete
videosfrom image processing classes by @zucchini-nlp in #41607 - Fixed incorrect model_type for qwen2vl and qwen2.5vl when config is saved and loaded again by @i3hz in #41758
- [kernels] Add version to function mapping by @MekkCyber in #41685
- Reduce warning noise caused by Tensor.new_tensor by @st81 in #41748
- Fix graphormer model compilation with Cython 3.1.4 by @alexmalyshev in #41671
- Update type hints in modeling_rope_utils.py to use | syntax by @faizan842 in #41714
- [v5] Remove deprecated tranformers.onnx by @echarlaix in #41700
- Modernize CLIP modeling code by @molbap in #41546
- Simplify pipeline padding logic by @Rocketknight1 in #41667
- Chat response parsing by @Rocketknight1 in #40894
- Add LightGlue fast image processor by @yonigozlan in #41670
- Fix bark after #41445 by @ydshieh in #41645
- Remove invalid
@staticmethodfrom module-level get_device_and_memory_breakdown by @albertvillanova in #41747 - Fix CUDA index out of bounds for q_idx in VLM token type masking for Gemma3, PaliGemma, and example modular by @albertvillanova in #41757
- fix: Gemma 3 weights conversion vision and multimodal projector paths by @RyanMullins in #41767
- [v5] Delete legacy chat template saving by @zucchini-nlp in #41648
- [quantization] fix compressed_tensors tests by @MekkCyber in #41780
- [quantization] Skip Fp8 tests when hardware capability < 8.9 by @MekkCyber in #41785
- Swap columns and rows of the grid layout in LFM2-VL by @ankke in #41755
- fix type annotation typo in docstring by @johntheprime in #41788
- Fix chat schema tests by @Rocketknight1 in #41793
- Fix attention mask in mamba layers by @zucchini-nlp in #41790
- [quantization] fix torchao tests after 0.14.0 release by @MekkCyber in #41777
- [
Onnx docs] Remove some traces by @vasqu in #41791 - flash attn pytest marker by @ydshieh in #41781
- Bump AMD docker by @remi-or in #41792
- make apollo test case pass by @yao-matrix in #41805
- Add a safeguard around a flaky test in gemma2 by @remi-or in #41811
- Fix Qwen3Next dtype API usage by @SrijanUpadhyay in #41735
- [Trainer] remove env vars by @SunMarc in #41697
- Fixed grammar mistakes by @FrogWarlord in #41799
- Fixed some grammar mistakes by @FrogWarlord in #41802
- transformers cli default flag fix by @ArjunPimpale in #41761
- Deprecate warmup_ratio by @SunMarc in #41326
- transformers serve quantization docs + some api fixes for bitsandbytes by @SunMarc in #41253
- [Parakeet] add output_attention_mask by @eustlb in #41694
- unpin torch/torchcodec for CircleCI by @ydshieh in #41839
- extend bitnet cases to xpu, all 8 cases pass by @yao-matrix in #41831
- extend 2 trainer test cases to xpu by @yao-matrix in #41829
- extend 2 blip2 and falcon_h1 test cases to xpu by @yao-matrix in #41825
- further reducing flakiness in
utils/check_bad_commit.pyby @ydshieh in #41658) - Remove redundant code from Qwen3VLProcessor by @Xqle in #41836
- Fix MXFP4 quantizer to support variable num_local_experts and hidden_size by @marksverdhei in #41795
- Fix Qwen2Audio flash attention mask format for generation by @Abdennacer-Badaoui in #41843
- Fix const parsing for dict inputs in chat schemas by @Rocketknight1 in #41824
- Share embedding modules in BART, not only weights by @githubnemo in #41821
- Fix TypeError: find_adapter_config_file() got an unexpected keyword argument '_adapter_model_path' by @albertvillanova in #41604
- 🚨 [
Clip] Fix masking and enable flash attention on all model types by @vasqu in #41750 - CI workflow for Flash Attn by @ydshieh in #41857
- Fix torch.no_grad decorator in VLMS by @yaswanth19 in #41888
- Fix installation cmds in docs by @yaswanth19 in #41887
- revert changes in _is_package_available by @MekkCyber in #41891
- make lfm2_moe integration test pass on XPU by @yao-matrix in #41796
- Fix: avoid duplicate token in maybe_load_adapters by @luaenrique in #41903
- speed up loading checkpoints for zero stage 3 by @ri938 in #41850
- evaluate>=0.4.6 is needed by @stas00 in #41920
- Add 6 huggingface notebooks on AMD dev cloud by @fan-amd in #41883
- Fix invalid examples in QwenVL model docstrings and add Qwen3VL example by @Xqle in #41812
- Allow parse_response to accept token IDs by @Rocketknight1 in #41849
- Fix Florence2 conversion script model_type KeyError by @i3hz in #41866
- Update some workflow files by @ydshieh in #41892
- fix some ut failures on XPU w/ torch 2.9 by @yao-matrix in #41923
- Cache latest pytorch amd image locally on mi325 CI runner cluster by @jitesh-gupta in #41926
- Minor fix in docker image build workflow by @ydshieh in #41949
- fix some ut failures on XPU w/ torch 2.9 by @yao-matrix in #41941
- Fix rope_parameters for gemma3 weights conversion script by @douglas-reid in #41922
- Fix: Gemma3TextConfig rope scaling assignments by @RyanMullins in #41934
- fix prepare_config_and_inputs_for_common bug in llava test by @yao-matrix in #41942
- Fix: prevent .gitignore truncation in run_clm_no_trainer.py by @luaenrique in #41957
- V4.57.1 training ci: Refactor
test_tensor_parallel.pyby @3outeille in #41918 - [v5] Return a BatchEncoding dict from apply_chat_template by default by @Rocketknight1 in #41626
- make recurrent_gemma and voxtral cases pass on xpu by @yao-matrix in #41958
- Fix typo in image_processing_lfm2_vl_fast by @yonigozlan in #41940
- Run slow v2 by @ydshieh in #41914
- Fix
detectron2installation in docker files by @ydshieh in #41975 - Fix
autoawq[kernels]installation in quantization docker file by @ydshieh in #41978 - add support for saving encoder only so any parakeet model can be loaded for inference by @nithinraok in #41969
- Use indices as position_ids in modernebert by @remi-or in #41789
- test tensor parallel: make tests for dense model more robust by @3outeille in #41968
- fix: dict[RopeParameters] to dict[str, RopeParameters] by @RyanMullins in #41963
- docs: add continuous batching page by @McPatate in #41847
- Fix
torchcodecversion in quantization docker file by @ydshieh in #41988 - [kernels] Add Tests & CI for kernels by @MekkCyber in #41765
- Move the Mi355 to regular docker by @remi-or in #41989
- More data in benchmarking by @remi-or in #41848
- fix (CI): Refactor SSH runners by @glegendre01 in #41991
- fix 3 failed test cases for video_llama_3 model on Intel XPU by @kaixuanliu in #41931
- Integrate colqwen2.5 using colqwen2 modelling code by @sahil-kabir in #40600
- Fixed wrong padding value in OWLv2 by @gjamesgoenawan in #41938
- Fix
run slow v2: empty report when there is only one model by @ydshieh in #42002 - [kernels] change import time in KernelConfig by @MekkCyber in #42004
- DOC Fix typo in argument name: pseudoquant by @BenjaminBossan in #41994
- Fix
torch+deepspeeddocker file by @ydshieh in #41985 - Correct syntax error in trainer.md by @Yacklin in #42001
- Reduce the number of benchmark in the CI by @remi-or in #42008
- Fix continuous batching tests by @Rocketknight1 in #42012
- add back
logging_dirby @SunMarc in #42013 - Fix issue with from pretrained and kwargs in image processors by @yonigozlan in #41997
- Fix default image_rows and image_cols initialization in Idefics3 and SmolVLM processors by @MilkClouds in #41871
- Add GLPNImageProcessorFast by @Aravind-11 in #41725
- add fuyu fast image processors by @DeXtAr47-oss in #41817
- [kernels] Fix XPU layernorm kernel by @MekkCyber in #41583
- [v5] Deprecate Text2Text and related pipelines by @Rocketknight1 in #41996
- [FPQuant] MXFP8 and MXFP4 backwards support by @BlackSamorez in #41897
- fix
deeepspeedin AMD docker file by @ydshieh in #42025 - CodeQL workflow for security analysis by @paulinebm in #42015
- [tests] Add Context-parallel CI tests by @kashif in #41860
- extend fp_quant cases to xpu by @yao-matrix in #41833
- Change trigger time for AMD CI by @ydshieh in #42034
- Fix the order of methods in processor loading by @zucchini-nlp in #42031
- 🔴 Isolate prefill from generation loops by @manueldeprada in #40652
- update
huggingface_hubdependency version by @hanouticelina in #42033 - Remove some custom datasets defined in codebase by @ydshieh in #41511
- Cleanup workflow - part 1 by @ydshieh in #42023
- Fix
pr_slow_ci_suggestion.ymlafter #42023 by @ydshieh in #42049 - Fix AutoImageProcessor.register and documentation in auto processing modules by @MilkClouds in #41864
- Fix Qwen3-Omni RoPE by @zucchini-nlp in #41778
- Avoid explicit checkout in workflow by @ydshieh in #42057
- Annoying typo in attention error message by @manueldeprada in #42037
- Be careful at explicit checkout actions by @ydshieh in #42060
- Fix another
Argument list too longinpr_slow_ci_suggestion.ymlby @ydshieh in #42061 - Fix KeyError in GPT-OSS weight conversion script by @Aznix07 in #42007
- Fix KeyError in _is_package_available for packages with dotted names by @yashwantbezawada in #42050
- Revert back to use GitHub context by @ydshieh in #42066
- Fix missing arg in check_docstring by @yonigozlan in #42054
- [deepspeed tests fixes] by @stas00 in #41925
- Fix logic in setting self.fsdp when it is False by @roychan in #41974
- fix tensor device placement issue of 2 UT cases by @yao-matrix in #41921
- add workflow to check permissions and advise a set of permissions req… by @paulinebm in #42071
- Fix security issue 5 by @paulinebm in #42072
- Fix inconsistency of commit sha during the workflow run by @ydshieh in #42074
- QwenVL: add skipped keys in
setattras well by @zucchini-nlp in #41808 - permissions worflows fix by @paulinebm in #42080
- 4.1V Model and GLM-4.5V Model Conversion Code Updates by @zRzRzRzRzRzRzR in #41784
- feat(ci): add continuous batching to benchmarks by @McPatate in #41916
- Fix modular docstring for Mixtral by @diegoakel in #42041
- Fix Auto classes to support dynamically registered processors by @MilkClouds in #41865
- Reinstate self.scaling in Gemma3nTextAttention by @RyanMullins in #41751
- [v5] 🚨Refactor subprocessors handling in processors by @yonigozlan in #41633
- add xpu support in test_modeling_janus.py::JanusIntegrationTest::test… by @sywangyi in #41986
- Revert "permissions worflows fix" by @ydshieh in #42110
- Fix return metadata checking logic by @Xqle in #42108
- Correctly handle unbatched audio inputs in Gemma3nAudioFeatureExtractor by @kho in #42076
- [Bugfix] fix qwen3vl expand generation with video by @JJJYmmm in #42089
- Fix base model prefix in VLMs by @zucchini-nlp in #42059
- fix continuous batching issues, extend ut cases to xpu by @yao-matrix in #41830
- 📝 docs(smolvlm): fix variable name in batch inference example by @gorkachea in #42123
- fix qwen2vl/qwen3vl video processor temporal padding when num_frames%temporal_patch_size!=1 by @yaogang2060 in #42083
- [
Attn Masks] Non-vmap default for attention masks by @vasqu in #41852 - Fix GPT-2 Flash Attention 2 generation with left-padding by @Abdennacer-Badaoui in #41966
- Fix model name test for compressed tensors by @SunMarc in #42128
- Fix MaskFormer/Mask2Former fast image processors by @yonigozlan in #41393
- Remove unused functions in
image_transforms.pyby @yaswanth19 in #42044 - update deps table by @ArthurZucker in #42120
- fix: improve video processing fps assignment logic by @Xqle in #42009
- Fix T5Gemma module structure by @Cyrilvallez in #42145
- DataCollatorForLanguageModeling warning error fixed by @mjaliz in #42144
- Bugfix/remove emojis from print by @7amim in #42091
- Avoid mutating user-provided arguments in preprocessing utils by @LeonardoEmili in #42126
- Enforce check_auto_docstring by @yonigozlan in #41635
- Add dinov3 autobackbone by @vijayabhaskar-ev in #41276
- Fix logic error in
prepare_inputs_for_generationcache slicing condition by @albertvillanova in #41764 - 🚨 Fix gradient checkpointing for several models and improve test robustness by @githubnemo in #41818
- [
T5Gemma] Fix cross attention cache by @vasqu in #41890 - T5 migration to new masking interface by @Aravind-11 in #41804
- fix: improve visibility of ValueError root causes in model config loading by @scottzh8 in #41972
- add xpu to valid hardware for torch.compile by @sywangyi in #42079
- extend test_beam_search_early_stop_heuristic case to other device by @sywangyi in #42078
- fix failure of tests/models/shieldgemma2/test_modeling_shieldgemma2.p… by @sywangyi in #42022
- Fixes Flash Attention implementation for models by @i3hz in #42149
- fix test failure of speculative_generation on xpu by @sywangyi in #42052
- add rmsnorm kernels support for npu by @zheliuyu in #42106
- update torchao doc by @jiqing-feng in #42139
- feat(kernels): add opt-out flag to disable kernels hub usage through the lib by @mfuntowicz in #41990
- handle inputs from Siglip/Siglip2 non-automapped encoder layers by @molbap in #41930
- Add slow to some examples tests by @SunMarc in #42164
- fix(ci): unexpected keyword argument
streamingby @McPatate in #42102 - pin
pytest<9for now by @ydshieh in #42162 - Docs/i18n updates by @lilin-1 in #42006
- Fix in-place modification of user-input in SAM2 embed boxes by @xenova in #42173
- [
Pop2Piano] Fix cache usage by @vasqu in #42170 - Fix helper fn for new processor config format by @zucchini-nlp in #42085
- Remove unnecessary slicing in sdpa_attention_forward by @justinchuby in #41900
- [
PEFT] Fix prefix tuning by @vasqu in #41696 - [typo] fix mrope-interleave annotation to avoid ambiguity by @JJJYmmm in #42177
- Update transformers to support
FqnToConfigby @jcaip in #41894 - [
PEFT] Fix the general test for prefix tuning by @vasqu in #42185 - [TP] Fix parameter detection issue and some invalid TP-plans by @Cyrilvallez in #42129
- Refactor weight loading by @ArthurZucker in #41580
- 🚨 Delete deprecations with end-cycle in v4.xx and v5.0 by @zucchini-nlp in #41681
- Add AutoTokenizer mapping for mistral3 and ministral by @patrickvonplaten in #42198
- Fix checkpoint loading with DeepSpeed ZeRO3 by @tohtana in #42201
- [
Pop2Piano] Fix tied weights by @vasqu in #42193 - New docker from AMD by @remi-or in #42208
- Add cross links for model contribution by @zucchini-nlp in #42207
- Stop inheriting tests! by @Rocketknight1 in #42192
- Refactor check_auto_docstring using AST by @yonigozlan in #41432
- [
BLT] Fix cache usage by @vasqu in #42188 - Update
test_dynamic_cache_exportability_multiple_run(failing on torch 2.10 nightly) by @ydshieh in #42212 - Much more efficient and clear weight initialization and tie weights by @Cyrilvallez in #42191
- GLM-V update with new processor by @zRzRzRzRzRzRzR in #42122
- Fix initialization guard for pytest by @Cyrilvallez in #42234
- Fix TP plans for MoE models by @Cyrilvallez in #42236
- Add prefix sharing to continuous batching by @remi-or in #42094
- Loading optimization by @Cyrilvallez in #42239
- calls
AttentionMaskConverter._unmask_unattendedfor xpu device before by @kaixuanliu in #42230 - FIX Broken PEFT adapter loading by @BenjaminBossan in #42187
- Fix processor test for glm by @molbap in #42233
- Fix UnboundLocalError in RT-DETR loss computation by @yashwantbezawada in #42224
- Stop inheriting tests (again) by @Rocketknight1 in #42247
- [loading] Fix device when source and target are different by @Cyrilvallez in #42246
- Reduce timing on CircleCI - part 1 (Use @slow for IntegrationTests) by @ydshieh in #42206
- 🚨 Delete generation params from model config by @zucchini-nlp in #41695
- Allow VLMs to have a correct
base_modelby @zucchini-nlp in #41589 - Make tests run in less time by reducing
batch_sizeby @ydshieh in #42213 - Revert "Make tests run in less time by reducing
batch_size" by @ydshieh in #42258 - Cleanup reference to TFBertTokenizer and TFGPT2Tokenizer by @Rocketknight1 in #42182
- delete already deprecated models by @ydshieh in #42235
- Fix bnb for the weights refactor by @SunMarc in #42043
- Fix looping in torch guard decorator by @Cyrilvallez in #42260
- 🚨 Generalize
get_decoder()for multimodal and delete redundant code 🔪 by @zucchini-nlp in #42156 - Audio Flamingo3 - fix attention masking by @zucchini-nlp in #42278
- Add support for torch device objects in device validator by @yonigozlan in #42267
- Remove doc files of other langs for deleted models by @ydshieh in #42276
- [testing] fix
cwmby @ydshieh in #42261 - fix a typo: pbd -> pdb by @jaeminoh in #42268
- Enable glm46v UTs on XPU by @YangKai0616 in #42274
- [testing] fix some cases in xpu by @sywangyi in #42273
- Remove random flag by @Cyrilvallez in #42282
- Fix accelerate integration by @Cyrilvallez in #42264
- Fix validation checks order in benchmark_v2 by @Abdennacer-Badaoui in #42280
- Update torchcodec to match torchaudio version by @remi-or in #42288
- Use
torch.get_autocast_dtypeinstead oftorch.get_autocast_gpu_dtypeby @qgallouedec in #42055 - perf: Optimization for Min-p sampling implementation by @casinca in #42248
- Fix device_map computation part 2 by @Cyrilvallez in #42290
- Fixed the docstring for
WhisperFeatureExtractorby @TopCoder2K in #42286 - avoiding conditional indexing in positionalencoding to avoid possibil… by @ppadjinTT in #42090
- ENH: Add support for LoRA hotswapping by @BenjaminBossan in #41297
- Fix Break change of AWQ FusedModules due to Attention Refactor by @fanqiNO1 in #41909
- Remove error string test that was failing by @Rocketknight1 in #42301
- Properly protect the is_compiling checks by @Cyrilvallez in #42304
- Remove outdated methods in modeling_utils.py by @Cyrilvallez in #42302
- Fix Mac mps dataloader_num_workers > 1 causes RuntimeError: share_filename: only available on CPU by @AmitMY in #38819
- Fix the init_weights for the MoE models by @Cyrilvallez in #42306
- Update link to generation strategies documentation by @omkar-334 in #42252
- Update conversion mapping to separate renaming from converting by @ArthurZucker in #42254
- fix(granitemoe*): Only create block_sparse_moe if num_local_experts > 0 by @gabe-l-hart in #42036
- [SAM3 Video] Add support for multi prompts by @yonigozlan in #42293
- Add Pix2Struct fast image processor by @yonigozlan in #42020
- Fix post processing methods in keypoints matching models by @yonigozlan in #42018
- fix tests/models/xcodec/test_modeling_xcodec.py::XcodecIntegrationTest by @sywangyi in #42272
- [loading] Fix device detection by @Cyrilvallez in #42323
- Fix typo from side_dict to size_dict by @nihui in #42319
- HF Trainer: ALST/Ulysses sequence parallelism integration via HF Accelerate by @stas00 in #41832
- Fix gpt2 modeling tests by @Abdennacer-Badaoui in #42321
- [loading] Use fewer threads by default for much better performances by @Cyrilvallez in #42324
- Allow LayoutLMV3Processor to accept rescale_factor by @Rocketknight1 in #42305
- Correctly create tied key mapping in post_init, and dynamic tie weight by @Cyrilvallez in #42270
- [
CI] SkipEfficientLoFTRtest by @vasqu in #42327 - [XPU] Add flash_attn2 support for XPU by @YangKai0616 in #41956
- [
Attn Masks] Lift bidirectional mask restriction on eager by @vasqu in #42325 - fix bug when gemma3n model run on multiple device by @kaixuanliu in #42303
- Fix ChineseCLIPModel.get_text_features by @JiangJQ2000 in #42351
- Gemma3 hybrid fix by @remi-or in #42287
- fix(benchmarks): correct sdpa_backend inconsistency and attn_implementation for continuous batching by @engmohamedsalah in #42339
- Auto convert tekken.json by @ArthurZucker in #42299
- [loading] Re-add and improve disk offloading support by @Cyrilvallez in #42242
- Fix typo - indentation in JSON dump example by @anthropikos in #42332
- Fix tied weight for Bart (for BC) by @Cyrilvallez in #42355
- Fix reference to yelp dataset by @JuanFKurucz in #42349
- Fix documentation reference to pytorch max memory allocated by @JuanFKurucz in #42350
- Fix reference to imagenet 1k dataset by @JuanFKurucz in #42348
- Fix typos by @omahs in #42354
- Protect
torch.distributedimports by @Cyrilvallez in #42361 - Expand npu device for KernelConfig by @zheliuyu in #42358
- Replace Optional and Union typing with | in some source files by @cyyever in #42294
- Fix code examples to load gpt 1 openai community model by @JuanFKurucz in #42347
- fix tekken pattern matching by @ArthurZucker in #42363
- Fixed-wrong-ZeRO3-json-snippet-found-in-deepspeed-markdown-file by @Yacklin in #42346
- Make benchmarking lighter: clean-up result files and remove non-needed arguments by @remi-or in #42357
- Add image processor fast vitpose by @yonigozlan in #42021
- Small tp fix by @ArthurZucker in #42366
- Remove test inheritance for EfficientLoftr, rename KeypointMatchingOutput to model specific name by @yonigozlan in #42365
- Tiny doc fix by @molbap in #42296
- Fix TimesFM patch normalization instability by @AnMakc in #42099
- [core] Fix torchao by @MekkCyber in #42289
- Fix tp by @ArthurZucker in #42368
- [
Attn Masks] Add skip option for non-packed sequences by @vasqu in #42367 - 📚 docs(granite-speech): add comprehensive usage examples by @gorkachea in #42125
- Xcodec fix by @eustlb in #42095
- Replace Optional and Union typing with | in some source files by @cyyever in #42372
- [
Mistral Tokenizers] Fix tokenizer detection by @vasqu in #42389 - misc don't recreate it by @ArthurZucker in #42394
- [SAM3] Fix precompute vision_embeds or text_embeds for inference by @yonigozlan in #42407
- 🚨 Image-text pipeline expects correctly formatted chat by @zucchini-nlp in #42359
- Many small fixes for the CI by @remi-or in #42364
- [core] fix mxfp4 by @MekkCyber in #42382
- fixed json syntax error for zero2 configuration file found in deepspeed.md by @Yacklin in #42406
- GLM4V - delete duplicate config attribute by @zucchini-nlp in #42416
- 🚨 Remove generic output_attentions warning by @Aravind-11 in #42334
- Bart config doesn't need generation parameters by @zucchini-nlp in #42337
- Simplify and standardize processor tests by @yonigozlan in #41773
- Clean bnb integration using weight converter by @SunMarc in #42426
- Any to any pipeline and auto-mapping by @zucchini-nlp in #40884
- Fix processor usage + add chat_template support to TTS pipeline, and shift common chat template logic to base class. by @ebezzam in #42326
- [fp8] fix scales param name by @MekkCyber in #42434
- Fix an edge case for
get_encoder()by @zucchini-nlp in #42295 - Disable loss rounding in training stats log by @AnMakc in #42104
- Benchmark simplification by @remi-or in #42408
- Future annotations break FastAPI by @LysandreJik in #42450
- [cleanup] Don't use Repository in create_dummy_models.py script by @Wauplin in #42380
- [cleanup] Remove deprecated load config from file by @Wauplin in #42383
- [
FA] Cleanup loading logic by @vasqu in #41427 - tiny fix for deepseekocr support [vllm] by @molbap in #42423
- fix: Restore explicit .keys() calls for TensorDict compatibility by @pankajbaid567 in #42373
- Transformers serve -> list all generative models from the cache by @LysandreJik in #42146
- 🚨 [v5][PEFT] Bump min version requirement of PEFT to 0.18.0 by @BenjaminBossan in #41889
- [cleanup] Offline mode and cache dir from
huggingface_hubconstants + cleanup inPushToHubMixinby @Wauplin in #42391 - Correctly return finish reason length when finished by @LysandreJik in #42157
- FIX: Minimal fix for loading PEFT weights by @BenjaminBossan in #42387
- Let's break Qwen-VL 🚨 by @zucchini-nlp in #42420
- [
CI] Add to run slow by @vasqu in #42459 - Fix the "test_offline" test by @LysandreJik in #42458
transformers chatlaunched without base_url has a direct tie to localhost:8000 by @LysandreJik in #42463- update with more recent tts models by @Deep-unlearning in #42328
- rm slow tokenizers by @itazap in #40936
- [loading/saving] Reverse all loading operations when saving by @Cyrilvallez in #42396
- Fix T5 tests: use generation_config for generation parameters by @Abdennacer-Badaoui in #42419
- remove reference to TF models from docs by @zucchini-nlp in #42443
- [Trainer] use output.loss when using liger-kernel by @kashif in #42444
- replace source_keys and target_keys by @SunMarc in #42471
- Update migration guide - generation config by @zucchini-nlp in #42470
- 🚨 Move
rotary_partial_embto RopeParams and delete unnecessary code 🔪 by @zucchini-nlp in #42255 - Fix doc builds by @Rocketknight1 in #42478
- extend CwmIntegrationTest to xpu by @sywangyi in #42314
- add require_deterministic_for_xpu to make the case pass in xpu by @sywangyi in #42439
- Skip failing irrelevant test for ColQwen2 by @Rocketknight1 in #42480
- [quantization] make torchao tests slow by @MekkCyber in #42482
- Fix gpt2 tokenizer
add_prefix_spacedefault value by @SunMarc in #42481
Significant community contributions
The following contributors have made significant changes to the library over the last release:
- @ArthurZucker
JetMoeFix jetmoe after #40132 (#41324)- [
ModularChecker] QOL for the modular checker (#41361) - [
CB] Refactors the way we access paged (#41370) - Update from pretrained error when loading (#33380)
- 🤦 CB nit! (#41413)
- [
from_pretrained] Small refactorfrom_pretrained: move around unrelated stuff (#41445) - update deps table (#42120)
- Refactor weight loading (#41580)
- Update conversion mapping to separate renaming from converting (#42254)
- Auto convert tekken.json (#42299)
- fix tekken pattern matching (#42363)
- Small tp fix (#42366)
- Fix tp (#42368)
- misc don't recreate it (#42394)
- @vasqu
- 🚨 [
v5] Remove relative position embeddings (for bert like models) (#41170) - [
v5] Sync Bert and Bart eager attention (#41248) - [
JetMoe] Fix KV head repetition and padding free (#41423) - 🚨 [
Attention Masks] Bidirectional masks for encoder and encoder-decoder models (#41265) - [
CI] Fix copies on main (#41486) - [
Docs] Fix changed references (#41614) - [
Executorch] Simplify for encoder models (#41627) - [
Ernie 4.5 Moe] Fix Moe and offloading (#41385) - [
Masks] Fix mask handling in eager for vision models (#41625) - [
Attn] Allow dynamic causality in SDPA via Kwargs (#41692) - [
Onnx docs] Remove some traces (#41791) - 🚨 [
Clip] Fix masking and enable flash attention on all model types (#41750) - [
Attn Masks] Non-vmap default for attention masks (#41852) - [
T5Gemma] Fix cross attention cache (#41890) - [
Pop2Piano] Fix cache usage (#42170) - [
PEFT] Fix prefix tuning (#41696) - [
PEFT] Fix the general test for prefix tuning (#42185) - [
Pop2Piano] Fix tied weights (#42193) - [
BLT] Fix cache usage (#42188) - [
CI] SkipEfficientLoFTRtest (#42327) - [
Attn Masks] Lift bidirectional mask restriction on eager (#42325) - [
Attn Masks] Add skip option for non-packed sequences (#42367) - [
Mistral Tokenizers] Fix tokenizer detection (#42389) - [
FA] Cleanup loading logic (#41427) - [
CI] Add to run slow (#42459)
- 🚨 [
- @ydshieh
- [testing] update
test_longcat_generation_cpu(#41368) - [testing] Fix
JetMoeIntegrationTest(#41377) - Pickle - part 2 (#41476)
- Try to remove
pickle-BloomTokenizerFast(#41466) - [testing] reduce runtime of
HunYuanMoEV1IntegrationTest:test_model_generation(#41373) - delete some tokenizer tests using pickle (#41514)
- torch 2.9 don't ❤️ torchcodec 💔 (#41610)
- Update a dataset reop link (#41618)
- Remove the head masking block in some vision models (#41620)
- improve
utils/check_bad_commit.py(#41658) - torch 2.9 still don't ❤️ torchcodec 0.8 💔 (#41686)
- path validation for security reason (#41256)
- pin torchcodec on CI docker image (#41703)
- further improve
utils/check_bad_commit.py(#41658) (#41690) - Revert "Remove upper version bound of pandas" (#41744)
- Fix bark after #41445 (#41645)
- flash attn pytest marker (#41781)
- unpin torch/torchcodec for CircleCI (#41839)
- further reducing flakiness in
utils/check_bad_commit.py(#41658) (#41815) - CI workflow for Flash Attn (#41857)
- Update some workflow files (#41892)
- Minor fix in docker image build workflow (#41949)
- Run slow v2 (#41914)
- Fix
detectron2installation in docker files (#41975) - Fix
autoawq[kernels]installation in quantization docker file (#41978) - Fix
torchcodecversion in quantization docker file (#41988) - Fix
run slow v2: empty report when there is only one model (#42002) - Fix
torch+deepspeeddocker file (#41985) - fix
deeepspeedin AMD docker file (#42025) - Change trigger time for AMD CI (#42034)
- Remove some custom datasets defined in codebase (#41511)
- Cleanup workflow - part 1 (#42023)
- Fix
pr_slow_ci_suggestion.ymlafter #42023 (#42049) - Avoid explicit checkout in workflow (#42057)
- Be careful at explicit checkout actions (#42060)
- Fix another
Argument list too longinpr_slow_ci_suggestion.yml(#42061) - Revert back to use GitHub context (#42066)
- Fix inconsistency of commit sha during the workflow run (#42074)
- Revert "permissions worflows fix" (#42110)
- pin
pytest<9for now (#42162) - Update
test_dynamic_cache_exportability_multiple_run(failing on torch 2.10 nightly) (#42212) - Reduce timing on CircleCI - part 1 (Use @slow for IntegrationTests) (#42206)
- Make tests run in less time by reducing
batch_size(#42213) - Revert "Make tests run in less time by reducing
batch_size" (#42258) - delete already deprecated models (#42235)
- Remove doc files of other langs for deleted models (#42276)
- [testing] fix
cwm(#42261)
- [testing] update
- @cyyever
- Remove unnecessary list comprehension (#41305)
- Remove unused function patameters (#41358)
- Use accelerator API to free device memory (#41195)
- Remove Python 3.9 classifier (#41410)
- Remove KERAS_NLP_IMPORT_ERROR (#41468)
- Import Callable from collections.abc (#41130)
- Remove infer_device (#41088)
- Fix Latex typesetting in documentation (#41177)
- Fix typsetting and content of llm_tutorial_optimization.md (#41172)
- More markdown file fixes (#41599)
- Format MarkDown documentation and tiny fixes (#41638)
- Fix typos in documentation (#41641)
- Fix confusing cls assignment (#41642)
- Use | for Optional and Union typing (#41646)
- Remove require_torch_bf16_gpu (#40979)
- Fix MarkDown syntax (#41676)
- Use | for Optional and Union typing (#41675)
- Enable faiss-cpu on Windows (#41678)
- Fix Pylint warnings (#41644)
- Enable FURB rules in ruff (#41395)
- Remove upper version bound of pandas (#41677)
- Fix documentation issues (#41726)
- Apply RUFF PIE rules (#41727)
- Replace Optional and Union typing with | in some source files (#42294)
- Replace Optional and Union typing with | in some source files (#42372)
- @yao-matrix
- make some ut cases pass on xpu w/ latest torch (#41337)
- fix asr ut failures (#41332)
- enable new model uts to xpu and fix some failures on xpu (#41386)
- enable some falcon-mamba uts on xpu (#41428)
- enhance patched_tearDown to support python 3.11+ (#41429)
- fix gemma3n case failure (#41426)
- upgrade xpu docker file to torch 2.8 (#41551)
- make apollo test case pass (#41805)
- extend bitnet cases to xpu, all 8 cases pass (#41831)
- extend 2 trainer test cases to xpu (#41829)
- extend 2 blip2 and falcon_h1 test cases to xpu (#41825)
- make lfm2_moe integration test pass on XPU (#41796)
- fix some ut failures on XPU w/ torch 2.9 (#41923)
- fix some ut failures on XPU w/ torch 2.9 (#41941)
- fix prepare_config_and_inputs_for_common bug in llava test (#41942)
- make recurrent_gemma and voxtral cases pass on xpu (#41958)
- extend fp_quant cases to xpu (#41833)
- fix tensor device placement issue of 2 UT cases (#41921)
- fix continuous batching issues, extend ut cases to xpu (#41830)
- @MekkCyber
- [kernels] Kernel Config (#41232)
- Fixing comments in init file (#41414)
- [kernels] Cleanup deta kernel (#41470)
- Cleaning hub kernels (#41477)
- Remove DISABLE_KERNEL_MAPPING flag (#41475)
- [kernels] Remove RWKV kernel finally ! (#41493)
- [kernels] rm yoso kernel (#41495)
- [kernels] rm mra kernels (#41507)
- Revert "add rmsnorm kernels support for Intel XPU" (#41579)
- [kernels] refactor function kernel calling (#41577)
- Erroring when KernelConfig is passed without use_kernels = True (#41657)
- Small Fix for imports (#41411)
- [kernels] Add version to function mapping (#41685)
- [quantization] fix compressed_tensors tests (#41780)
- [quantization] Skip Fp8 tests when hardware capability < 8.9 (#41785)
- [quantization] fix torchao tests after 0.14.0 release (#41777)
- revert changes in _is_package_available (#41891)
- [kernels] Add Tests & CI for kernels (#41765)
- [kernels] change import time in KernelConfig (#42004)
- [kernels] Fix XPU layernorm kernel (#41583)
- [core] Fix torchao (#42289)
- [core] fix mxfp4 (#42382)
- [fp8] fix scales param name (#42434)
- [quantization] make torchao tests slow (#42482)
- @paulpak58
- @gante
- @zRzRzRzRzRzRzR
- @jacobkahn
- Add Code World Model (CWM) (#41199)
- @molbap
- Update philosophy (#41438)
- [QoL] modular conversion shows LoC saved (#41500)
- Double router compute? (#41653)
- Add vision contribution guide (#41456)
- Modernize CLIP modeling code (#41546)
- handle inputs from Siglip/Siglip2 non-automapped encoder layers (#41930)
- Fix processor test for glm (#42233)
- Tiny doc fix (#42296)
- tiny fix for deepseekocr support [vllm] (#42423)
- @Wauplin
- Bump to hfh 1.0.0.rc5 to fix test (#41508)
- Migrate transformers cli to Typer (#41487)
- Remove deprecated
use_auth_tokenparameter (#41666) - added more breaking changes
- [cleanup] Don't use Repository in create_dummy_models.py script (#42380)
- [cleanup] Remove deprecated load config from file (#42383)
- [cleanup] Offline mode and cache dir from
huggingface_hubconstants + cleanup inPushToHubMixin(#42391)
- @remi-or
- Restore cuda graphs to continuous batching (#41421)
- Fix an import error with PreTrainModel (#41571)
- Add iter to DynamicCache (#41569)
- Gemma3 fixes (#41572)
- Benchmark overhaul (#41408)
- Fix fp32_ln for various models (#41605)
- Fix EncoderDecoder cache (#41612)
- Switch to CB if cache_implementation == paged (#41655)
- Small changes to benchmarking script (#41662)
- Bump AMD docker (#41792)
- Add a safeguard around a flaky test in gemma2 (#41811)
- Use indices as position_ids in modernebert (#41789)
- Move the Mi355 to regular docker (#41989)
- More data in benchmarking (#41848)
- Reduce the number of benchmark in the CI (#42008)
- New docker from AMD (#42208)
- Add prefix sharing to continuous batching (#42094)
- Update torchcodec to match torchaudio version (#42288)
- Gemma3 hybrid fix (#42287)
- Make benchmarking lighter: clean-up result files and remove non-needed arguments (#42357)
- Many small fixes for the CI (#42364)
- Benchmark simplification (#42408)
- @lkhl
- [model] Add VideoLLaMA3 implementation (#40499)
- @philiproeleveld
- Add
logits_to_keepto many older CausalLM models (#41335)
- Add
- @AlphaOrOmega
- Adding superglue fast image processing (#41394)
- @echarlaix
- [v5] Remove deprecated tranformers.onnx (#41700)
- @Aravind-11
- @DeXtAr47-oss
- add fuyu fast image processors (#41817)
- @lashahub
- [models] Add AudioFlamingo3 integration (#40290)
- @lilin-1
- Docs/i18n updates (#42006)
- @burtenshaw
- [MODEL] Nanochat implementation (#41634)
- @itazap
- rm slow tokenizers (#40936)