- Changes
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}version-support-policy
.
- We introduce
jax.Array
which is a unified array type that subsumes
DeviceArray
,ShardedDeviceArray
, andGlobalDeviceArray
types in JAX.
Thejax.Array
type helps make parallelism a core feature of JAX,
simplifies and unifies JAX internals, and allows us to unifyjit
and
pjit
.jax.Array
has been enabled by default in JAX 0.4 and makes some
breaking change to thepjit
API. The jax.Array migration
guide can
help you migrate your codebase tojax.Array
. You can also look at the
Distributed arrays and automatic parallelization
tutorial to understand the new concepts.
PartitionSpec
andMesh
are now out of experimental. The new API endpoints
arejax.sharding.PartitionSpec
andjax.sharding.Mesh
.
jax.experimental.maps.Mesh
andjax.experimental.PartitionSpec
are
deprecated and will be removed in 3 months.
with_sharding_constraint
s new public endpoint is
jax.lax.with_sharding_constraint
.
- If using ABSL flags together with
jax.config
, the ABSL flag values are no
longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading
jax.config
options, which are used pervasively in JAX.
- The jax2tf.call_tf function now uses for TF lowering the first TF
device of the same platform as used by the embedding JAX computation.
Before, it was using the 0th device for the JAX-default backend.
- A number of
jax.numpy
functions now have their arguments marked as
positional-only, matching NumPy.
jnp.msort
is now deprecated, following the deprecation ofnp.msort
in numpy 1.24.
It will be removed in a future release, in accordance with the {ref}api-compatibility
policy. It can be replaced withjnp.sort(a, axis=0)
.
- Support for Python 3.7 has been dropped, in accordance with JAX's