New features
- Error messages will now include useful shape information for debugging. (!!!) This closes the venerable #6, which is is one of the oldest feature requests for jaxtyping. This is enabled by using the following syntax, instead of the old double-decorator syntax:
and moreover this is what
from jaxtyping import jaxtyped from beartype/typeguard import beartype/typechecked as typechecker @jaxtyped(typechecker=typechecker) # passing as keyword argument is important def foo(...): ...
install_import_hook
now does.
As an example of this done, consider this buggy code:will now produce the error messageimport jax.numpy as jnp from jaxtyping import Array, Float, jaxtyped from beartype import beartype @jaxtyped(typechecker=beartype) def f(x: Float[Array, "foo bar"], y: Float[Array, "foo"]): ... f(jnp.zeros((3, 4)), jnp.zeros(5))
Hurrah! I'm really glad to have this important quality-of-life improvement in. (#6, #138)jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of f. The problem arose whilst typechecking argument 'y'. Called with arguments: {'x': f32[3,4], 'y': f32[5]} Parameter annotations: (x: Float[Array, 'foo bar'], y: Float[Array, 'foo']). The current values for each jaxtyping axis annotation are as follows. foo=3 bar=4
- Added support for the following:
in which axis names enclosed in
def make_zeros(size: int) -> Float[Array, "{size}"]: return jnp.zeros(size)
{...}
are evaluated as f-strings using the value of the argument of the function. This closes the long-standing feature request #93. (#93, #140) (Heads-up @MilesCranmer!) - Added support for declaring PyTree structures, which like array shapes must match across all arguments. For example
demands that
def f(x: PyTree[int, "T"], y: PyTree[float, "T"])
x
andy
be PyTrees with the samejax.tree_util.tree_structure
as each other. (#135) - Added support for treepath-dependent sizes using
?
. This makes it possible for the value of a dimension to vary across its position within a pytree, but must still be consistent with its value in other pytrees of the same structure. Such annotations look likePyTree[Float[Array, "?foo"], "T"]
. Together with the previous point, this means that you can now declare that two pytrees must have the exact same structure and array shapes as each other: usePyTree[Float[Array, "?*shape"], "T"]
as the annotation for both. (#136) - Added
jaxtyping.Real
, which admits any float, signed integer, or unsigned integer. (But not bools or complexes.) (#128) - If JAX is installed, then
jaxtyping.DTypeLike
is now available (it is just a forwarding on ofjax.typing.DTypeLike
). (#129)
Bugfixes
- Fixed no error being raised when having mismatched variadic+broadcast and variadic+nonbroadcast dimensions; see #134 for details. (#134)
- Fixed
jaxtyping.Key
not being compatible with the new-stylejax.random.key
. (As opposed to the old-stylejax.random.PRNGKey
.) (#142, #143) - Fixed
install_import_hook(..., None)
crashing (#145, #146). - Variadic shapes combined with
bool
/int
/float
/complex
now work correctly, e.g.Float[float, "..."]
is now valid (and equivalent to justfloat
). This is useful in particular forFloat[ArrayLike, "..."]
to work correctly (asArrayLike
includesfloat
). (#133)
Better error messages
- The error message due to a nonexist symbolic dimension -- e.g.
def f(x: Float[Array, "dim*2"])
leavesdim
unspecified -- are now fixed. (#131) - The error message due to the wrong dataclass attribute type -- e.g.
will now correctly include the
@dataclass class Foo: attribute_name: int Foo("strings are not integers")
attribute_name
. (#132)
Note that this release may result in new errors being raised, due to the inclusion of #134. If so then you then the appropriate thing to do is to fix your code -- this is a correct error that jaxtyping was previously failing to raise.
Full Changelog: v0.2.23...v0.2.24