- GitHub commits.
- Changes
- {func}
jax.numpy.linalg.svd
on TPUs uses a qdwh-svd solver. - {func}
jax.numpy.linalg.cond
on TPUs now accepts complex input. - {func}
jax.numpy.linalg.pinv
on TPUs now accepts complex input. - {func}
jax.numpy.linalg.matrix_rank
on TPUs now accepts complex input. - {func}
jax.scipy.cluster.vq.vq
has been added. jax.experimental.maps.mesh
has been deleted.
Please usejax.experimental.maps.Mesh
. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
for more information.- {func}
jax.scipy.linalg.qr
now returns a length-1 tuple rather than the raw array whenmode='r'
, in order to match the behavior ofscipy.linalg.qr
({jax-issue}#10452
) - {func}
jax.numpy.take_along_axis
now takes an optionalmode
parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passingmode="clip"
. - {func}
jax.numpy.take
now defaults tomode="fill"
, which returns invalid values (e.g., NaN) for out-of-bounds indices. - Scatter operations, such as
x.at[...].set(...)
, now have"drop"
semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct. - {func}
jax.numpy.take_along_axis
now raises aTypeError
if its indices are not of an integer type, matching the behavior of
{func}numpy.take_along_axis
. Previously non-integer indices were silently cast to integers. - {func}
jax.numpy.ravel_multi_index
now raises aTypeError
if itsdims
argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index
. Previously non-integerdims
was silently cast to integers. - {func}
jax.numpy.split
now raises aTypeError
if itsaxis
argument is not of an integer type, matching the behavior of {func}numpy.split
. Previously non-integeraxis
was silently cast to integers. - {func}
jax.numpy.indices
now raises aTypeError
if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices
. Previously non-integer dimensions were silently cast to integers. - {func}
jax.numpy.diag
now raises aTypeError
if itsk
argument is not of an integer type, matching the behavior of {func}numpy.diag
. Previously non-integerk
was silently cast to integers. - Added {func}
jax.random.orthogonal
.
- {func}
- Deprecations
- Many functions and objects available in {mod}
jax.test_util
are now deprecated and will raise a warning on import. This includescases_from_list
,check_close
,check_eq
,device_under_test
,format_shape_dtype_string
,rand_uniform
,skip_on_devices
,with_config
,xla_bridge
, and_default_tolerance
({jax-issue}#10389
). These, along with previously-deprecatedJaxTestCase
,JaxTestLoader
, andBufferDonationTestCase
, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest
, {mod}absl.testing
, {mod}numpy.testing
, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices
. Many of the deprecated utilities will still exist in {mod}jax._src.test_util
, but these are not public APIs and as such may be changed or removed without notice in future releases.
- Many functions and objects available in {mod}