github jax-ml/jax jax-v0.10.0
JAX v0.10.0

7 hours ago
  • New features:

    • Added ResizeMethod.CUBIC_PYTORCH to jax.image.resize to match
      PyTorch's bicubic resize (#15768).
    • We now support differentiation of jax.lax.linalg.qr for wide
      matrices and when full_matrices is True.
    • LAPACK operations are now parallelized along the batch dimension on CPU.
    • Added perturb_singular argument to
      jax.lax.linalg.tridiagonal_solve to handle singular matrices by
      perturbing near-zero pivots in the LU decomposition. This is useful for
      solving numerically singular systems when computing eigenvectors by inverse
      iteration.
    • jax.scipy.linalg.eigh_tridiagonal now supports computing
      eigenvectors on CPU and GPU.
    • Added the jax.numpy.ndarray.byteswap method.
  • Breaking changes:

    • PartitionSpec objects no longer report themselves to be equal to tuples.
      Convert tuples to PartitionSpec objects before testing equality.
    • The .vma property has been removed from jax.core.ShapedArray. Use
      .manual_axis_type.varying instead.
    • JAX CPU devices now report their names as cpu:0, cpu:1, etc. instead of
      TFRT_CPU_0, TFRT_CPU_1.
    • The config state jax_pmap_shmap_merge has been removed. jax.pmap
      will now always use the new implementation that wraps
      jax.jit(jax.shard_map). Please see
      https://docs.jax.dev/en/latest/migrate_pmap.html for more information.
    • jax.device_put_sharded and jax.device_put_replicated have been removed
      from the public API and now raise an AttributeError when accessed.
      Please see
      https://docs.jax.dev/en/latest/migrate_pmap.html#drop-in-replacements for
      drop-in replacements.
    • The C++ pmap infrastructure has been removed. The following public APIs
      are no longer available:
      • jax.sharding.PmapSharding
      • From jaxlib.xla_extension: PmapFunction, pmap,
        NoSharding, Chunked, Unstacked, ShardedAxis, Replicated,
        ShardingSpec.
      • From jax.interpreters.pxla: MapTracer, PmapExecutable,
        parallel_callable, shard_args, xla_pmap_p, Chunked,
        NoSharding, Replicated, ShardedAxis, ShardingSpec,
        Unstacked, spec_to_indices.
    • The deprecated keyword arguments a, a_min, and a_max to
      jax.numpy.clip have been removed.
    • Functions jax.numpy.hstack, jax.numpy.vstack, jax.numpy.dstack,
      jax.numpy.column_stack, jax.numpy.atleast_1d, jax.numpy.atleast_2d,
      and jax.numpy.atleast_3d no longer accept non-ArrayLike inputs.
      Doing so previously issued a DeprecationWarning.
    • jax.scipy.stats.rankdata now returns floating point values in
      all cases, following a similar change in the SciPy 1.18 release.
  • Deprecations:

    • A number of internal APIs in jax.core have been newly deprecated and
      some have been moved to jax.extend.core. These include CallPrimitive,
      DebugInfo, DropVar, Effect, Effects, InconclusiveDimensionOperation,
      JaxprTypeError, check_jaxpr, concrete_or_error, find_top_trace,
      gensym, get_opaque_trace_state, jaxprs_in_params, new_jaxpr_eqn,
      no_effects, nonempty_axis_env_DO_NOT_USE, primal_dtype_to_tangent_dtype,
      unsafe_am_i_under_a_jit_DO_NOT_USE, unsafe_am_i_under_a_vmap_DO_NOT_USE,
      unsafe_get_axis_names_DO_NOT_USE, valid_jaxtype, JaxprPpContext,
      JaxprPpSettings, OutputType, abstract_token, aval_mapping_handlers,
      call, concretization_function_error, custom_typechecks, is_concrete,
      is_constant_dim, is_constant_shape, literalable_types, no_axis_name,
      pytype_aval_mappings, and trace_ctx.
  • Changes:

    • The minimum supported SciPy version is now 1.14.
    • vma parameter of jax.ShapeDtypeStruct has been replaced with
      manual_axis_type: jax.sharding.ManualAxisType. The .vma property has
      been replaced with .manual_axis_type.varying.
    • Removed experimental jax.experimental.custom_dce.custom_dce
    • jax.scipy.linalg.cho_solve, jax.scipy.linalg.lu_solve, and
      jax.scipy.linalg.solve_triangular now show a deprecation warning for
      batched 1D solves with b.ndim > 1. In the future these will be treated as
      batched 2D solves.
    • Added a new version 10 for the jax.export serialization format. This is
      an optimization for when there are multiple occurrences of the same
      abstract value, abstract mesh, or sharding.
  • Bug fixes:

    • Fixed a bug that led to differing output between CPU and GPU for
      non-symmetric multidimensional IRFFTs (#29325).
    • Fixed an error when tiny matrices were passed to
      jax.lax.linalg.tridiagonal_solve on GPU (#32487).
    • Fixed a bug in jax.scipy.fft.dctn and idctn where axes=None
      incorrectly defaulted to all axes when s was specified, instead of the
      last len(s) axes to match SciPy behavior (#29426).
    • Fixed a bug where calling jax.distributed.initialize() on a GCE TPU
      Managed Instance Group raised an IndexError (#36593). When
      jax.distributed.initialize() is called on a GCE VM, it uses the GCE
      metadata server to learn the addresses of all participating tasks. The format of this metadata
      on Managed Instance Groups was not a format JAX expected, leading to the
      exception. We now parse this format correctly.

Don't miss a new jax release

NewReleases is sending notifications on new releases.