github google/flax v0.11.2
0.11.2

latest release: v0.12.0
one month ago

What's Changed

nnx.merge now doesn't create a copy of the Variables in the incoming states by default, meaning that the new merged structures holds references to the incoming Variables. This enables new patterns, for example its now possible to create models with the same state but with different runtime behavior:

model = SomeModel(...)
# create eval model
eval_model = nnx.merge(*nnx.split(model))  # same Variables, different structure
eval_model.eval()

model and eval_model share the same Variables and are therefore kept in sync but have different runtime behavior, this avoids having to constantly mutate a single model back and forth between different runtime modes which can be error prone / cause unwanted recompilation.

To keep the old behavior use nnx.merge(..., copy=True).

PRs

  • add Rngs random helpers by @cgarciae in #4876
  • Fix re-export and docs for identity by @jlperla in #4850
  • Fix ToLinen docstring return description by @mohsinm-dev in #4852
  • Update doc build instructions and clean up unused packages by @IvyZX in #4885
  • Improve docs related with dataclasses by @IvyZX in #4884
  • Fix broken contributing documentation link by @mohsinm-dev in #4855
  • Internal change by @copybara-service[bot] in #4886
  • Fix string key preservation in replace_by_pure_dict by @mohsinm-dev in #4860
  • Remove the need for Conv and ConvTranspose to know the precise batch size. by @copybara-service[bot] in #4877
  • call jax's source_info_util.register_exclusion in flax's traceback_util.register_exclusion by @copybara-service[bot] in #4887
  • Update typo in nnx.Optimizer by @codinfox in #4880
  • Exposed split_rngs docstring in the docs_nnx by @vfdev-5 in #4846
  • Pin sentencepiece version to 0.2.0 to fix head by @IvyZX in #4892
  • Relax duplicate check to exclude non-string values such as PartitionSpec.UNCONSTRAINED, since those can be repeated. by @copybara-service[bot] in #4881
  • add find_duplicates by @cgarciae in #4894
  • Sharding API improvements (non breaking) by @IvyZX in #4893
  • document jax.random shorthand methods by @cgarciae in #4899
  • Optimiser was already instantiated using the model - 05_vae.py by @nenuadrian in #4857
  • revert is_leaf logic in _check_carry_same_references by @copybara-service[bot] in #4903
  • Doc fix: remove outdated advice on flax v0.6.10; it was released two years ago. by @copybara-service[bot] in #4910
  • Fix bug when raising ScopeParamNotFoundError. by @copybara-service[bot] in #4898
  • fix mypy on main by @cgarciae in #4909
  • merge no copy Variables by @cgarciae in #4912
  • update version to 0.11.2 by @copybara-service[bot] in #4915

New Contributors

Full Changelog: v0.11.1...v0.11.2

Don't miss a new flax release

NewReleases is sending notifications on new releases.