v0.11.0 - Pytrees, MutableArrays, and more!
This version of Flax introduces some changes to improve interop with native JAX and adds support for the new jax.experimental.MutableArray. More on this soon! However, some breaking changes to align with the JAX way of doing things were necessary. Most code should remain intact, however, the following changes deviate from the current behavior:
- Rngsin standard layers: all standard layers no longer hold a shared reference to the- rngsobject given in the constructor, instead they now keep a- fork-ed copy of the- Rngsor- RngStreamobjects. This impacts Using Rngs in NNX Transforms and Loading Checkpoints with RNGs.
- Optimizer Updates: the Optimizer abstraction no longer holds a reference to the modelto avoid reference sharing, instead themodelmust be provided as the first argument toupdate.
- Modules as Pytrees: Modules are now pytrees! This avoid unnecessary use of splitandmergewhen interacting trivially with raw JAX transforms (state must still be manually propagated if not using MutableArrays, and referential transparency is still an issue). This affects when operating on Pytrees containing NNX Objects withjax.tree.*APIs.
Checkout the full NNX 0.10 to NNX 0.11 migration guide.
In the near future we'll share more information about new ways of using NNX with JAX transforms directly by leveraging the new Pytree and MutableArray support. Stay tuned!
What's Changed
- [nnx] mutable array p3 by @cgarciae in #4755
- [nnx] allow method calls in ToLinen by @cgarciae in #4808
- Internal change by @copybara-service[bot] in #4807
- Preserve sharding information in axes_scan by @copybara-service[bot] in #4806
- Deduplicate contributing and philosophy and move to main site by @IvyZX in #4809
- Fixed nnx.remat docstring rendering by @vfdev-5 in #4790
- Added a note to gemma guide about model's license consent on kaggle by @vfdev-5 in #4776
- [nnx] ToLinen add abtract_init flag by @cgarciae in #4813
- Modify NNX to use id(variable) instead of nnx.Variables as dictionary by @divyashreepathihalli in #4814
- Allow using LazyRngs for flax init/apply. by @copybara-service[bot] in #4818
- [nnx] remove VariableState by @cgarciae in #4800
- Fix failing CI jobs: trailing whitespace, deprecated .typeusage by @vfdev-5 in #4823
- [nnx] fix Rngs dtype check by @cgarciae in #4820
- refactor: move usages of .valueto[...]in modules_test.py by @lukeyeh in #4815
- Added training script for Gemma model by @vfdev-5 in #4822
- [nnx] add flax_pytree_module flag by @cgarciae in #4811
- create ModelAndOptimizer symbol by @copybara-service[bot] in #4849
- [nnx] remove Optimizer.model attribute by @cgarciae in #4842
- [nnx] add mutable array support in update by @cgarciae in #4851
- Migrate transforms_test.pyfrom.valueto[...]by @lukeyeh in #4841
- 0.11.0 migration guide by @cgarciae in #4854
New Contributors
- @divyashreepathihalli made their first contribution in #4814
- @lukeyeh made their first contribution in #4815
Full Changelog: v0.10.7...v0.11.0