-
New features
- JAX now ships Python 3.14 and 3.14t wheels.
- JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we only
offered free-threading builds on Linux.
-
Changes
- Exposed
jax.set_mesh
which acts as a global setter and a context manager.
Removedjax.sharding.use_mesh
in favor ofjax.set_mesh
. - JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remain
supported. jax.lax.dot
now implements the general dot product via the optional
dimension_numbers
argument.
- Exposed
-
Deprecations:
jax.lax.zeros_like_array
is deprecated. Please use
jax.numpy.zeros_like
instead.- Attempting to import
jax.experimental.host_callback
now results in
aDeprecationWarning
, and will result in anImportError
starting in JAX
v0.8.0. Its APIs have raisedNotImplementedError
since JAX version 0.4.35. - In
jax.lax.dot
, passing theprecision
andpreferred_element_type
arguments by position is deprecated. Pass them by explicit keyword instead. - Several dozen internal APIs have been deprecated from
jax.interpreters.ad
,
jax.interpreters.batching
, andjax.interpreters.partial_eval
; they
are used rarely if ever outside JAX itself, and most are deprecated without any
public replacement.