-
New features:
- Added
jax.numpy.insert
implementation (#7936 ).
- Added
-
Breaking Changes
jax.api
has been removed. Functions that were available asjax.api.*
were aliases for functions injax.*
; please use the functions in
jax.*
instead.jax.partial
,jax.lax.partial
, andjax.util.partial
were accidental
exports that have now been removed. Usefunctools.partial
from the Python
standard library instead.- Boolean scalar indices now raise a
TypeError
; previously this silently
returned wrong results (#7925 ). - Many more
jax.numpy
functions now require array-like inputs, and will error
if passed a list (#7747 #7802 #7907 ).
See #7737 for a discussion of the rationale behind this change. - When inside a transformation such as
jax.jit
,jax.numpy.array
always
stages the array it produces into the traced computation. Previously
jax.numpy.array
would sometimes produce a on-device array, even under
ajax.jit
decorator. This change may break code that used JAX arrays to
perform shape or index computations that must be known statically; the
workaround is to perform such computations using classic NumPy arrays
instead. jnp.ndarray
is now a true base-class for JAX arrays. In particular, this
means that for a standard numpy arrayx
,isinstance(x, jnp.ndarray)
will
now returnFalse
(#7927).