-
New features:
- Added
ResizeMethod.CUBIC_PYTORCHto jax.image.resize to match
PyTorch's bicubic resize (#15768). - We now support differentiation of jax.lax.linalg.qr for wide
matrices and whenfull_matricesisTrue. - LAPACK operations are now parallelized along the batch dimension on CPU.
- Added
perturb_singularargument to
jax.lax.linalg.tridiagonal_solve to handle singular matrices by
perturbing near-zero pivots in the LU decomposition. This is useful for
solving numerically singular systems when computing eigenvectors by inverse
iteration. - jax.scipy.linalg.eigh_tridiagonal now supports computing
eigenvectors on CPU and GPU. - Added the jax.numpy.ndarray.byteswap method.
- Added
-
Breaking changes:
PartitionSpecobjects no longer report themselves to be equal to tuples.
Convert tuples toPartitionSpecobjects before testing equality.- The
.vmaproperty has been removed fromjax.core.ShapedArray. Use
.manual_axis_type.varyinginstead. - JAX CPU devices now report their names as
cpu:0,cpu:1, etc. instead of
TFRT_CPU_0,TFRT_CPU_1. - The config state
jax_pmap_shmap_mergehas been removed.jax.pmap
will now always use the new implementation that wraps
jax.jit(jax.shard_map). Please see
https://docs.jax.dev/en/latest/migrate_pmap.html for more information. jax.device_put_shardedandjax.device_put_replicatedhave been removed
from the public API and now raise anAttributeErrorwhen accessed.
Please see
https://docs.jax.dev/en/latest/migrate_pmap.html#drop-in-replacements for
drop-in replacements.- The C++ pmap infrastructure has been removed. The following public APIs
are no longer available:jax.sharding.PmapSharding- From
jaxlib.xla_extension:PmapFunction,pmap,
NoSharding,Chunked,Unstacked,ShardedAxis,Replicated,
ShardingSpec. - From
jax.interpreters.pxla:MapTracer,PmapExecutable,
parallel_callable,shard_args,xla_pmap_p,Chunked,
NoSharding,Replicated,ShardedAxis,ShardingSpec,
Unstacked,spec_to_indices.
- The deprecated keyword arguments
a,a_min, anda_maxto
jax.numpy.cliphave been removed. - Functions
jax.numpy.hstack,jax.numpy.vstack,jax.numpy.dstack,
jax.numpy.column_stack,jax.numpy.atleast_1d,jax.numpy.atleast_2d,
andjax.numpy.atleast_3dno longer accept non-ArrayLikeinputs.
Doing so previously issued aDeprecationWarning. - jax.scipy.stats.rankdata now returns floating point values in
all cases, following a similar change in the SciPy 1.18 release.
-
Deprecations:
- A number of internal APIs in
jax.corehave been newly deprecated and
some have been moved tojax.extend.core. These includeCallPrimitive,
DebugInfo,DropVar,Effect,Effects,InconclusiveDimensionOperation,
JaxprTypeError,check_jaxpr,concrete_or_error,find_top_trace,
gensym,get_opaque_trace_state,jaxprs_in_params,new_jaxpr_eqn,
no_effects,nonempty_axis_env_DO_NOT_USE,primal_dtype_to_tangent_dtype,
unsafe_am_i_under_a_jit_DO_NOT_USE,unsafe_am_i_under_a_vmap_DO_NOT_USE,
unsafe_get_axis_names_DO_NOT_USE,valid_jaxtype,JaxprPpContext,
JaxprPpSettings,OutputType,abstract_token,aval_mapping_handlers,
call,concretization_function_error,custom_typechecks,is_concrete,
is_constant_dim,is_constant_shape,literalable_types,no_axis_name,
pytype_aval_mappings, andtrace_ctx.
- A number of internal APIs in
-
Changes:
- The minimum supported SciPy version is now 1.14.
vmaparameter ofjax.ShapeDtypeStructhas been replaced with
manual_axis_type: jax.sharding.ManualAxisType. The.vmaproperty has
been replaced with.manual_axis_type.varying.- Removed experimental jax.experimental.custom_dce.custom_dce
jax.scipy.linalg.cho_solve,jax.scipy.linalg.lu_solve, and
jax.scipy.linalg.solve_triangularnow show a deprecation warning for
batched 1D solves withb.ndim > 1. In the future these will be treated as
batched 2D solves.- Added a new version 10 for the jax.export serialization format. This is
an optimization for when there are multiple occurrences of the same
abstract value, abstract mesh, or sharding.
-
Bug fixes:
- Fixed a bug that led to differing output between CPU and GPU for
non-symmetric multidimensional IRFFTs (#29325). - Fixed an error when tiny matrices were passed to
jax.lax.linalg.tridiagonal_solveon GPU (#32487). - Fixed a bug in
jax.scipy.fft.dctnandidctnwhereaxes=None
incorrectly defaulted to all axes whenswas specified, instead of the
lastlen(s)axes to match SciPy behavior (#29426). - Fixed a bug where calling
jax.distributed.initialize()on a GCE TPU
Managed Instance Group raised anIndexError(#36593). When
jax.distributed.initialize()is called on a GCE VM, it uses the GCE
metadata server to learn the addresses of all participating tasks. The format of this metadata
on Managed Instance Groups was not a format JAX expected, leading to the
exception. We now parse this format correctly.
- Fixed a bug that led to differing output between CPU and GPU for