github jax-ml/jax jax-v0.7.1
JAX v0.7.1

16 days ago
  • New features

    • JAX now ships Python 3.14 and 3.14t wheels.
    • JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we only
      offered free-threading builds on Linux.
  • Changes

    • Exposed jax.set_mesh which acts as a global setter and a context manager.
      Removed jax.sharding.use_mesh in favor of jax.set_mesh.
    • JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remain
      supported.
    • jax.lax.dot now implements the general dot product via the optional
      dimension_numbers argument.
  • Deprecations:

    • jax.lax.zeros_like_array is deprecated. Please use
      jax.numpy.zeros_like instead.
    • Attempting to import jax.experimental.host_callback now results in
      a DeprecationWarning, and will result in an ImportError starting in JAX
      v0.8.0. Its APIs have raised NotImplementedError since JAX version 0.4.35.
    • In jax.lax.dot, passing the precision and preferred_element_type
      arguments by position is deprecated. Pass them by explicit keyword instead.
    • Several dozen internal APIs have been deprecated from jax.interpreters.ad,
      jax.interpreters.batching, and jax.interpreters.partial_eval; they
      are used rarely if ever outside JAX itself, and most are deprecated without any
      public replacement.

Don't miss a new jax release

NewReleases is sending notifications on new releases.