Features
- Added
jaxtyping.print_bindings
to manually inspect the values of each axis, whilst inside a function. - Added support for
jaxtyping.{Int4, UInt4}
. (#174, thanks @jianlijianli!)
Bugfixes
- We no longer import JAX at all, even if it is present. This ensures compatibility when using jaxtyping+PyTorch alongside an old JAX installation. (All JAX re-exports, like
jaxtyping.Array = jax.Array
, are looked up dynamically rather than import time.) (#178) - We no longer raise false postiives when
@jaxtyped
-ing generators (withyield
statements). (#91, #171, thanks @knyazer!)
Internals
- Added support for beartype's pseudostandard
__instancecheck_str__
method. Instead ofisinstance(x, Float[Array, "foo"])
, then one can now callFloat[Array, "foo"].__instancecheck_str__(x)
, which will return either an empty string (success) or an error message describing why the check failed (wrong shape, wrong dtype, ...). In practice this feature probably isn't super usable right now; we'll need to wait until we've later done a better job ensuring compatibility between the jaxtyping import hooks and the beartype import hooks.
Docs
- Fixes by @jeertmans (#154) and @afrozenator (#170) -- thank you!
New Contributors
- @jeertmans made their first contribution in #154
- @afrozenator made their first contribution in #170
Full Changelog: v0.2.25...v0.2.26