- GitHub commits.
- Bugs
- Fix corner case issue in gradient of
lax.pow
with an exponent of zero (#12041)
- Fix corner case issue in gradient of
- Breaking changes
jax.checkpoint
, also known asjax.remat
, no longer supports theconcrete
option, following the previous version's deprecation; see JEP 11830.
- Changes
- Added
jax.pure_callback
that enables calling back to pure Python functions from compiled functions (e.g. functions decorated withjax.jit
orjax.pmap
).
- Added
- Deprecations:
- The deprecated
DeviceArray.tile()
method has been removed. Usejax.numpy.tile
(#11944). DeviceArray.to_py()
has been deprecated. Usenp.asarray(x)
instead.
- The deprecated