-
Deprecations
jax.lax.pvaryhas been deprecated.
Please usejax.lax.pcast(..., to='varying')as the replacement.- Complex arguments passed to
jax.numpy.arangenow result in a
deprecation warning, because the output is poorly-defined. - From
jax.corea number of symbols are newly deprecated including:
call_impl,get_aval,mapped_aval,subjaxprs,set_current_trace,
take_current_trace,traverse_jaxpr_params,unmapped_aval,
AbstractToken, andTraceTag. - All symbols in
jax.interpreters.pxlaare deprecated. These are
primarily JAX internal APIs, and users should not rely on them.
-
Changes:
-
jax's
Tracerno longer inherits fromjax.Arrayat runtime. However,
jax.Arraynow uses a custom metaclass suchisinstance(x, Array)is true
if an objectxrepresents a tracedArray. Only someTracers represent
Arrays, so it is not correct forTracerto inherit fromArray.For the moment, during Python type checking, we continue to declare
Tracer
as a subclass ofArray, however we expect to remove this in a future
release. -
jax.experimental.si_vjphas been deleted.
jax.vjpsubsumes it's functionality.
-