github jax-ml/jax jax-v0.9.0
JAX v0.9.0

13 hours ago
  • New features:

    • Added jax.thread_guard, a context manager that detects when devices
      are used by multiple threads in multi-controller JAX.
  • Bug fixes:

    • Fixed a workspace size calculation error for pivoted QR (magma_zgeqp3_gpu)
      in MAGMA 2.9.0 when using use_magma=True and pivoting=True.
      (#34145).
  • Deprecations:

    • The flag jax_collectives_common_channel_id was removed.
    • The jax_pmap_no_rank_reduction config state has been removed. The
      no-rank-reduction behavior is now the only supported behavior: a
      jax.pmapped function f sees inputs of the same rank as the input to
      jax.pmap(f). For example, if jax.pmap(f) receives shape (8, 128) on
      8 devices, then f receives shape (1, 128).
    • Setting the jax_pmap_shmap_merge config state is deprecated in JAX v0.9.0
      and will be removed in JAX v0.10.0.
    • jax.numpy.fix is deprecated, anticipating the deprecation of
      numpy.fix in NumPy v2.5.0. jax.numpy.trunc is a drop-in
      replacement.
  • Changes:

    • jax.export now supports explicit sharding. This required a new
      export serialization format version that includes the NamedSharding,
      including the abstract mesh, and the partition spec. As part of this
      change we have added a restriction in the use of exported modules: when
      calling them the abstract mesh must match the one used at export time,
      including the axis names. Previously, only the number of the devices
      mattered.

Don't miss a new jax release

NewReleases is sending notifications on new releases.