github jax-ml/jax jax-v0.5.1
JAX v0.5.1

11 hours ago
  • New Features

    • Added an experimental jax.experimental.custom_dce.custom_dce
      decorator to support customizing the behavior of opaque functions under
      JAX-level dead code elimination (DCE). See #25956 for more
      details.
    • Added low-level reduction APIs in {mod}jax.lax: jax.lax.reduce_sum,
      jax.lax.reduce_prod, jax.lax.reduce_max, jax.lax.reduce_min,
      jax.lax.reduce_and, jax.lax.reduce_or, and jax.lax.reduce_xor.
    • jax.lax.linalg.qr, and jax.scipy.linalg.qr, now support
      column-pivoting on CPU and GPU. See #20282 and
      #25955 for more details.
  • Changes

    • JAX_CPU_COLLECTIVES_IMPLEMENTATION and JAX_NUM_CPU_DEVICES now work as
      env vars. Before they could only be specified via jax.config or flags.
    • JAX_CPU_COLLECTIVES_IMPLEMENTATION now defaults to 'gloo', meaning
      multi-process CPU communication works out-of-the-box.
    • The jax[tpu] TPU extra no longer depends on the libtpu-nightly package.
      This package may safely be removed if it is present on your machine; JAX now
      uses libtpu instead.
  • Deprecations

    • The internal function linear_util.wrap_init and the constructor
      core.Jaxpr now must take a non-empty core.DebugInfo kwarg. For
      a limited time, a DeprecationWarning is printed if
      jax.extend.linear_util.wrap_init is used without debugging info.
      A downstream effect of this several other internal functions need debug
      info. This change does not affect public APIs.
      See #26480 for more detail.
  • Bug fixes

    • TPU runtime startup and shutdown time should be significantly improved on
      TPU v5e and newer (from around 17s to around 8s). If not already set, you may
      need to enable transparent hugepages in your VM image
      (sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled').
      We hope to improve this further in future releases.
    • Persistent compilation cache no longer writes access time file if
      JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU
      eviction policy isn't enabled. This should improve performance when using
      the cache with large-scale network storage.

Don't miss a new jax release

NewReleases is sending notifications on new releases.