What's Changed
- Update Optax version to 0.2.7.dev. by @copybara-service[bot] in #1420
- Fix doctest by @copybara-service[bot] in #1424
- Fix piecewise interpolate with 0 first split by @rdyro in #1425
- Expose weight decay as a schedule option in all alias optimizers. by @copybara-service[bot] in #1427
- Skip schedule free tests casting from complex to float. by @copybara-service[bot] in #1428
- Add
monitorandmeasure_with_emato Optax transformations. by @copybara-service[bot] in #1430 - Add consistent rms scaling for muon update by @shuningjin in #1435
- Add microbatch transformation to optax/experimental. by @copybara-service[bot] in #1434
- Add experimental aggregators in optax. by @copybara-service[bot] in #1436
- Add missing exports to optax/init.py. by @carlosgmartin in #1433
- Adding microbatch to the docs by @copybara-service[bot] in #1443
- Add a few sharding-related tests to optax. by @copybara-service[bot] in #1450
- Allow instantiating optimizers before jax has been initialized. by @copybara-service[bot] in #1454
- Adding gradient variance tracking. by @copybara-service[bot] in #1451
- Resolve remaining sharding test failures in optax. by @copybara-service[bot] in #1457
- Add optional "in_axes" and "argnames" kwargs to microbatch, which will allow for natural composition with jax.vmap. by @copybara-service[bot] in #1442
- Honor inject_hyperparams dtypes if passed as jax.Arrays. by @copybara-service[bot] in #1460
- Update reshape_batch_axis to include sharding information when in explicit sharding mode. by @copybara-service[bot] in #1459
- Fix typo in L-BFGS debug information section by @partev in #1465
- Deprecate second order utilities. by @copybara-service[bot] in #1466
- Deprecate optax.global_norm in favor of optax.tree.norm. by @carlosgmartin in #1368
- Add signum optimizer by @copybara-service[bot] in #1463
- Use internal
warn_deprecated_functioninstead of the Chex version. by @copybara-service[bot] in #1469 - Define internal
assert_trees_all_{close, equal}functions and use them in tests instead of the Chex versions. by @copybara-service[bot] in #1481 - Replace (non-test) Chex assertions, usually with ValueErrors. by @copybara-service[bot] in #1483
- Replace
chex.Numericwithjax.typing.ArrayLike. Replacechex.Scalarwithfloat/intas appropriate. by @copybara-service[bot] in #1479 - Update readme by @copybara-service[bot] in #1496
- Enforce keyword-only arguments for optional parameters in Optax losses. Disable or fix existing Pytype bugs that surfaced as a result of this change. by @copybara-service[bot] in #1505
- Remove use of cast_tree in favor of optax.tree.cast by @copybara-service[bot] in #1477
- Minor doc fixes: Include hyperlinks to functions, classes references and correct typos by @rajasekharporeddy in #1511
- [optax] Remove
jax_pmap_shmap_merge=Falseflag incontrib/_complex_valued. #jax-fixit by @copybara-service[bot] in #1515 - Enforce keyword-only arguments for optional parameters in Optax losses. Disable or fix existing Pytype bugs that surfaced as a result of this change. by @copybara-service[bot] in #1516
- Fix argument order in scale_by_distance_over_gradients by @Aaryan-549 in #1501
- Fix momo crash when loss value is a Python float by @Aaryan-549 in #1502
- Clarify difference between kl_divergence and convex_kl_divergence by @zer-art in #1514
- Fix broken test in numerics_test. by @copybara-service[bot] in #1522
- Add guidelines on contributing AI generated code, the same as JAX's. by @copybara-service[bot] in #1526
- Move _microbatching.py to microbatching.py and update init.py so we can directly import microbatching without causing pytype errors. by @copybara-service[bot] in #1525
- Fix up microbatching documentation. by @copybara-service[bot] in #1527
- Fix conjugation in Newton-Schulz iterator and update tests for comple… by @eirikfagerbakke in #1506
- Fix microbatching.Accumulator in docs. by @copybara-service[bot] in #1529
- Upgrade GitHub Actions to latest versions by @salmanmkc in #1534
- Add axis and where arguments to smooth_labels function. by @carlosgmartin in #1492
- Fix _projections_test.py: Remove prints and use optax.tree.allclose. by @carlosgmartin in #1439
- updating github actions versions by @copybara-service[bot] in #1540
- Fix typo in alias.py documentation by @partev in #1464
- Added Refined Lion Optimizer and Tests by @raghulchandramouli in #1497
- Add a bias_correction_v flag to scale_by_amsgrad to align with the original AMSGrad paper and Pytorch/tensorflow impl by @vvsvictor in #1423
- fix failing CI by @copybara-service[bot] in #1543
- Optimize tree_sum compile time using tree_reduce_associative by @Aaryan-549 in #1503
- Tests to warmup_cosine_decay_schedule edge cases by @edawite in #1413
- resolve DeprecationWarnings by @rdyro in #1547
- state dtype consistency in multi-step by @copybara-service[bot] in #1554
- Fix optax CI after merging galore. by @copybara-service[bot] in #1556
- Replace unicode escaped characters in ipynb files by @copybara-service[bot] in #1557
- Add an internal definition of
ArrayTreeand use it instead ofchex.ArrayTree. by @copybara-service[bot] in #1484 - Optional auxiliary learning rate for Adam within Moun by @RaphaelRe in #1565
- Upstream num_real_microbatches to micro_vmap and micro_grad and add unit tests. by @copybara-service[bot] in #1570
- Allow early stopping when num_real_microbatches is dynamic (Tracer). by @copybara-service[bot] in #1571
- Fix
multiply_by_parameter_scaletype in adafactor optimizer. by @copybara-service[bot] in #1504 - Generalize dice_loss with alpha/beta weighting by @aymuos15 in #1458
- adding additional non 2d array test in galore (refrence to #1541) by @yash194 in #1574
- optax fix CI by @copybara-service[bot] in #1579
- Raise error on unused extra kwargs in backtracking linesearch by @TanmayThakur2209 in #1559
- Add madgrad optimizer by @divye-joshi in #1581
- Fix optax CI by @copybara-service[bot] in #1586
- Remove TF dependency in
Lookahead Optimizer on MNISTexample by @rajasekharporeddy in #1568 - Add example demonstrating the microbatching api and comparing it against optax.MultiSteps. by @copybara-service[bot] in #1573
- Add an internal definition of
ArrayTreeand use it instead ofchex.ArrayTree. by @copybara-service[bot] in #1588 - clipping: support default unitwise_norm for 5D params by @staticpayload in #1576
- Fix CI by @copybara-service[bot] in #1592
- Finish removing Chex dependency. by @copybara-service[bot] in #1590
- attempt at fixing docs building by @copybara-service[bot] in #1593
- Release optax 0.2.7 by @copybara-service[bot] in #1596
New Contributors
- @shuningjin made their first contribution in #1435
- @partev made their first contribution in #1465
- @Aaryan-549 made their first contribution in #1501
- @zer-art made their first contribution in #1514
- @eirikfagerbakke made their first contribution in #1506
- @salmanmkc made their first contribution in #1534
- @raghulchandramouli made their first contribution in #1497
- @vvsvictor made their first contribution in #1423
- @edawite made their first contribution in #1413
- @RaphaelRe made their first contribution in #1565
- @yash194 made their first contribution in #1574
- @TanmayThakur2209 made their first contribution in #1559
- @divye-joshi made their first contribution in #1581
- @staticpayload made their first contribution in #1576
Full Changelog: v0.2.6...v0.2.7