-
Bug fixes:
-
Out-of-bounds indices to jax.ops.segment_sum will now be handled with FILL_OR_DROP semantics, as documented. This primarily afects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634).
-
jax2tf will force the converted code to use XLA for the code fragments under jax.jit, e.g., most jax.numpy functions (#7839).