github jax-ml/jax jax-v0.8.1
JAX v0.8.1

7 hours ago
  • New features:

    • jax.jit now supports the decorator factory pattern; i.e instead of
      writing
      @functools.partial(jax.jit, static_argnames=['n'])
      def f(x, n):
        ...
      you may write
      @jax.jit(static_argnames=['n'])
      def f(x, n):
        ...
  • Changes:

    • jax.lax.linalg.eigh now accepts an implementation argument to
      select between QR (CPU/GPU), Jacobi (GPU/TPU), and QDWH (TPU)
      implementations. The EighImplementation enum is publicly exported from
      jax.lax.linalg.

    • jax.lax.linalg.svd now implements an algorithm that uses the polar
      decomposition on CUDA GPUs. This is also an alias for the existing algorithm
      on TPUs.

  • Bug fixes:

    • Fixed a bug introduced in JAX 0.7.2 where eigh failed for large matrices on
      GPU (#33062).
  • Deprecations:

    • jax.sharding.PmapSharding is now deprecated. Please use
      jax.NamedSharding instead.
    • jx.device_put_replicated is now deprecated. Please use jax.device_put
      with the appropriate sharding instead.
    • jax.device_put_sharded is now deprecated. Please use jax.device_put with
      the appropriate sharding instead.
    • Default axis_types of jax.make_mesh will change in JAX v0.9.0 to return
      jax.sharding.AxisType.Explicit. Leaving axis_types unspecified will raise a
      DeprecationWarning.
    • jax.cloud_tpu_init and its contents were deprecated. There is no reason for a user to import or use the contents of this module; JAX handles this for you automatically if needed.

Don't miss a new jax release

NewReleases is sending notifications on new releases.