github patrick-kidger/jaxtyping v0.2.24
jaxtyping v0.2.24

latest releases: v0.2.34, v0.2.33, v0.2.32...
9 months ago

New features

  • Error messages will now include useful shape information for debugging. (!!!) This closes the venerable #6, which is is one of the oldest feature requests for jaxtyping. This is enabled by using the following syntax, instead of the old double-decorator syntax:
    from jaxtyping import jaxtyped
    from beartype/typeguard import beartype/typechecked as typechecker
    
    @jaxtyped(typechecker=typechecker)  # passing as keyword argument is important
    def foo(...):
        ...
    and moreover this is what install_import_hook now does.
    As an example of this done, consider this buggy code:
    import jax.numpy as jnp
    from jaxtyping import Array, Float, jaxtyped
    from beartype import beartype
    
    @jaxtyped(typechecker=beartype)
    def f(x: Float[Array, "foo bar"], y: Float[Array, "foo"]):
        ...
    
    f(jnp.zeros((3, 4)), jnp.zeros(5))
    will now produce the error message
    jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of f.
    The problem arose whilst typechecking argument 'y'.
    Called with arguments: {'x': f32[3,4], 'y': f32[5]}
    Parameter annotations: (x: Float[Array, 'foo bar'], y: Float[Array, 'foo']).
    The current values for each jaxtyping axis annotation are as follows.
    foo=3
    bar=4
    
    Hurrah! I'm really glad to have this important quality-of-life improvement in. (#6, #138)
  • Added support for the following:
    def make_zeros(size: int) -> Float[Array, "{size}"]:
        return jnp.zeros(size)
    in which axis names enclosed in {...} are evaluated as f-strings using the value of the argument of the function. This closes the long-standing feature request #93. (#93, #140) (Heads-up @MilesCranmer!)
  • Added support for declaring PyTree structures, which like array shapes must match across all arguments. For example
    def f(x: PyTree[int, "T"], y: PyTree[float, "T"])
    demands that x and y be PyTrees with the same jax.tree_util.tree_structure as each other. (#135)
  • Added support for treepath-dependent sizes using ?. This makes it possible for the value of a dimension to vary across its position within a pytree, but must still be consistent with its value in other pytrees of the same structure. Such annotations look like PyTree[Float[Array, "?foo"], "T"]. Together with the previous point, this means that you can now declare that two pytrees must have the exact same structure and array shapes as each other: use PyTree[Float[Array, "?*shape"], "T"] as the annotation for both. (#136)
  • Added jaxtyping.Real, which admits any float, signed integer, or unsigned integer. (But not bools or complexes.) (#128)
  • If JAX is installed, then jaxtyping.DTypeLike is now available (it is just a forwarding on of jax.typing.DTypeLike). (#129)

Bugfixes

  • Fixed no error being raised when having mismatched variadic+broadcast and variadic+nonbroadcast dimensions; see #134 for details. (#134)
  • Fixed jaxtyping.Key not being compatible with the new-style jax.random.key. (As opposed to the old-style jax.random.PRNGKey.) (#142, #143)
  • Fixed install_import_hook(..., None) crashing (#145, #146).
  • Variadic shapes combined with bool/int/float/complex now work correctly, e.g. Float[float, "..."] is now valid (and equivalent to just float). This is useful in particular for Float[ArrayLike, "..."] to work correctly (as ArrayLike includes float). (#133)

Better error messages

  • The error message due to a nonexist symbolic dimension -- e.g. def f(x: Float[Array, "dim*2"]) leaves dim unspecified -- are now fixed. (#131)
  • The error message due to the wrong dataclass attribute type -- e.g.
    @dataclass
    class Foo:
        attribute_name: int
    Foo("strings are not integers")
    will now correctly include the attribute_name. (#132)

Note that this release may result in new errors being raised, due to the inclusion of #134. If so then you then the appropriate thing to do is to fix your code -- this is a correct error that jaxtyping was previously failing to raise.

Full Changelog: v0.2.23...v0.2.24

Don't miss a new jaxtyping release

NewReleases is sending notifications on new releases.