github jax-ml/jax jax-v0.2.21
Jax release v0.2.21

latest releases: jax-v0.4.33, jax-v0.4.33-rc, jax-v0.4.32...
2 years ago
  • New features:

    • Added jax.numpy.insert implementation (#7936 ).
  • Breaking Changes

    • jax.api has been removed. Functions that were available as jax.api.*
      were aliases for functions in jax.*; please use the functions in
      jax.* instead.
    • jax.partial, jax.lax.partial, and jax.util.partial were accidental
      exports that have now been removed. Use functools.partial from the Python
      standard library instead.
    • Boolean scalar indices now raise a TypeError; previously this silently
      returned wrong results (#7925 ).
    • Many more jax.numpy functions now require array-like inputs, and will error
      if passed a list (#7747 #7802 #7907 ).
      See #7737 for a discussion of the rationale behind this change.
    • When inside a transformation such as jax.jit, jax.numpy.array always
      stages the array it produces into the traced computation. Previously
      jax.numpy.array would sometimes produce a on-device array, even under
      a jax.jit decorator. This change may break code that used JAX arrays to
      perform shape or index computations that must be known statically; the
      workaround is to perform such computations using classic NumPy arrays
      instead.
    • jnp.ndarray is now a true base-class for JAX arrays. In particular, this
      means that for a standard numpy array x, isinstance(x, jnp.ndarray) will
      now return False (#7927).

Don't miss a new jax release

NewReleases is sending notifications on new releases.