-
New features:
jax.jitnow supports the decorator factory pattern; i.e instead of
writingyou may write@functools.partial(jax.jit, static_argnames=['n']) def f(x, n): ...
@jax.jit(static_argnames=['n']) def f(x, n): ...
-
Changes:
-
jax.lax.linalg.eighnow accepts animplementationargument to
select between QR (CPU/GPU), Jacobi (GPU/TPU), and QDWH (TPU)
implementations. TheEighImplementationenum is publicly exported from
jax.lax.linalg. -
jax.lax.linalg.svdnow implements analgorithmthat uses the polar
decomposition on CUDA GPUs. This is also an alias for the existing algorithm
on TPUs.
-
-
Bug fixes:
- Fixed a bug introduced in JAX 0.7.2 where eigh failed for large matrices on
GPU (#33062).
- Fixed a bug introduced in JAX 0.7.2 where eigh failed for large matrices on
-
Deprecations:
jax.sharding.PmapShardingis now deprecated. Please use
jax.NamedShardinginstead.jx.device_put_replicatedis now deprecated. Please usejax.device_put
with the appropriate sharding instead.jax.device_put_shardedis now deprecated. Please usejax.device_putwith
the appropriate sharding instead.- Default
axis_typesofjax.make_meshwill change in JAX v0.9.0 to return
jax.sharding.AxisType.Explicit. Leaving axis_types unspecified will raise a
DeprecationWarning. jax.cloud_tpu_initand its contents were deprecated. There is no reason for a user to import or use the contents of this module; JAX handles this for you automatically if needed.