github patrick-kidger/jaxtyping v0.2.26
jaxtyping v0.2.26

latest releases: v0.2.34, v0.2.33, v0.2.32...
8 months ago

Features

  • Added jaxtyping.print_bindings to manually inspect the values of each axis, whilst inside a function.
  • Added support for jaxtyping.{Int4, UInt4}. (#174, thanks @jianlijianli!)

Bugfixes

  • We no longer import JAX at all, even if it is present. This ensures compatibility when using jaxtyping+PyTorch alongside an old JAX installation. (All JAX re-exports, like jaxtyping.Array = jax.Array, are looked up dynamically rather than import time.) (#178)
  • We no longer raise false postiives when @jaxtyped-ing generators (with yield statements). (#91, #171, thanks @knyazer!)

Internals

  • Added support for beartype's pseudostandard __instancecheck_str__ method. Instead of isinstance(x, Float[Array, "foo"]), then one can now call Float[Array, "foo"].__instancecheck_str__(x), which will return either an empty string (success) or an error message describing why the check failed (wrong shape, wrong dtype, ...). In practice this feature probably isn't super usable right now; we'll need to wait until we've later done a better job ensuring compatibility between the jaxtyping import hooks and the beartype import hooks.

Docs

New Contributors

Full Changelog: v0.2.25...v0.2.26

Don't miss a new jaxtyping release

NewReleases is sending notifications on new releases.