-
Changes
- We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g.pip install jax[cuda12]
). - JAX now requires ml_dtypes version 0.4.0 or newer.
- Removed backwards-compatibility support for old usage of the
jax.experimental.export
API. It is not possible anymore to use
from jax.experimental.export import export
, and instead you should use
from jax.experimental import export
.
The removed functionality has been deprecated since 0.4.24.
- We anticipate that this will be the last release of JAX and jaxlib
-
Deprecations
jax.sharding.XLACompatibleSharding
is deprecated. Please use
jax.sharding.Sharding
.jax.experimental.Exported.in_shardings
has been renamed as
jax.experimental.Exported.in_shardings_hlo
. Same forout_shardings
.
The old names will be removed after 3 months.- Removed a number of previously-deprecated APIs:
- from {mod}
jax.core
:non_negative_dim
,DimSize
,Shape
- from {mod}
jax.lax
:tie_in
- from {mod}
jax.nn
:normalize
- from {mod}
jax.interpreters.xla
:backend_specific_translations
,
translations
,register_translation
,xla_destructure
,
TranslationRule
,TranslationContext
,XlaOp
.
- from {mod}
- The
tol
argument of {func}jax.numpy.linalg.matrix_rank
is being
deprecated and will soon be removed. Usertol
instead. - The
rcond
argument of {func}jax.numpy.linalg.pinv
is being
deprecated and will soon be removed. Usertol
instead. - The deprecated
jax.config
submodule has been removed. To configure JAX
useimport jax
and then reference the config object viajax.config
. - {mod}
jax.random
APIs no longer accept batched keys, where previously
some did unintentionally. Going forward, we recommend explicit use of
{func}jax.vmap
in such cases.
-
New Functionality
- Added {func}
jax.experimental.Exported.in_shardings_jax
to construct
shardings that can be used with the JAX APIs from the HloShardings
that are stored in theExported
objects.
- Added {func}