What's Changed
- Make
Sequential()
be identity by @SobhanMP in #4796 - Add a JAX/Flax key concepts doc by @IvyZX in #4795
- miscellaneous improvements by @cgarciae in #4859
- Replace
jax.sharding.use_mesh
withjax.set_mesh
.jax.set_mesh
can act as a global setter or a context manager. by @copybara-service[bot] in #4862 - Pytree and ArrayRef refactor by @cgarciae in #4863
- Add old property attributes for object->pytree rename. by @copybara-service[bot] in #4864
- Add BatchNorm layers to CNN in MNIST tutorial for improved training stability by @sanepunk in #4773
- Description by @copybara-service[bot] in #4866
- update and pop for dict by @cgarciae in #4869
- simplify nnx_basics by @cgarciae in #4868
- updates to version 0.11.1 by @cgarciae in #4878
New Contributors
Full Changelog: v0.11.0...v0.11.1