Summary
This release marks a major turning point for the ecosystem with the introduction of CubeK. Our goal was to solve a classic challenge in deep learning: achieving peak performance on diverse hardware without maintaining fragmented codebases.
By unifying CPU and GPU kernels through CubeCL, we've managed to squeeze maximum efficiency out of everything from NVIDIA Blackwell GPUs to standard consumer CPUs.
Beyond performance, this release makes the library more robust, flexible, and significantly easier to debug.
This release also features a complete overhaul of the ONNX import system, providing broader support for a wide range of ONNX models. In addition, various bug fixes and new tensor operations enhance stability and usability.
For more details, check out the release post on our website.
Changelog
Breaking
We've introduced a couple of breaking API changes with this release. The affected interfaces are detailed in the sections below.
Training
We refactored burn-train to better support different abstractions and custom training strategies. As part of this,
the LearnerBuilder has been replaced by the LearningParadigm flow:
- let learner = LearnerBuilder::new(ARTIFACT_DIR)
+ let training = SupervisedTraining::new(ARTIFACT_DIR, dataloader_train, dataloader_valid)
.metrics((AccuracyMetric::new(), LossMetric::new()))
.num_epochs(config.num_epochs)
- .learning_strategy(burn::train::LearningStrategy::SingleDevice(device))
- .build(model, config.optimizer.init(), lr_scheduler.init().unwrap());
+ .summary();
- let result = learner.fit(dataloader_train, dataloader_valid);
+ let result = training.launch(Learner::new(
+ model,
+ config.optimizer.init(),
+ lr_scheduler.init().unwrap(),
+ ));Interface Changes
The scatter and select_assign operations now require an IndexingUpdateOp to specify the update behavior.
- let output = tensor.scatter(0, indices, values);
+ let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add);API calls for slice, slice_assign, and slice_fill no longer require const generics for dimensions, which cleans up the syntax quite a bit:
- let prev_slice = tensor.slice::<[Range<usize>; D]>(slices.try_into().unwrap());
+ let prev_slice = tensor.slice(slices.as_slice());The grid_sample_2d operation now supports different options.
To preserve the previous behavior, make sure to specify the matching options:
- let output = tensor.grid_sample_2d(grid, InterpolateMode::Bilinear);
+ let options = GridSampleOptions::new(InterpolateMode::Bilinear)
+ .with_padding_mode(GridSamplePaddingMode::Border)
+ .with_align_corners(true);
+ let output = tensor.grid_sample_2d(grid, options);The QuantStore variants used in QuantScheme have been updated to support a packing dimension.
pub enum QuantStore {
/// Native quantization doesn't require packing and unpacking.
Native,
+ /// Store packed quantized values in a natively supported packing format (i.e. e2m1x2).
+ PackedNative(usize),
/// Store packed quantized values in a 4-byte unsigned integer.
- U32,
+ PackedU32(usize),
}Finally, Shape no longer implements IntoIterator. If you need to iterate by-value over dimensions, access the dims field directly.
- for s in shape {
+ for s in shape.dims {Module & Tensor
- Generalize linalg::outer semantics; add linalg::outer_dim (#3923) @crutcher
- Use square() where appropriate. (#3900) @crutcher
- Add linalg matvec (#3967) @huy209vn
- Add GaussianNoise layer (#4022) @kul-sudo
- Make TransformerEncoderLayer fields public (#4053) @Mnwa
- Workaround MPS embedding allocation error in LibTorch (#4073) @antimora
- Fix Slice operation to handle empty ranges (#4083) @antimora
- Handle empty tensors in cat and slice_assign ops (#4095) @antimora
- [Breaking] Add
IndexingUpdateOptoscatterandselect_assign(#4070) @laggui - Add CrossAttention module to burn-nn (#4101) @huy209vn
- Add reflect and edge padding modes to tensor.pad (#4105 #) @antimora
- Fix GLU and quiet softmax activations (#4121) @laggui
- Add ceil_mode support to pooling operations (MaxPool, AvgPool) (#4112) @antimora
- [Breaking] Remove D2 const generic from slice / SliceArg (#4127) @crutcher
- Add backend supports_dtype (#4155) @laggui
- Fix repeat 0 times (#4216) @laggui
- feat: add hardswish activation (#4209) @mertalev
- Add more trig ops (#4282) @laggui
- Add empty/zeros/ones/full
TensorCreationOptions(#4285) @laggui - feat: nms op (#4246) @mertalev
Datasets & Training
- Refactor metric loggers(#3895 #4017) @Charles23R
- Add support for custom learning strategy (#3921) @Charles23R
- Feat/optim/distributed (#4018) @nathanielsimard
- Refactor MetricEntry (#4031) @Charles23R
- Feature muon (#3925) @NewBornRustacean
- Add warmup epochs to
MetricEarlyStoppingStrategy(#4041) @crutcher - Log running values (#4199) @Charles23R
- Fix checkpoint and summary log level (#4201) @J-F-Liu
- [Breaking] Burn train api refactor (#4223 #4283) @Charles23R
- Fix checkpointer interrupt (#4268) @Charles23R
Backends
- Add candle device seeding (#3959) @laggui
- feat: Enable tuning for MMA matmul (#3961) @wingertge
- feat: TMA autotuning (#3986) @wingertge
- feat: Enable tuning specialized matmul (#4026) @wingertge
- Add CubeCL Flash Attention module (#4089 #4192) @louisfd
- Zero-copy tensor loading for NdArray backend (#4178) @antimora
- feat: Implicit GEMM weight gradients for convolution (#4182) @wingertge
- Perf/reduce cpu + Fix OOB (#4197 #4204) @nathanielsimard
- feat: Accelerated convolution data gradient (#4220) @wingertge
- Remove linux-only constraint for cpu (#4233) @louisfd
- Perf/into contiguous (#4257) @nathanielsimard
- fix: grid sample using excessive memory (#4236 #4242) @mertalev
- Add fast-path for batched vector–matrix matmul (#4300) @louisfd
Bug Fixes
- Fix async barrier & TMA checks (#4007) @nathanielsimard
- Fix fusion reduce local already registered as output (#4014) @laggui
- Fix remainder int (#4015) @laggui
- Fix cuda mem error (#4020) @nathanielsimard
- Cleanup autodiff unused roots (#4039) @laggui
- Fix autotuner (#4049) @nathanielsimard
- Fix scatter values backward (#4064) @khoek
- More correctness fixes in autodiff ops (#4069) @khoek
- Fix transaction read (#4074) @laggui
- Fix tch bf16 kind (#4088 #4142 #4203) @laggui
- Fix cubecl cuda compilation error/typo (#4092) @BjornTheProgrammer
- Fix output dtype for argmin / argmax (#4195) @tzemanovic
- Return slice for each dimension in shape (#4152) @laggui
Documentation & Examples
- Update raspberry pi pico example (#4034 #4132) @BjornTheProgrammer
- Contributor Book: Update the "ONNX to Burn" Page (#4229) @softmaximalist
- docs: add examples for bool tensor operations (#4248) @qburke
- Update the "Adding New Operation" guide in the contributor book (#4284) @softmaximalist
- Refactor dop_timer for multiple trials (for warmup). (#4288) @crutcher
- Added documentation examples for more boolean tensor operations in burn-tensor (#4289) @qburke
Fixes
- Fix book (#3942) @laggui
- remove repetitive words in comment (#4029) @black5box
- Include katex header as symlink (#4118) @laggui
- Fix quantization docs (make it clear that only PTQ is currently supported) (#4316) @laggui
ONNX Support
- ONNX IR and import refactor to better support complex graphs (#3872 #4019 #4033 #4094) @antimora
- Add ONNX control flow operators:
If,Loop, andScan(#3936) @antimora - Silero VAD ONNX model verification (#3999) @antimora
- Add support for yolo12x model variant (#4048) @antimora
- Remove burn-import abstraction layer and use onnx-ir types directly (#4033) @antimora
- Fix ConstantOfShape output size determination (#4085) @antimora
- Specify output rank in squeeze_dims for type inference (#4086) @antimora
- Fix Expand operation to use ONNX max-semantics (#4082) @antimora
- [Breaking] Add ONNX GridSample op support and tests (#4084) @antimora
- Add RF-DETR model check for burn-import (#4087) @antimora
- Add LSTM operator support with configurable activations (#4106) @antimora
- Add memory-mapped ONNX loading with tensor data ref (#4097) @antimora
- Fix outer-scope variable references in ONNX subgraphs (If/Loop/Scan) (#4119) @antimora
- Add Reshape scalar optimization and Gather scalar input support (#4146) @antimora
- Update GELU ONNX test to use native op and fix expected values (#4161) @antimora
- Add ONNX CumSum operator support (#4162) @antimora
- Remove global ONNX opset version restriction, recommend opset 16 (#4168) @antimora
- Handle 1D slope when importing prelu from onnx (#4205) @mertalev
- Fix handling scalar scan outputs in ONNX loop nodes (#4210) @antimora
- Add ONNX external data support for models >2GB (#4158) @antimora
- fix: handle negative indices in onnx gather op (#4207) @mertalev
- Split backend tensor ops tests (#4232) @laggui
- Do not use alloc import in burn-import codegen (#4286) @laggui
- Fix ONNX where broadcasted dims (#4315) @laggui
Enhancements
- Feat/pinned memory staging (#4016) @nathanielsimard
- burn-store enhancements for troubleshooting and new enum skip flag (#4051) @antimora
- Feat/runtime error (#4079 #4110) @nathanielsimard
- Perf/improve reduce autotuning + plane non uniform control flow check (#4208) @nathanielsimard
- Packed quantized matmul with
QuantStorechanges (#4310 #4323) @wingertge
Refactoring
- chore: Update to batch caching PR for
cubecl(#3948) @wingertge - Refactor IR to define outputs as a function of the operation (#3877) @laggui
- Chore/update dtypes (#3998) @nathanielsimard
- Cleanup quantization strategy (CPU ref, ndarray only) (#4023) @laggui
- Refactor/dtype cubecl (#4032) @nathanielsimard
- Refactor of burn fusion and burn cubecl fusion (#4044) @nathanielsimard
- chore: Update to cubecl scalar refactor (#4062) @wingertge
- refactor: cubecl Runtime trait (#4065) @wingertge
- Refactor/autotuner (#4068) @nathanielsimard
- Move types from
burn-tensortoburn-stdandburn-backend(#4050) @laggui - Feat/error handling cubecl (#4076) @nathanielsimard
- Refactor
RemoteDeviceand RemoteSender. (#4113 #4108) @crutcher - Refactor
LocalCollectiveClientandLocalCollectiveServer(#4125 #4126) @crutcher - Move backend traits and types to
burn-backend(#4111) @laggui - Migrate ONNX import to burnpack format (removing Record type) (#4122) @antimora
- Refactor more basic ops (#4156) @laggui
- Refactor configurable backend tests (no more testgen macros) (#4129) @laggui
- Backends no longer depend on
burn-tensor, but strictlyburn-backend(#4169) @laggui - Refactor/cube dim (#4217) @nathanielsimard
- Update ops subfolder file names (#4271) @softmaximalist
- refactor: Migrate to usize indexing (#4273) @wingertge
- Unify ReshapeArgs / Shape.reshape(args) (#4221 #4317) @crutcher @laggui
- chore: Update to refactor cubecl types and traits (#4297) @wingertge
Miscellaneous
- Add
Shape::ravel_indexfor row-major raveling of indices. (#3879) @crutcher - ci: let CI server dispatch the test-gpu workflow (#3938) @syl20bnr
- ci: check tag version against Cargo.toml version before publishing (#3939) @syl20bnr
- Implement error for DataError (#3960) @laggui
- Pin burn crates version (#4035) @Marc-AnthonyG
- Implement
FromStrforSlicewith parsing and error handling (#3983) @crutcher - Enable no-std SafeTensors support and update hashbrown (#4071) @antimora
- Move network utilities to
burn-std(#4104) @laggui - Add 256-byte tensor alignment to burnpack format for mmap zero-copy support (#4100) @antimora
- Fix/autotune checks (#4114) @nathanielsimard
- Add direct tensor snapshot retrieval API to ModuleStore (#4131) @antimora
- Implement Slice iterator and utility methods. (#4042) @crutcher
- Shape FromStr/ToString (#4143) @crutcher
- Add contiguous index mapping for non-contiguous layer indices (#4150) @antimora
- Zero-copy loading for embedded burnpack weights (#4154) @antimora
- Add
flatten_dimsmethod toShapeand refactor tensor flattening API (#4189) @crutcher - Make xtask validate run no-std checks first. (#4198) @crutcher
- Add tracing::instrument and refactor collective operations. (#4157 #4234) @crutcher
- Fix dtype preservation when loading tensors in burn-store (#4194) @antimora
- Fix burn-store quantized tensor storage data length calculation (#4180) @antimora
- Replace
canonicalize_dimwithexpect_dim(#4196) @crutcher - Refactor: Consolidate shape and slice error handling into
ExpressionError(#4218) @crutcher - Implement TODO tests and validation for Sum operation in onnx-ir (#4251) @softmaximalist
- Fix burn-store collector tuple modules (#4270) @laggui
- Fix rand os_rng (#4295) @laggui
- chore: update xtask to 4.9.0 (#4311) @syl20bnr