github google/flax v0.5.0
Version 0.5.0

latest releases: v0.12.0, v0.11.2, v0.11.1...
3 years ago

New features:

  • Added flax.jax_utils.ad_shard_unpad() by @lucasb-eyer
  • Implemented default dtype FLIP.
    This means the default dtype is now inferred from inputs and params rather than being hard-coded to float32.
    This is especially useful for dealing with complex numbers because the standard Modules will no longer truncate
    complex numbers to their real component by default. Instead the complex dtype is preserved by default.

Bug fixes:

  • Fix support for JAX's experimental_name_stack.

Breaking changes:

  • In rare cases the dtype of a layer can change due to default dtype FLIP. See the "Backward compatibility" section of the proposal for more information.

Don't miss a new flax release

NewReleases is sending notifications on new releases.