-
Bug fixes
- Fixed a bug where XLA sharded some concatenation operations incorrectly,
which manifested as an incorrect output for cumulative reductions (#21403). - Fixed a bug where XLA:CPU miscompiled certain matmul fusions
(openxla/xla#13301). - Fixes a compiler crash on GPU (#21396).
- Fixed a bug where XLA sharded some concatenation operations incorrectly,
-
Deprecations
jax.tree.map(f, None, non-None)
now emits aDeprecationWarning
, and will
raise an error in a future version of jax.None
is only a tree-prefix of
itself. To preserve the current behavior, you can askjax.tree.map
to
treatNone
as a leaf value by writing:
jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.