-
Breaking changes:
- JAX is changing the default
jax.pmap
implementation to one implemented in
terms ofjax.jit
andjax.shard_map
.jax.pmap
is in maintenance mode
and we encourage all new code to usejax.shard_map
directly. See the
migration guide for
more information. - The
auto=
parameter ofjax.experimental.shard_map.shard_map
has been
removed. This means thatjax.experimental.shard_map.shard_map
no longer
supports nesting. If you want to nest shard_map calls, please use
jax.shard_map
. - JAX no longer allows passing objects that support
__jax_array__
directly
to, e.g.jit
-ed functions. Calljax.numpy.asarray
on them first. jax.numpy.cov
is now returns NaN for empty arrays ({jax-issue}#32305
),
and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}#32308
).- JAX no longer accepts
Array
values where adtype
value is expected. Call
.dtype
on these values first. - The deprecated function
jax.interpreters.mlir.custom_call
was
removed. - The
jax.util
,jax.extend.ffi
, andjax.experimental.host_callback
modules have been removed. All public APIs within these modules were
deprecated and removed in v0.7.0 or earlier. - The deprecated symbol
jax.custom_derivatives.custom_jvp_call_jaxpr_p
was removed. jax.experimental.multihost_utils.process_allgather
raises an error when
the input is a jax.Array and not fully-addressable andtiled=False
. To fix
this, passtiled=True
to yourprocess_allgather
invocation.- from
jax.experimental.compilation_cache
, the deprecated symbols
is_initialized
andinitialize_cache
were removed. - The deprecated function
jax.interpreters.xla.canonicalize_dtype
was removed. jaxlib.hlo_helpers
has been removed. Usejax.ffi
instead.- The option
jax_cpu_enable_gloo_collectives
has been removed. Use
jax_cpu_collectives_implementation
instead. - The previously-deprecated
interpolation
argument to
jax.numpy.percentile
andjax.numpy.quantile
has been
removed; usemethod
instead. - The JAX-internal
for_loop
primitive was removed. Its functionality,
reading from and writing to refs in the loop body, is now directly
supported byjax.lax.fori_loop
. If you need help updating your
code, please file a bug. jax.numpy.trimzeros
now errors for non-1D input.- The
where
argument tojax.numpy.sum
and other reductions is now
required to be boolean. Non-boolean values have resulted in a
DeprecationWarning
since JAX v0.5.0. - The deprecated functions in
jax.dlpack
,jax.errors
,
jax.lib.xla_bridge
,jax.lib.xla_client
, and
jax.lib.xla_extension
were removed. jax.interpreters.mlir.dense_bool_array
was removed. Use MLIR APIs to
construct attributes instead.
- JAX is changing the default
-
Changes
jax.numpy.linalg.eig
now returns a namedtuple (with attributes
eigenvalues
andeigenvectors
) instead of a plain tuple.jax.grad
andjax.vjp
will now round always primals to
float32
iffloat64
mode is not enabled.jax.dlpack.from_dlpack
now accepts arrays with non-default layouts,
for example, transposed.- The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses
cusolver. The magma and LAPACK implementations are still available via the
newimplementation
argument tojax.lax.linalg.eig
({jax-issue}#27265
). Theuse_magma
argument is now deprecated in favor
ofimplementation
. jax.numpy.trim_zeros
now follows NumPy 2.2 in supporting
multi-dimensional inputs.
-
Deprecations
jax.experimental.enable_x64
andjax.experimental.disable_x64
are deprecated in favor of the new non-experimental context manager
jax.enable_x64
.jax.experimental.shard_map.shard_map
is deprecated; going forward use
jax.shard_map
.jax.experimental.pjit.pjit
is deprecated; going forward use
jax.jit
.