-
Changes:
jax.tree.flatten_with_path
andjax.tree.map_with_path
are added
as shortcuts of the correspondingtree_util
functions.
-
Deprecations
- a number of APIs in the internal
jax.core
namespace have been deprecated.
Most were no-ops, were little-used, or can be replaced by APIs of the same
name injax.extend.core
; see the documentation for {mod}jax.extend
for information on the compatibility guarantees of these semi-public extensions. - Several previously-deprecated APIs have been removed, including:
- from
jax.core
:check_eqn
,check_type
,check_valid_jaxtype
, and
non_negative_dim
. - from
jax.lib.xla_bridge
:xla_client
anddefault_backend
. - from
jax.lib.xla_client
:_xla
andbfloat16
. - from
jax.numpy
:round_
.
- from
- a number of APIs in the internal
-
New Features
jax.export.export
can be used for device-polymorphic export with
shardings constructed with {func}jax.sharding.AbstractMesh
.
See the jax.export documentation.- Added
jax.lax.split
. This is a primitive version of
jax.numpy.split
, added because it yields a more compact
transpose during automatic differentiation.