What's Changed
- Following the JAX 0.9.2 release, the
jax_pmap_shmap_mergeconfig flag was removed so that thejax.pmapimplementation is always based onjax.jitandjax.shard_map, and opting into the oldjax.pmapbehavior is no longer an option. Optax had opted into the old behavior to give users time to migrate, and as of Optax 0.2.8 this is no longer supported. This changed shouldn't impact most users, but if you experience errors or performance regressions as a result of it, you can update your code following JAX's migration guide (or use JAX 0.9.2 or earlier and setjax.config.update("jax_pmap_shmap_merge", False)). - Explicitly specify the dtype of the gradient accumulator in the MultiStep transform. by @copybara-service[bot] in #1605
- feat: add preconditioning and coef presets to muon by @massena-t in #1602
- Backwards compatibility export for the newton schulz iterator by @copybara-service[bot] in #1608
- Remove TensorFlow dependency in
Adversarial trainingexample by @rajasekharporeddy in #1609 - Improve lookahead docstrings with example and usage notes by @rdyro in #1619
- Make sure
inject_hyperparamsuses the dtype inferred from parameters... by @copybara-service[bot] in #1615 - Memory-optimization for microbatching. by @copybara-service[bot] in #1623
- Remove TensorFlow dependency and migrate mlp_mnist to Flax NNX by @selamw1 in #1536
- Let inject use the highest dtype found in the params as the default dtype of params. by @copybara-service[bot] in #1628
- Support scheduling alpha for AdEMAmix by @copybara-service[bot] in #1630
- [JAX] Suppress type errors found by pytype after correcting definition of jax.typing.ArrayLike. by @copybara-service[bot] in #1629
- [JAX] Suppress type errors found by pytype after correcting definition of jax.typing.ArrayLike. by @copybara-service[bot] in #1633
New Contributors
- @massena-t made their first contribution in #1602
- @selamw1 made their first contribution in #1536
Full Changelog: v0.2.7...v0.2.8