-
New features:
- Added
jax.thread_guard, a context manager that detects when devices
are used by multiple threads in multi-controller JAX.
- Added
-
Bug fixes:
- Fixed a workspace size calculation error for pivoted QR (
magma_zgeqp3_gpu)
in MAGMA 2.9.0 when usinguse_magma=Trueandpivoting=True.
(#34145).
- Fixed a workspace size calculation error for pivoted QR (
-
Deprecations:
- The flag
jax_collectives_common_channel_idwas removed. - The
jax_pmap_no_rank_reductionconfig state has been removed. The
no-rank-reduction behavior is now the only supported behavior: a
jax.pmapped functionfsees inputs of the same rank as the input to
jax.pmap(f). For example, ifjax.pmap(f)receives shape(8, 128)on
8 devices, thenfreceives shape(1, 128). - Setting the
jax_pmap_shmap_mergeconfig state is deprecated in JAX v0.9.0
and will be removed in JAX v0.10.0. jax.numpy.fixis deprecated, anticipating the deprecation of
numpy.fixin NumPy v2.5.0.jax.numpy.truncis a drop-in
replacement.
- The flag
-
Changes:
jax.exportnow supports explicit sharding. This required a new
export serialization format version that includes the NamedSharding,
including the abstract mesh, and the partition spec. As part of this
change we have added a restriction in the use of exported modules: when
calling them the abstract mesh must match the one used at export time,
including the axis names. Previously, only the number of the devices
mattered.