Deprecations
Variable.value
Variable.value is now deprecated. Consider the following example:
import jax.numpy as jnp
import jax
from flax import nnx
my_param = nnx.Param({'a': 0.0})
@nnx.jit
def f(m):
m.value['a'] = 1.0
return mRunning f(my_param) produces Param(value={'a': 0.0}), not Param(value={'a': 1.0}) as before. This is because getting the value parameter new returns a copy of the pytree values (like dict / list). Instead, use the __setitem__ method to update the value:
@nnx.jit
def f(m):
m['a'] = 1.0
return mnnx.Data and nnx.Static
nnx.Data and nnx.Static annotations are now deprecated. To create nnx.Pytree or nnx.Module dataclasses use the new nnx.dataclass with nnx.data and nnx.static as field descriptors.
# old
@dataclasses.dataclass
class Foo(nnx.Pytree):
a: nnx.Data[int]
b: nnx.Static[str]
# new
@nnx.dataclass
class Foo(nnx.Pytree):
a: int = nnx.data()
b: str = nnx.static()Pull Requests
- Clarify
*Normlayer docstrings:axis_index_groupsis unused under SPMD jit. by @copybara-service[bot] in #4940 - Move
ArrayRefcreation to the end ofVariablecreation by @IvyZX in #4980 - clean up jax.Ref-related names by @copybara-service[bot] in #4988
- Add compute_flops and compute_vjp_flops options to
nnx.tabulateby @samanklesaria in #4948 - Fix nnx.tabulate crash with empty dict/None values (fixes #4889) by @mohsinm-dev in #4891
- Future-proof imports of jax.new_ref / jax.Ref. by @copybara-service[bot] in #4986
- Use
jnp.stackinstead ofnp.stackinflax.training.common_utils.stack_forestby @vfdev-5 in #4991 - Fixed broken nnx.statelib.diff by @vfdev-5 in #4992
- Implemented spectral norm in NNX by @mattbahr in #4623
- Improve Variable.{get,set}_metadata by @cgarciae in #4985
- Move iter_children and iter_modules to functions by @samanklesaria in #4961
- Avoid install, import, or tests with tensorflow-text under Python 3.13+. by @jburnim in #5001
- disallow setting metadata through settattr by @cgarciae in #4993
- Use sphinx 6.2+ for docs, which works with Python 3.13. by @jburnim in #5009
- Removed kernel_init/bias_init atttributes from popular layers by @vfdev-5 in #4998
- Migrate from
jax.experimental.enable_x64tojax.enable_x64. by @copybara-service[bot] in #5011 - Add Rngs KeylessInitializers by @cgarciae in #5017
- optimize scan transpositions by @cgarciae in #5015
- Variable refactor by @cgarciae in #5006
- Remove invalid gymnasium dependency in pyproject.toml by @IvyZX in #5016
- Use jax.shard_map in flax by @copybara-service[bot] in #5020
- use jax.shard_map by @copybara-service[bot] in #5018
- Fix formatting in PR template checklist by @rapsealk in #5024
- Fixed attribute visualization in treescope_repr by @vfdev-5 in #5022
- feat: add
nnx.set_metadatato in-place change metadata of the state variables ofnnx.Modules by @pfackeldey in #5007 - Update README to use fully qualified
nnx.Linearin example by @rapsealk in #5023 - Fix nnx tabulate variable hooks by @mohsinm-dev in #5008
- python 3.13 support by @cgarciae in #4987
- Added a note in nnx.jit about arg donation by @vfdev-5 in #5031
- Add flip doc link to eager sharding error message by @IvyZX in #5033
- fix reseed for abstract values by @cgarciae in #5034
- Deduplicate
Variablenodes initer_graphand eliminate recursion. by @copybara-service[bot] in #5035 - Support for python 3.14 by @vfdev-5 in #5032
- [docs] Exposed more helper functions/classes in state.rst by @vfdev-5 in #5037
- Copybara import of the project: by @copybara-service[bot] in #5041
- Internal change by @copybara-service[bot] in #5048
- filter grad state in nnx.Optimizer by @copybara-service[bot] in #5049
- Add NNX WeightNorm (update of #4568) by @samanklesaria in #5043
- Fix shard_map documentation link in compilation.py by @vfdev-5 in #5038
- Fix ValueError when
nnx.jitis used withnnx.custom_vjpby @samanklesaria in #5045 - Recursive map by @chapman20j in #5042
- Convert linen pytorch guide to nnx by @samanklesaria in #4999
- Set Mode with Tests by @chapman20j in #5056
- Fixing Optimizer docstring - fixing #5060 by @Lucas-Fernandes-Martins in #5061
- Update tutorial examples to thread explicit RNGs by @samanklesaria in #4975
- Fix NNX jit static args with in_shardings issue #4989 by @mohsinm-dev in #4996
- support explicit sharding in eager sharding by @cgarciae in #5070
- Added missing LayerNorm test case into TestLayersSameGraph by @vfdev-5 in #5076
- fix main by @cgarciae in #5081
- docs: Document
allow_duplicatesargument ofnnx.to_arrays. by @dan-zheng in #5083 - add promote_dtype to all standard layers by @cgarciae in #5080
- add nnx.dataclass by @cgarciae in #5066
- Expand ConvTranspose padding documentation by @samanklesaria in #4990
- Added kernel_metadata/bias_metadata args to nnx layers by @vfdev-5 in #5074
- Add nnx.use_eager_sharding context manager by @samanklesaria in #5079
- fix main by @cgarciae in #5090
- Adding set_mode_info by @chapman20j in #5071
- Fixed nnx.scan with carry as pytree and sow by @vfdev-5 in #5073
- Fix bound method auto-unbinding for NNX transforms by @mohsinm-dev in #5055
- deprecate Variable.value by @cgarciae in #5052
- Add eq for variables by @samanklesaria in #5084
- Fixed deprecated .value usage failing CI tests by @vfdev-5 in #5097
- update jax minver to 0.8.1 by @cgarciae in #5095
New Contributors
- @samanklesaria made their first contribution in #4948
- @jburnim made their first contribution in #5001
- @rapsealk made their first contribution in #5024
- @pfackeldey made their first contribution in #5007
- @chapman20j made their first contribution in #5042
- @Lucas-Fernandes-Martins made their first contribution in #5061
Full Changelog: v0.12.0...v0.12.1