github google/flax v0.11.0

latest releases: v0.12.0, v0.11.2, v0.11.1...
3 months ago

v0.11.0 - Pytrees, MutableArrays, and more!

This version of Flax introduces some changes to improve interop with native JAX and adds support for the new jax.experimental.MutableArray. More on this soon! However, some breaking changes to align with the JAX way of doing things were necessary. Most code should remain intact, however, the following changes deviate from the current behavior:

  • Rngs in standard layers: all standard layers no longer hold a shared reference to the rngs object given in the constructor, instead they now keep a fork-ed copy of the Rngs or RngStream objects. This impacts Using Rngs in NNX Transforms and Loading Checkpoints with RNGs.
  • Optimizer Updates: the Optimizer abstraction no longer holds a reference to the model to avoid reference sharing, instead the model must be provided as the first argument to update.
  • Modules as Pytrees: Modules are now pytrees! This avoid unnecessary use of split and merge when interacting trivially with raw JAX transforms (state must still be manually propagated if not using MutableArrays, and referential transparency is still an issue). This affects when operating on Pytrees containing NNX Objects with jax.tree.* APIs.

Checkout the full NNX 0.10 to NNX 0.11 migration guide.

In the near future we'll share more information about new ways of using NNX with JAX transforms directly by leveraging the new Pytree and MutableArray support. Stay tuned!

What's Changed

  • [nnx] mutable array p3 by @cgarciae in #4755
  • [nnx] allow method calls in ToLinen by @cgarciae in #4808
  • Internal change by @copybara-service[bot] in #4807
  • Preserve sharding information in axes_scan by @copybara-service[bot] in #4806
  • Deduplicate contributing and philosophy and move to main site by @IvyZX in #4809
  • Fixed nnx.remat docstring rendering by @vfdev-5 in #4790
  • Added a note to gemma guide about model's license consent on kaggle by @vfdev-5 in #4776
  • [nnx] ToLinen add abtract_init flag by @cgarciae in #4813
  • Modify NNX to use id(variable) instead of nnx.Variables as dictionary by @divyashreepathihalli in #4814
  • Allow using LazyRngs for flax init/apply. by @copybara-service[bot] in #4818
  • [nnx] remove VariableState by @cgarciae in #4800
  • Fix failing CI jobs: trailing whitespace, deprecated .type usage by @vfdev-5 in #4823
  • [nnx] fix Rngs dtype check by @cgarciae in #4820
  • refactor: move usages of .value to [...] in modules_test.py by @lukeyeh in #4815
  • Added training script for Gemma model by @vfdev-5 in #4822
  • [nnx] add flax_pytree_module flag by @cgarciae in #4811
  • create ModelAndOptimizer symbol by @copybara-service[bot] in #4849
  • [nnx] remove Optimizer.model attribute by @cgarciae in #4842
  • [nnx] add mutable array support in update by @cgarciae in #4851
  • Migrate transforms_test.py from .value to [...] by @lukeyeh in #4841
  • 0.11.0 migration guide by @cgarciae in #4854

New Contributors

Full Changelog: v0.10.7...v0.11.0

Don't miss a new flax release

NewReleases is sending notifications on new releases.