-
Breaking changes:
jax.dlpack.from_dlpack
no longer accepts a DLPack capsule. This
behavior was deprecated and is now removed. The function must be called
with an array implementing__dlpack__
and__dlpack_device__
.
-
Changes
-
The minimum supported NumPy version is now 2.0. Since SciPy 1.13 is required
for NumPy 2.0 support, the minimum supported SciPy version is now 1.13. -
JAX now represents constants in its internal jaxpr representation as a
LiteralArray
, which is a private JAX type that duck types as a
numpy.ndarray
. This type may be exposed to users viacustom_jvp
rules,
for example, and may break code that usesisinstance(x, np.ndarray)
. If
this breaks your code, you may convert these arrays to classic NumPy arrays
usingnp.asarray(x)
.
-
-
Bug fixes
arr.view(dtype=None)
now returns the array unchanged, matching NumPy's
semantics. Previously it returned the array with a float dtype.jax.random.randint
now produces a less-biased distribution for 8-bit and
16-bit integer types ({jax-issue}#27742
). To restore the previous biased
behavior, you may temporarily set thejax_safer_randint
configuration to
False
, but note this is a temporary config that will be removed in a
future release.
-
Deprecations:
- The parameters
enable_xla
andnative_serialization
forjax2tf.convert
are deprecated and will be removed in a future version of JAX. These were
used for jax2tf with non-native serialization, which has been now removed. - Setting the config state
jax_pmap_no_rank_reduction
toFalse
is
deprecated. By default,jax_pmap_no_rank_reduction
will be set toTrue
andjax.pmap
shards will not have their rank reduced, keeping the same
rank as their enclosing array.
- The parameters