github jax-ml/jax jax-v0.7.2
JAX v0.7.2

13 hours ago
  • Breaking changes:

    • jax.dlpack.from_dlpack no longer accepts a DLPack capsule. This
      behavior was deprecated and is now removed. The function must be called
      with an array implementing __dlpack__ and __dlpack_device__.
  • Changes

    • The minimum supported NumPy version is now 2.0. Since SciPy 1.13 is required
      for NumPy 2.0 support, the minimum supported SciPy version is now 1.13.

    • JAX now represents constants in its internal jaxpr representation as a
      LiteralArray, which is a private JAX type that duck types as a
      numpy.ndarray. This type may be exposed to users via custom_jvp rules,
      for example, and may break code that uses isinstance(x, np.ndarray). If
      this breaks your code, you may convert these arrays to classic NumPy arrays
      using np.asarray(x).

  • Bug fixes

    • arr.view(dtype=None) now returns the array unchanged, matching NumPy's
      semantics. Previously it returned the array with a float dtype.
    • jax.random.randint now produces a less-biased distribution for 8-bit and
      16-bit integer types ({jax-issue}#27742). To restore the previous biased
      behavior, you may temporarily set the jax_safer_randint configuration to
      False, but note this is a temporary config that will be removed in a
      future release.
  • Deprecations:

    • The parameters enable_xla and native_serialization for jax2tf.convert
      are deprecated and will be removed in a future version of JAX. These were
      used for jax2tf with non-native serialization, which has been now removed.
    • Setting the config state jax_pmap_no_rank_reduction to False is
      deprecated. By default, jax_pmap_no_rank_reduction will be set to True
      and jax.pmap shards will not have their rank reduced, keeping the same
      rank as their enclosing array.

Don't miss a new jax release

NewReleases is sending notifications on new releases.