github tracel-ai/burn v0.21.0

5 hours ago

Summary

Burn 0.21.0 brings 4 months of improvements that make the framework significantly faster and more reliable across the board. The gains span distributed workflows for training large models all the way down to small-model inference, where the reduced framework overhead becomes especially noticeable.

We rethought our distributed computing stack around differentiable collective operations. Kernel selection is now more reliable thanks to better autotuning and a new validation layer, and a project-level burn.toml file lets you tweak those internals (and many others) without recompiling. A reworked device handle reduces framework overhead, and a new burn-dispatch crate simplifies backend selection while paving the way for faster compile times. The release also ships burn-flex, a lightweight eager CPU backend for WebAssembly and embedded targets that replaces burn-ndarray. Finally, we added early off-policy reinforcement learning support and a fresh round of kernel work on GEMV, top-k, and FFT.

For more details, check out the release post on our website.

Changelog

Breaking

We've introduced a couple of breaking changes with this release. The affected areas are detailed in the sections below.

burn-dataset cache directory

To respect platform conventions, we switched from using a hardcoded ~/.cache directory root for downloaded artifacts.

Platform Path
Linux $XDG_CACHE_HOME or ~/.cache
macOS ~/Library/Caches
Windows {FOLDERPATH_LOCAL_APPDATA}

For Linux users without $XDG_CACHE_HOME configured, this change has no effect. The cache directory is still ~/.cache.

Interface Changes

TensorData::shape now stores a Shape instead of a Vec<usize>. Existing binary records using BinFileRecorder or BinBytesRecorder are no not forward-compatible and must be converted before upgrading.

static STATE_ENCODED: &[u8] = include_bytes!("model.bin");

let model: Model<B> = Model::new(&Default::default());

// Old format can still be loaded before upgrade, but must be re-saved in a forward-compatible format.
let record = BinBytesRecorder::<FullPrecisionSettings, &'static [u8]>::default()
    .load(STATE_ENCODED, &Default::default())
    .expect("Failed to decode state");
let model = model.load_record(record);

model.save_file("model.mpk", &NamedMpkFileRecorder::<FullPrecisionSettings>::new()).unwrap();

The module derive macro has been improved, and the Ignored<T> wrapper is now deprecated. For fields that should not considered modules, use #[module(skip)] instead.

pub struct Conv1d<B: Backend> {
-    pub padding: Ignored<PaddingConfig1d>,
+    #[module(skip)]
+    pub padding: PaddingConfig1d,
}

We added support for explicit asymmetric padding. If you were using explicit padding, you must now specify the same value for all pairs. Note that PaddingConfig3d does not support asymmetric padding yet.

// Symmetric (left, right)
- PaddingConfig1d::Explicit(1)
+ PaddingConfig1d::Explicit(1, 1)
// Symmetric (top, left, bottom, right)
- PaddingConfig2d::Explicit(1, 1)
+ PaddingConfig2d::Explicit(1, 1, 1, 1)

The Gelu activation module can now be configured with tanh approximation. This only affects code that instantiated Gelu directly.

- let activation = Gelu;
+ let activation = Gelu::new(); // or Gelu::default()

The position-wise feed-forward module now has a configurable activation function. To keep it backwards compatible with previously saved records, the field is marked as #[module(skip)].

#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
    // ...
-   /// GELU activation function.
-   pub gelu: Gelu,
+   /// Activation function.
+   #[module(skip)]
+   pub activation: Activation<B>,
}

The Shape fields are now private and some methods have been renamed. ShapeError has been renamed to MetadataError.

- let b = tensor.shape().dims[0];
+ let b = tensor.shape()[0]

- if let Err(ShapeError::RankMismatch{...}) = lhs.broadcast(&rhs) {
+ if let Err(MetadataError::RankMismatch{...}) = lhs.broadcast(&rhs) {

- let shape = shape.swap(1, 2).unwrap();
+ let shape = shape.swapped(1, 2).unwrap();

- let shape = shape.permute(&[0, 2, 1, 3]).unwrap();
+ let shape = shape.permuted(&[0, 2, 1, 3]).unwrap();

The boolean data type was expanded to include its storage type.

match bool_tensor.dtype() {
-   DType::Bool => todo!(),
+   DType::Bool(BoolStore::Native) => todo!(),
+   DType::Bool(BoolStore::U8) => todo!(),
+   DType::Bool(BoolStore::U32) => todo!(),
    _ => unreachable!(),
}

powf is no longer supported for Int tensors, as it previously relied on incorrect implicit truncation. These operations are now only available for Float tensors.

- let tensor_i = tensor_int.powf(tensor_float);
+ let tensor_f = tensor_int.float().powf(tensor_float);

- let tensor_i = tensor_int.powf_scalar(scalar_float);
+ let tensor_f = tensor_int.float().powf_scalar(scalar_float);

Backend tensor creation and conversion ops now take an explicit output dtype. This removes backend-specific dtype inference and ensures consistent behavior across backends. (Backend implementors only.)

impl BoolTensorOps<Self> for MyBackend {
-    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
+    fn bool_empty(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
        // use `dtype` instead of inferring internally
    }
-    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
+    fn bool_into_int(tensor: BoolTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
        // use `dtype` instead of inferring internally
    }
}

Associated types were moved from Backend to BackendTypes. Prefer the type aliases (Device<B>, FloatTensor<B>, etc.) to avoid type resolution issues.

impl BoolTensorOps<Self> for MyBackend {
-    fn bool_empty(shape: Shape, device: &<Self as Backend>::Device, dtype: BoolDType)) -> <Self as Backend>::BoolTensorPrimitive {
+    fn bool_empty(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
    }
}

Module & Tensor

Datasets & Training

Backends

Bug Fixes

Documentation & Examples

Fixes

Enhancements

Refactoring

Miscellaneous

Full Changelog: v0.20.0...v0.21.0

Don't miss a new burn release

NewReleases is sending notifications on new releases.