Summary
This release brings major improvements to enable efficient distributed training, quantization, and CPU support in Burn.
To achieve true multi-GPU parallelism, we had to rethink several core systems: we implemented multi-stream execution to keep all GPUs busy, optimized device transfers to avoid unnecessary synchronization, and redesigned our locking strategies to eliminate bottlenecks in autotuning, fusion, and autodiff. We also introduced burn-collective for gradient synchronization and refactored our training loop to support different distributed training strategies.
Additionally, we added comprehensive quantization support, allowing models to use significantly less memory while maintaining performance through fused dequantization and optimized quantized operations.
Finally, we introduced a new CPU backend powered by MLIR and LLVM, bringing the same JIT compilation, autotuning, and fusion capabilities from our GPU backends to CPU execution.
As with previous releases, this version includes various bug fixes, further optimizations and enhanced documentation. Support for ONNX models has also been expanded, with additional operators and bug fixes for better operator coverage.
For more details, check out the release post on our website.
Changelog
Breaking
We've introduced a couple of breaking API changes with this release. The affected interfaces are detailed in the sections below.
Learning Strategy
We refactored the Learner to support better distributed training strategies. Instead of registering a list of device(s), you now specify a training strategy.
let learner = LearnerBuilder::new(artifact_dir)
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
- .devices(vec![device.clone()])
+ .learning_strategy(LearningStrategy::SingleDevice(device.clone()))
.num_epochs(config.num_epochs)
.summary()
.build(
config.model.init::<B>(&device),
config.optimizer.init(),
config.learning_rate,
);Learner Training Result
The Learner previously lacked an evaluation loop. We extended its return type to include all training states in a TrainingResult, which includes the trained model and a metrics renderer.
- let model_trained = learner.fit(dataloader_train, dataloader_valid);
+ let result = learner.fit(dataloader_train, dataloader_valid);
- model_trained
+ result
+ .model
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
.expect("Trained model should be saved successfully");This enables the renderer to be reused by the new evaluator so that training and evaluation metrics appear together in the TUI dashboard:
let mut renderer = result.renderer;
let evaluator = EvaluatorBuilder::new(artifact_dir)
.renderer(renderer)
.metrics((AccuracyMetric::new(), LossMetric::new()))
.build(result.model.clone());
evaluator.eval(name, dataloader_test);Interface Changes
Config
The Config trait now requires Debug:
- #[derive(Config)]
+ #[derive(Config, Debug)]
pub struct TrainingConfig {
// ...
}BatchNorm
BatchNorm no longer requires the spatial dimension generic:
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: nn::conv::Conv2d<B>,
- norm: BatchNorm<B, 2>,
+ norm: BatchNorm<B>,
pool: Option<MaxPool2d>,
activation: nn::Relu,
}Backend::seed
Seeding is now device-specific:
- B::seed(seed);
+ B::seed(&device, seed);Tensor
For consistency with other methods like unsqueeze() / unsqueeze_dim(dim), squeeze(dim) was renamed:
- tensor.squeeze(dim)
+ tensor.squeeze_dim(dim)We've also added a tensor.squeeze() method which squeezes all singleton dimensions.
Finally, we removed tensor ^ T syntax, which was clunky.
- use burn::tensor::T;
- tensor ^ T
+ tensor.t()tensor.t() is also a simple alias for tensor.transpose().
Module & Tensor
- Fix unsqueeze rank check (#3429) @laggui
- Feat/quant block (#3442) @laggui
- Kill
tensor^Tmagic transpose marker in favor oftensor.t(). (#3452) @crutcher - ADD GLU activation function (#3444) @bn-c
- Add quantization params precision (#3453) @laggui
- Improve select_assign check (#3483) @laggui
- Add grid_sample function (#3495 #3523 #3522) @Cielbird
- save_tensor_as_image utility (#3520) @Cielbird
- Add affine_grid_2d (#3526) @Cielbird
- ADD missing Debug derive for embedding (#3547) @bn-c
- Dot Product Op (#3537) @kikefdezl
- Lift .full()/.full_like() into base Tensor - support Tensor<B, D, Bool>::full()/full_like(). (#3562) @crutcher
- Make
Distribution::DefaulttheDefault::default(). (#3582) @crutcher - Implement int matmul (#3575) @wingertge
- Feat/quant formats (#3613) @laggui
- Switch Tensor::swap_dims/permute to AsIndex dim support. (#3619) @crutcher
- Tensor::flatten() => AsIndex dims support. (#3620) @crutcher
- Remove D param from
BatchNorm<B, D>. (#3625) @crutcher - nn.activation; Activation (#3603 #3693) @crutcher
- Add q4 q2 quantization (#3617) @laggui
- Introduce
NormLayerabstraction for unified normalization layers. (#3630) @crutcher - Add dtype to trait creation ops (#3670) @laggui
- Make Config require Debug (#3689) @crutcher
- Add NormalizationConfig::with_num_features() and related (#3688) @crutcher
- Module quantization w/ tests (#3637) @nathanielsimard
- Add NumPy-like take operation with multi-dimensional index support (#3681) @antimora
- Added trace and diag with batch support for linalg crate (#3703) @niklund
- Add step support to tensor
sliceoperations (#3748) @antimora - Tensor::unfold(dim, size, step) (#3751 #3782 #3783) @crutcher
- Slice assign with steps (#3776) @antimora
- Add
bool_xoroperation for boolean tensors (#3785) @crutcher - [Breaking] Make squeeze/squeeze_dim consistent with other APIs (#3790) @laggui
- Add cross product (#3743) @SinanGncgl
- Enable stepped slicing for slice_fill and complete slice API cleanup (#3784) @antimora
- Tensor::rank() (#3797) @crutcher
- AsIndex dim handling for Numeric ops (#3795) @crutcher
- Add outer and outer_batch ops in linalg (#3786) @huy209vn
- Tensor::_dims() (#3811) @crutcher
- Add
tensor.cumsum(dim)first implementation (#3806) @antimora - slice_fill() should pick a compatible dtype (#3826) @crutcher
- Default LU decomposition implementation (#3816) @DimitriTimoz
- Add
tensor.squareand fast-path int-power exponents. (#3847) @crutcher - Add cumulative operations: cumprod, cummin, and cummax (#3819) @antimora
- Add Tensor::sum_dims_squeeze(dims) (#3817) @crutcher
- Allow linear to use quantized matmul (#3913) @wingertge
Datasets & Training
- Pre-Shuffle Multithread DataLoaders on Shuffle (#3390) @crutcher
- PixelDepth + Copy (#3419) @crutcher
- Add Dice-Sorenson Coefficient Metric (#3407) @MathijsdeBoer
- Add SelectionDataset, refactor ShuffledDataset, and add transform tests. (#3406) @crutcher
- Evenly distribute complete chunks/batches across partial dataset splits (#3476) @laggui
- Distributed Data Parallel (#3456) @Cielbird
- Use tensor ops for clip_by_norm (#3485) @laggui
SamplerDatasetdistribution fix; constructors and builder. (#3490) @crutcher- Unify transform usage of RngOptions. (#3577) @crutcher
- Fix bugs with ddp learning (#3581) @Cielbird
- Add support for CIFAR-10 and CIFAR-100 datasets (#3579) @buttfa
- Add with_interrupter for LearnerBuilder (#3611) @amfaber
- Improved Burn Train (#3614 #3935) @nathanielsimard @laggui
- Add 'TextFolderDataset' struct and
AgNewsDataset(#3698) @buttfa - Add PerplexityMetric for language model evaluation (#3707) @TheDarkchip
- Adding CER/WER metrics (#3418) @yazanmashal03
- Fix/autodiff/multi threads (#3793) @nathanielsimard
- Add
cautious_weight_decayto AdamW optimizer. (#3869) @crutcher - Fix evaluator dataloader device (#3893) @laggui
Backends
- Migrate to new cubecl multi tensor handle changes (#3136) @wingertge
- More memory control with scoped static memory management (#3410) @nathanielsimard
- Feat/fusion quant (#3454) @nathanielsimard
- Expose client utilities (#3559) @allenqm
- New CPU backend based on MLIR (#3411) @marcantoinem
- feat: ndarray dynamic tensor types and int tensor cast (#3647) @wingertge
- Implement optimized bool_select for primary backends (#3710) @TheDarkchip
- Add backend level is_nan / is_inf implementations (#3809) @laggui
- Feat/persistent memory (#3842) @nathanielsimard
- feat: add backend implementations for
Truncop (#3860) @mooori
Bug Fixes
- Fix ndarray interpolate coord precision at boundaries (#3481) @laggui
- Fix ndarray conv2d groups channels (#3415) @laggui
- Fix candle mask broadcasting (#3489) @laggui
- Update cubecl: fix wgpu vec to scalar cast (#3496) @Cielbird
- Fix/conv2d groups backward (#3521) @laggui
- Fix/conv3d backward groups (#3533) @laggui
- [Fix] Add some missing handling for flex32 (#3551) @wingertge
- Fix backward scatter dim (#3555) @laggui
- fix: Use correct datatype when filling boolean tensors (#3593) @wingertge
- fix: Ensure output layout is the same for non-inplace SIMD ops in ndarray (#3604) @wingertge
- Fix scalar binop not contiguous (#3636) @laggui
- Fix dtype dispatch in cubecl module ops (#3658) @laggui
- Fix wgpu bool and/or (#3664) @laggui
- Fix tch bool ones and rand int (#3684) @laggui
- fix: Select assign + bool cast (#3730) @wingertge
- Fix register_float_tensor to use the correct dtype (#3774) @A2va
- Fix: autotune errors with fusion (added fallback) (#3778) @nathanielsimard
- Fix
mask_wherebroadcasted line size (#3823) @laggui - Fix adaptive avg pool2d backward line size (#3840) @laggui
- Fix line size regression bug (#3850) @nathanielsimard
- Correctly set
cubecl::random::seed(seed)(#3878) @laggui - Fix indexing for permuted tensors with cumulative ops (#3891) @wingertge
- Fix quantized reshape and
into_contiguous(#3903) @wingertge - Fix fusion matmul inputs (#3905) @laggui
- Fix powf vectorization on WGPU (#3916) @nathanielsimard
Documentation & Examples
- [Docs] Add python prerequisite disclaimer for
HuggingfaceDatasetLoader(#3484) @laggui - Mnist example augmented data (#3534) @Cielbird
- Improve DataLoaderBuilder docs. (#3482) @crutcher
- Readme + Burn Book performance section (#3686) @nathanielsimard
- Update README for improved ONNX import documentation (#3738) @antimora
- Some updates to the book (#3906) @louisfd
Fixes
- fix: link in examples (#3475) @domenicocinque
- Fix webassembly description + fusion usage + missing device (#3474) @laggui
- Fix dataset split docs (#3508) @laggui
- docs: fix example (#3498) @domenicocinque
- Fix tensor docs examples (#3525) @laggui
- Fix MNIST example model (#3549) @Cielbird
- Fix/conv2d docs display (#3586) @huy209vn
- Fix KaTeX docs (#3787) @laggui
- Fix typo in getting-started (#3868) @Charles23R
ONNX Support
- Add ONNX IsNaN and IsInf ops (#3393) @Friedrich-S
- Add support onnx bernoulli (#3394) @tye-singwa
- fix onnx reshape op elem_type inference (#3395) @tye-singwa
- Adding bitwise ONNX ops (#3120) @AshAnand34
- Add ONNX Attention op (#3423) @Friedrich-S
- Add support and tests for ONNX Abs operator (#3536) @antimora
- Infer conv spatial dims from weight rank (#3538) @laggui
- Debug log new name during ONNX renames (#3539) @torsteingrindvik
- Proto conversion: Allow f16 tensors by casting via bytemuck from raw data (#3541) @torsteingrindvik
- Fix onnx
auto_padandceil_modeattrs handling (#3542) @laggui - Support int min/max types in clip_config (#3544) @antimora
- Make onnx-ir parse error more informative. Handle more data type variants in TryFrom -> Argument (#3545) @torsteingrindvik
- Add Identity node support and fix initializer handling (#3543) @antimora
- Use
try_cast_vecwith fallback in proto conversion (#3546) @laggui - onnx-ir: Infer conv2d kernel shape from weight tensor (#3554) @torsteingrindvik
- Add comprehensive Shape type support for ONNX operations (#3381) @antimora
- Extend onnx reduce op support (#3497) @tye-singwa
- Enhance ConstantOfShape to support static shape input (#3550) @torsteingrindvik
- Don't panic on allowzero since reshape supports it (#3573) @torsteingrindvik
- ONNX enhancements to support CLIP ViT-B-32 (#3560) @antimora
- Use
prettypleaseto format burn-import output rust files (#3578) @n1ght-hunter - Fix ONNX import rank inference for nodes downstream of Shape-type constant conversion (#3564) @antimora
- Support dynamic shape and tensor sizes in ONNX resize (#3563) @antimora
- Refactor backend selection for onnx-tests (#3584) @antimora
- Add broadcasting support for add, sub, mul, and div ops (#3589) @antimora
- Fix ONNX Slice operation axes parameter handling (#3594) @antimora
- ONNX model checking: Yolo11x (#3599) @antimora
- CLIP ViT-B/32 text model ONNX verification & backend fixes (#3623) @antimora
- clip-vit-b-32-vision model verifications and fixes (#3673) @antimora
- Implemented MatMulInteger ONNX in burn-import and Uint8/int8 element types (#3672) @huy209vn
- Fix ONNX import: Integer constants serialization and MatMulInteger broadcasting (#3696) @antimora
- Add EyeLike ONNX operation support (#3731) @TheDarkchip
- Support ONNX Squeeze with axes input and no axes (#3736) @antimora
- Enhance ONNX PRelu config initialization with alpha and num_parameters (#3746) @antimora
- Add support for negative indices in Gather shape ops (#3749) @antimora
- Update ONNX dependency to stable version (#3772) @antimora
- Add NonZero ONNX operation support (#3745) @TheDarkchip
- Add static shape propagation and broadcasting support for ONNX IR operations (#3763) @antimora
trunc,fmodandModONNX ops (#3767) @antimora- Add uint16 to onnx-ir (#3791) @TheGhostHuCodes
- Add YOLO model family check with ONNX import and test (#3750) @antimora
- ONNX albert model check and bug fix (#3810) @antimora
- Add ModernBERT-base model check (#3814) @antimora
- Add all-MiniLM-L6-v2 ONNX model check (#3813) @antimora
- ONNX: support broadcasting for
bool_and(#3829) @mooori - Lift constants for ReduceMax and ReduceMean nodes (#3827) @TheGhostHuCodes
- Burn import refactor to node-based registry architecture (#3825) @antimora
- ONNX: support broadcasting for
bool_or,bool_xor(#3839) @mooori - Update ONNX model support version to Opset 16+ (#3870) @jc-cr
- Handle empty tensor constants in ONNX import (#3904) @antimora
Enhancements
- Add more operations support in fusion (#3552) @nathanielsimard
- Perf/linear layout (#3587) @nathanielsimard
- Perf/data transfer (#3695) @nathanielsimard
- Perf: GPU to CPU Copy (#3708) @nathanielsimard
- feat: Matmul quant (#3874 #3910) @wingertge
- Fix/matmul/fusion (#3899) @nathanielsimard
Refactoring
- Refactor burn-train (#3451) @Cielbird
- [chore] Migrate to memory management API refactor (#3477) @wingertge
- Update cubecl: matmul refactor (#3493) @louisfd
- Refactor/quant (#3500) @nathanielsimard
- chore: Update cubecl with new changes to Item and layouts (#3626) @wingertge
- Refactor/seed (#3641) @nathanielsimard
- Reorganize activation layer sources into
nn.activationmodule (#3627) @crutcher - Remove backend
QuantizedEncodingtype and unused candle/tch impl (#3645) @laggui - chore: Update cubecl with stacked view changes (#3687) @wingertge
- chore: Update cubecl for split traits (#3700) @wingertge
- Use bytes from cubecl (#3701) @nathanielsimard
- Update cubecl runtime features (#3711) @wingertge
- Use
ScalarIrto represent scalars generically (#3706) @laggui - chore: Update cubecl to tile refactor PR (#3728) @wingertge
- Refactor/broadcast layout (#3733) @wingertge
- Add cubecl re-export, root Tensor, doc updates and Noam scheduler fix (#3742) @laggui
- Move nn components to
burn-nn(#3740) @laggui - Update cubecl (#3752) @wingertge
- Chore update cubecl (#3764) @nathanielsimard
- Chore: update cubecl + fix no-std (#3771) @laggui
- Move optimizer components to
burn-optim(#3773) @laggui - Feat/multi streams (#3775) @nathanielsimard
- chore: Update cubecl for quant refactor and other changes (#3828) @wingertge
- chore: Update for launch refactor (#3841) @wingertge
- Refactor
Shapemanipulations (#3845) @laggui - refactor: Refactor matmul to use views for its inputs (#3846) @wingertge
- Refactor/cubecl client (#3873) @nathanielsimard
Miscellaneous
- chore: update dependencies (#3389) @reneleonhardt
- Use member name as filter for wgpu tests (#3405) @laggui
- Fix fusion no default feat (#3408) @laggui
- Bump MSRV from 1.85 to latest stable 1.87 (#3424) @Friedrich-S
- Add benchmarks.toml (#3430 #3457) @syl20bnr
- Test benchmark execution on an Nvidia A100 (#3435 #3446) @syl20bnr
- Burn-collective base (#3288) @Cielbird
- ci: split tests on GitHub runners and on GPU runners (#3382) @syl20bnr
- ci: bench on multiple machines (#3455) @syl20bnr
- ci: fix wgpu-info (#3466) @syl20bnr
HuggingfaceDatasetLoaderautomatically check for pip (#3479) @Puranjay-del-Mishra- Refactor/collective (#3450) @nathanielsimard
- cfg-mask ddp constructor (#3488) @crutcher
- Update MSRV to 1.88 (#3492) @laggui
- Fix various warnings reported by
run-checks(#3512) @crutcher - Burn-vision transforms (#3527) @Cielbird
- Add feature flag to bytemuck due to usage of API extern_crate_alloc (#3556) @torsteingrindvik
- Fix shape type annotation in test (#3576) @laggui
- Refactor burn-collective (#3572) @Cielbird
- fix: Fix bug with scalar tail in morphology op (#3588) @wingertge
- apply clippy fixes to burn-ndarray (#3618) @torsteingrindvik
- Derive clone for Record Items (#3601) @amfaber
- Add
Fromimplementations forActivationConfigand cleanup tests (#3631) @crutcher - Fix new stable clippy lints (#3643) @janhohenheim
- Fix stable clippy lints (#3644) @janhohenheim
- Fix obvious problems (#3646) @nathanielsimard
- Limit cubecl cpu target (#3656) @laggui
- Bump cubecl to use wgpu 26 (#3657) @janhohenheim
- add some missing
default-features = false(#3675) @dcrewi - Fix no-std support for
burn-no-std-testsand warning clean up (#3671) @antimora - Strengthen Doc Lints (#3691) @crutcher
- From impls for Activation (#3692) @crutcher
- Remove DimSwappedActivation (#3693) @crutcher
- Shape: into_iter(), into_ranges(), to_vec(), slice() (#3694) @crutcher
- Add
burn-storecrate for model storage with safetensors support (#3666) @antimora - Add #[allow(clippy::too_many_arguments)] to config constructor (#3737) @crutcher
- Remove empty indices tests (#3747) @laggui
- Fix various clippy lints (#3766) @wingertge
- chore: remove redundant words (#3770) @juejinyuxitu
- Fix segfaults from fusion panics with simple workaround (#3777) @wingertge
- Remove vulkan/mesa no-std CI setup (#3781) @laggui
- ci: add dispatch trigger publish workflow and bump xtask to 2.1.10 (#3788) @syl20bnr
- Slice: Copy, full(), default() (#3796) @crutcher
- Fix tests with hardcoded types (#3805) @wingertge
- Add PytorchStore for optimized model loading and in-house pickle reader (#3741) @antimora
- Fix ndarray compilation when cubecl-common enables
rayonbut ndarray doesn't (#3848) @wingertge - PyTorch reader: Add F16, BF16, and unsigned integer support (#3849) @antimora
- Fix minor typo in POEM.md (#3851) @jc-cr
- BurnpackStore (#3792) @antimora
- Expose de/serialize numericentry (#3890) @Charles23R
- Bump tch to 0.22.0 (#3892) @laggui
- update cubecl (#3896) @louisfd
- Disable no-std safetensorsstore (#3902) @antimora