Summary
Burn 0.21.0 brings 4 months of improvements that make the framework significantly faster and more reliable across the board. The gains span distributed workflows for training large models all the way down to small-model inference, where the reduced framework overhead becomes especially noticeable.
We rethought our distributed computing stack around differentiable collective operations. Kernel selection is now more reliable thanks to better autotuning and a new validation layer, and a project-level burn.toml file lets you tweak those internals (and many others) without recompiling. A reworked device handle reduces framework overhead, and a new burn-dispatch crate simplifies backend selection while paving the way for faster compile times. The release also ships burn-flex, a lightweight eager CPU backend for WebAssembly and embedded targets that replaces burn-ndarray. Finally, we added early off-policy reinforcement learning support and a fresh round of kernel work on GEMV, top-k, and FFT.
For more details, check out the release post on our website.
Changelog
Breaking
We've introduced a couple of breaking changes with this release. The affected areas are detailed in the sections below.
burn-dataset cache directory
To respect platform conventions, we switched from using a hardcoded ~/.cache directory root for downloaded artifacts.
| Platform | Path |
|---|---|
| Linux | $XDG_CACHE_HOME or ~/.cache
|
| macOS | ~/Library/Caches
|
| Windows | {FOLDERPATH_LOCAL_APPDATA}
|
For Linux users without $XDG_CACHE_HOME configured, this change has no effect. The cache directory is still ~/.cache.
Interface Changes
TensorData::shape now stores a Shape instead of a Vec<usize>. Existing binary records using BinFileRecorder or BinBytesRecorder are no not forward-compatible and must be converted before upgrading.
static STATE_ENCODED: &[u8] = include_bytes!("model.bin");
let model: Model<B> = Model::new(&Default::default());
// Old format can still be loaded before upgrade, but must be re-saved in a forward-compatible format.
let record = BinBytesRecorder::<FullPrecisionSettings, &'static [u8]>::default()
.load(STATE_ENCODED, &Default::default())
.expect("Failed to decode state");
let model = model.load_record(record);
model.save_file("model.mpk", &NamedMpkFileRecorder::<FullPrecisionSettings>::new()).unwrap();The module derive macro has been improved, and the Ignored<T> wrapper is now deprecated. For fields that should not considered modules, use #[module(skip)] instead.
pub struct Conv1d<B: Backend> {
- pub padding: Ignored<PaddingConfig1d>,
+ #[module(skip)]
+ pub padding: PaddingConfig1d,
}We added support for explicit asymmetric padding. If you were using explicit padding, you must now specify the same value for all pairs. Note that PaddingConfig3d does not support asymmetric padding yet.
// Symmetric (left, right)
- PaddingConfig1d::Explicit(1)
+ PaddingConfig1d::Explicit(1, 1)
// Symmetric (top, left, bottom, right)
- PaddingConfig2d::Explicit(1, 1)
+ PaddingConfig2d::Explicit(1, 1, 1, 1)The Gelu activation module can now be configured with tanh approximation. This only affects code that instantiated Gelu directly.
- let activation = Gelu;
+ let activation = Gelu::new(); // or Gelu::default()The position-wise feed-forward module now has a configurable activation function. To keep it backwards compatible with previously saved records, the field is marked as #[module(skip)].
#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
// ...
- /// GELU activation function.
- pub gelu: Gelu,
+ /// Activation function.
+ #[module(skip)]
+ pub activation: Activation<B>,
}The Shape fields are now private and some methods have been renamed. ShapeError has been renamed to MetadataError.
- let b = tensor.shape().dims[0];
+ let b = tensor.shape()[0]
- if let Err(ShapeError::RankMismatch{...}) = lhs.broadcast(&rhs) {
+ if let Err(MetadataError::RankMismatch{...}) = lhs.broadcast(&rhs) {
- let shape = shape.swap(1, 2).unwrap();
+ let shape = shape.swapped(1, 2).unwrap();
- let shape = shape.permute(&[0, 2, 1, 3]).unwrap();
+ let shape = shape.permuted(&[0, 2, 1, 3]).unwrap();The boolean data type was expanded to include its storage type.
match bool_tensor.dtype() {
- DType::Bool => todo!(),
+ DType::Bool(BoolStore::Native) => todo!(),
+ DType::Bool(BoolStore::U8) => todo!(),
+ DType::Bool(BoolStore::U32) => todo!(),
_ => unreachable!(),
}powf is no longer supported for Int tensors, as it previously relied on incorrect implicit truncation. These operations are now only available for Float tensors.
- let tensor_i = tensor_int.powf(tensor_float);
+ let tensor_f = tensor_int.float().powf(tensor_float);
- let tensor_i = tensor_int.powf_scalar(scalar_float);
+ let tensor_f = tensor_int.float().powf_scalar(scalar_float);Backend tensor creation and conversion ops now take an explicit output dtype. This removes backend-specific dtype inference and ensures consistent behavior across backends. (Backend implementors only.)
impl BoolTensorOps<Self> for MyBackend {
- fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
+ fn bool_empty(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
// use `dtype` instead of inferring internally
}
- fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
+ fn bool_into_int(tensor: BoolTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
// use `dtype` instead of inferring internally
}
}Associated types were moved from Backend to BackendTypes. Prefer the type aliases (Device<B>, FloatTensor<B>, etc.) to avoid type resolution issues.
impl BoolTensorOps<Self> for MyBackend {
- fn bool_empty(shape: Shape, device: &<Self as Backend>::Device, dtype: BoolDType)) -> <Self as Backend>::BoolTensorPrimitive {
+ fn bool_empty(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
}
}Module & Tensor
- Feat/device policy (#4373) @laggui
- Implement basic RNN module (#4460) @aditya0by0
- Add deg2rad and rad2deg (#4462) @softmaximalist
- Implement median tensor operation (#4454) @softmaximalist
- Add Selu activation function (#4439) @antimora
- Add CELU activation function (#4441) @antimora
- Add Elu activation function (#4438) @antimora
- Add BiGru (bidirectional GRU) module (#4442) @antimora
- Add ThresholdedRelu activation function (#4440) @antimora
- Add Softsign activation function (#4437) @antimora
- [Breaking] Add configurable activation and layer_norm_eps to transformer layers (#4410) @antimora
- [Breaking] Add asymmetric padding support for conv and pool operations (#4263) @antimora
- Implement HardShrink, SoftShrink and Shrink Activations (#4556) @aditya0by0
- feat: add align_corners support to InterpolateOptions (#4518) @antimora
- feat: support padding on arbitrary dimensions (#4507) @antimora
- feat: enhance attention() with scale, attn_bias, softcap, and is_causal (#4476) @antimora
- feat: Introduce Lanczos3 interpolation method (#4601) @ovr
- Add HannWindow operator to burn-tensor (#4631) @walkinggo
- [Breaking] Remove int powf and make powi numeric op (#4646) @laggui
- [Breaking] Add bool store dtype + remove bool elem from fusion (#4649) @laggui
- [Breaking] Use device settings to provide output dtype (#4653) @laggui
- feat: add categorical sampling for tensors (#4655) @majiayu000
- Add HammingWindow operator to burn-tensor (#4698) @RunjiaChen
- Fix: make module cloning efficient for CPU devices (#4703) @antimora
- feat: support cross-kind tensor casting via .cast() (#4713) @antimora
- Add
FloatInfofor dtype-aware precision info (#4721) @antimora - Fix
unsqueeze_dimspanic (#4755) @softmaximalist - Fix unsqueeze_dims panic on duplicate sorted axes (#4764) @antimora
- feat(burn-nn): add native LocalResponseNorm module (#4765) @jcwal1516
- Add det (determinant) tensor operation (#4813) @softmaximalist
- Add Blackman window function to signal module (#4842) @softmaximalist
- Add STFT/ISTFT and thread n through FFT backend trait (#4835) @antimora
- Add linear op to ModuleOps for fused matmul+bias (#4747) @antimora
- Add native impementations for scatter_nd / gather_nd; provide autodiff for assign & add (#4709) @cu9hue
- Fix conv x-backward padding_out bug (#4806) @antimora
- Extract float math ops in a new trait (#4891) @skewballfox
linalg::lu: Improve numerical handling and small perf cleanup (#4902) @softmaximalist- Adding complex to complex FFT implementation (#4903) @RunjiaChen
- add autodiff for scatter_nd min/max/mul (#4909) @cu9hue
- fix: conv_transpose x-backward output size (#4916) @SAY-5
- Change pwff activation to #[module(skip)] for backward compat (stateless) (#4929)
Datasets & Training
- Implement SSIM vision metric (#4396) @softmaximalist
- add KLDivLoss and batch_mean in reduction (#4399) @donjuanplatinum
- Fix cubek matmul stage size (#4435) @laggui
- Implement the PSNR vision metric (#4379) @softmaximalist
- Implement Mean(L(P) Norm Error)Loss (#4341) @softmaximalist
- Feature flag + Tests for RL in burn-rl and burn-train (#4470) @Charles23R
- Burn rl (#4447) @Charles23R
- add AMSgrad support for Adam/AdamW (#4388) @donjuanplatinum
- add LBFGS optimizer (#4471) @donjuanplatinum
- Add SequenceOutput struct for sequence prediction outputs (#4474) @softmaximalist
- fix: OptimSharded strategy validation device mismatch (#4527) @Dreaming-Codes
- Implement CTC loss (#4529) @softmaximalist
- Add Smooth L1 loss (#4547) @softmaximalist
- Implements: LPIPS matrics for Image quality (#4403) @koreaygj
- feat: Implements DISTS metric (#4574) @koreaygj
- Add multi-scale SSIM for image quality assessment (#4555) @softmaximalist
- Add Gram Matrix Loss for vision tasks (#4595) @softmaximalist
- Add evaluator summary (#4578) @laggui
- Fix cosine scheduler record in composed scheduler (#4617) @laggui
- Implement RNNT loss (#4623) @cong-or
- feat: add FID vision metric (#4644) @cong-or
- Add Adan optimizer implementation with tests (#4651) @sepcnt
- [Breaking] Split
TrainingStrategyto decouple theDistributedBackendrequirement (#4710) @laggui - Fix
CrossEntropyLosswith probabilities (#4829) @laggui
Backends
- More explicit global dtype support (#4400) @laggui
- opt(burn-cubecl): Optimized tensors by default (#4402) @wingertge
- Add device dtype usage (#4404) @laggui
- Attention: add autotune gate (#4554) @louisfd
- Attention autotune (#4552) @louisfd
- Attention: remove default impl and implement for all backends (#4544) @louisfd
- Add native sign unary ops for CubeCL float and int (#4513) @yash27-lab
- [Feat] Global backend
Dispatch(#4508) @laggui - allow flash attention with causal (#4509) @louisfd
- Perf: Improve fusion score (#4511) @nathanielsimard
- Dispatch autodiff checkpointing strategy support (#4629) @laggui
- Selector/attention (#4648) @louisfd
- update cubek and fix vecmat autotune (#4682) @louisfd
- update cubek and cubecl (#4699) @louisfd
- update cubek & fix gemv autotune (#4726) @louisfd
- Feat/add rfft (#4707) @Sublime12
- Feat/add irfft (#4719) @Sublime12
- Feat/implement fusion for rfft (#4735) @Sublime12
- Feat/implement fusion for irfft (#4736) @Sublime12
- Add burn-flex CPU backend (#4761) @antimora
- burn-flex: enable f16 tests and fix mean overflow, grid_sample and quantization (#4769) @antimora
- Add softmax and layer_norm backend trait hooks (#4797) @antimora
- burn-flex: implement softmax and layer_norm backend op (#4805) @antimora
- Matmul selection (#4773) @nathanielsimard
- Add native dispatch overrides and native tch ops for softmax, layer_norm (#4834) @antimora
- [Breaking] Split Associated Types from Backend into BackendTypes (#4868) @skewballfox
- Add ctc_loss backend trait hook + tch and cubecl impls (#4819) @antimora
- Update CubeK: tile matmul refactor (#4901) @louisfd
- Add argtopk for Cubecl backend (#4900) @Sublime12
- Add fusion integration for argtopk (#4904) @Sublime12
- Add cubecl integration to topk (#4906) @Sublime12
- Fusion tests (#4872) @nathanielsimard
- Enable & fix cubecl tests w/ fusion (#4917) @laggui
Bug Fixes
- Fix reduce line size parallel and mean accumulator precision (#4467) @laggui
- fix: default to single device strat when only 1 device (#4463) @Charles23R
- fix: use all dilation entries in
max_pool2d_with_indices_backward(#4466) @fcasal - Fix cubek matmul stage size (#4435) @laggui
- fix: Fix interpolate with NHWC input (#4363) @wingertge
- fix: Actually implement conv backwards ops for
burn-fusion/burn-router(#4360) @wingertge - Fix memory growth: use GraphLocator::remove_entry for orphan cleanup (#4342) @jnamika
- fix: Bool from_data_dtype panics on GPU backends (#4551) @antimora
- fix: resolve macOS build and test failures (#4545) @antimora
- Fix too many kernels (#4505) @nathanielsimard
- Fix quantization non-contiguous input (#4498) @laggui
- fix overflow in int_abs_elem for i64 min value (#4486) @Olexandr88
- Fix: create multiple elemwise fused block (#4497) @nathanielsimard
- Fix fusion cumulative op inputs (#4621) @laggui
- Fix dispatch autodiff feature propagation (#4592) @laggui
- Fix
conv2d_weight_backwardw/ strided channels and unit spatial dims (#4591) @laggui - Fix(lpips): load ImageNet backbone weights for pretrained models (#4557) @koreaygj
- Fix tch int_zeros dtype in sync (#4664) @laggui
- Fix fusion kernel vector_size mismatch on f16 output writes (#4675) @AdrianEddy
- Fix fusion consistency checks and binding estimation (#4695) @nathanielsimard
- Fix attention_fallback NaN for fully-masked rows (#4697) @antimora
- fix output in attention tuner (#4702) @louisfd
- fix: use integer arithmetic for nearest-neighbor coordinate scaling (#4687) @wkrettek
- Fix cubecl cuda all-reduce + remove useless check in distributed server (#4720) @Charles23R
- Fix fusion scalar broadcasting in
write_output_aligned(#4741) @laggui - Fix quantization tests and flaky tolerance (#4743) @laggui
- Fix select_assign OOB (#4760) @nathanielsimard
- Fix burn-flex bool binary ops to broadcast operands (#4775) @antimora
- Fix burn-flex attention rejecting broadcasted mask/bias (#4777) @antimora
- fix(ndarray): grouped conv SIMD clamp + regressions (#4727) @dnvt
- Fix autotune context, remove unsafe code (#4781) @ArthurBrussee
- Fix cubecl cross product on non-last dimension (#4850) @dschulmeist
- Fix burn-flex to_contiguous fast path for prefix views (#4856) @antimora
- Fix burn-flex sum_dim reading contiguous storage on transposed input (#4861) @antimora
- Fix burn-flex argmax NaN ordering; tighten expand; precise erf (#4859) @antimora
- Fix fusion reduce broadcasted when multi block local might be a view (#4867) @laggui
- Fix select_assign OOB units (#4870) @laggui
- Update cubecl + cubek: fix matmul, reduce WASM and vector size check on strided tensors (#4874) @laggui
- Fix fusion read_quantized native type (#4923) @laggui
Documentation & Examples
- Update Burn Book: metrics and trig functions (#4413) @softmaximalist
- docs: add DataframeDataset example using Polars (#4298) @SameerVers3
- doc(notebook) : add more basic operations and some examples (#4542) @Tyooughtul
- Update documentation link for burn-store (#4619) @softmaximalist
- Update building-blocks chapter (#4625) @softmaximalist
- Update ONNX import docs for LoadStrategy and from_bytes (#4607) @antimora
- Use burn-flex in docs and examples (#4841) @antimora
Fixes
- Add field docs to generated methods (#4408) @swfsql
- Fix typo in dataset.md in Burn Book (#4380) @softmaximalist
- Fix book guide training changes (#4340) @laggui
- Fix image-classification-web links (#4536) @laggui
- fix: replace ValidStep with InferenceStep in training.md (#4620) @TsaoLun
Enhancements
- Add
module.train()to move a module back to the autodiff backend (#3975) @laggui - Perf/fusion/reduce broadcasted (#4338) @nathanielsimard
- feat: Enable 64-bit indexing for kernels (#4502) @wingertge
- Refactor/device handle (#4593) @nathanielsimard
- All reduce backward (#4650 #4873) @Charles23R
- Perf/burn fusion overhead (#4645) @nathanielsimard
- Device service usage (#4839) @nathanielsimard
Refactoring
- Add
Scalarruntime literal (#4337) @laggui - Move ONNX crates to burn-onnx repository (#4393) @antimora
- chore: Update cubecl to runtime config refactor (#4489) @wingertge
- chore: deprecate burn-candle backend (#4416) @antimora
- Move ONNX import to
burn-onnxcrate (#4361) @laggui - [Breaking] perf: Make backing storage of
Shapemore flexible (#4516) @wingertge - refactor: Move from
CubeOptiontoOption(#4543) @wingertge - [Breaking] refactor: Metadata type/strides refactor (#4534) @wingertge
- Use shape in
TensorData(#4603) @laggui - refactor: Vector size generic (#4624) @wingertge
- refactor: View launch (#4639) @wingertge
- Refactor backend tests to set device settings at initialization + use
Dispatch(#4666) @laggui - Prep for Group Multi Optimizers (#4818) @crutcher
- Cleanup OptimizerAdaptor / GradAdaptor API. (#4822) @crutcher
- Remove unused M param from SimpleOptimizerMapper. (#4823) @crutcher
- Move tensor tests from burn-flex to burn-backend-tests (#4812) @antimora
- Fusion all reduce + refactor collective (#4803) @Charles23R
- Migrate benchmarks from burn-flex to burn-backend-tests (#4853) @antimora
- Migrate default test backend from NdArray to Flex (#4854) @antimora
- Update cubecl: refactor toml config, fix autotune priority and fix persistent memory pool reset (#4858) @nathanielsimard
- Add burn-std::config runtime configuration with fusion logging and search optimization (#4864) @nathanielsimard
- Update/cubecl to client (#4866) @Charles23R
- Centralize internal burn-* deps in [workspace.dependencies] (#4876) @antimora
- Remove optim::optim (#4924) @crutcher
Miscellaneous
- Update zip + time (#4468) @laggui
- Update cubecl wgpu v28 (#4244) @laggui
- [Breaking] Use
cache_dir()instead of hardcoded~/.cachepath (#4372) @antimora - Make
ElementComparisonoptional for dtypes (#4255) @skewballfox - Performance tweaks to the lp_norm code. (#4318) @crutcher
- ensure that tensor is owned on iter_dim call (#4309) @tzemanovic
- Use NodeType to point to unimplemented node (#4334) @laggui
- Bump burn version 0.21 (#4333) @laggui
- feat(burn-store): add ModuleAdapter chaining (#4407) @huahuadeliaoliao
- Replace Vec-based TransitionBuffer with tensor-backed storage (#4504) @arferreira
- Optional Ordering for NdArrayElement (#4559) @skewballfox
- Move
burn-nnmodule name checks inburn-storeadapter to the test section (#4580) @softmaximalist - Expose
BurnpackError(#4585) @AdrianEddy - Add HalfPrecisionAdapter for F32/F16 mixed-precision storage (#4594) @antimora
- Improve module derive + add
#[module(skip)]attribute (#4618) @laggui - Fix SSIM float types to f32 (#4602) @softmaximalist
- Fix function arg name inconsistencies (#4626) @softmaximalist
- Make Param Sync for parallel model inference (#4701) @antimora
- Fix flaky initializer_normal_init test (#4766) @leohenon
- Add Record<(R0,)> 1-Tuple (#4825) @crutcher
- Display FlexDevice as Cpu (#4857) @antimora
- Fix rustls-webpki audit (#4863) @laggui
- Fix
PytorchReaderbugs to load legacy files correctly (#4897) @softmaximalist - Add Clone + 'static bounds to LrScheduler::Record and derive Clone for scheduler records (#4905) @crutcher
- Add ParamId::try_deserialize() (#4881) @crutcher
- Use gather_nd in RNN-T gather_loss (#4895) @antimora
- Re-enable fusion f16 conv + bn regression tests (#4920) @laggui
- rnnt.rs: Optimize extract_log_probs and init_alpha (#4922) @softmaximalist
- Fix some test tolerances (#4926) @laggui
Full Changelog: v0.20.0...v0.21.0