Possibly breaking changes:
- When calling
initthe 'intermediates' collection is no longer mutable.
Therefore, intermediates will no longer be returned from initialization by default. - Don't update batch statistics during initialization.
- When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the
deterministicargument inMultiHeadDotProductAttention.
Other changes:
- Rewrote various examples to use Optax instead of Flax optimizers (e.g., Imagenet, SST2).
- Added an NLP text classification example (on the SST-2 dataset) to
examples/sst2.
that uses a bidirectional LSTM (BiLSTM) to encode the input text. - Added
flax.training.train_stateto simplify using Optax optimizers. mutableargument is now available onModule.initandModule.init_with_outputs- Bug fix: Correctly handle non-default parameters of Linen Modules with nested inheritance.
- Expose
dot_product_attention_weights, allowing access to attention weights. BatchNorminstances will behave correctly during init when called multiple times.- Added a more extensive "how to contribute" guide in
contributing.md. - Add proper cache behavior for
lift.jit,
fixing cache misses. - Fix bug in Embed layer: make sure it behaves correctly when embedding is np.array.
- Fix
linen.Modulefor deep inheritance chains. - Fix bug in DenseGeneral: correctly expand bias to account for batch & noncontracting dimensions.
- Allow Flax lifted transforms to work on partially applied Modules.
- Make
MultiOptimizeruseapply_gradientinstead ofapply_param_gradient.