Skip to content

Releases: patrick-kidger/jaxtyping

jaxtyping v0.2.33

12 Jul 12:02
Compare
Choose a tag to compare
  • Compatibility with Python 3.10 when using Any as the array type.
  • Compatibility with generic array types.
  • Array typevars now respect __constraints__

Full Changelog: v0.2.32...v0.2.33

jaxtyping v0.2.32

12 Jul 09:44
Compare
Choose a tag to compare
  • The array type can now be either Any or a TypeVar. In both cases this means that anything is allowed at runtime. As usual, static type checkers will only look at the array part of an annotation, so that an annotation of the form Float[T, "foo bar"] (where T = TypeVar("T")) will be treated as just T by static type checkers. This allows for expressing array-type-polymorphism with static typechecking. Here's an example:

    import numpy as np
    import torch
    from typing import TypeVar
    
    TensorLike = TypeVar("TensorLike", np.ndarray, torch.Tensor)
    
    def stack_scalars(x: Float[TensorLike, ""], y: Float[TensorLike, ""]) -> Float[TensorLike, "2"]:
        if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
            return np.stack([x, y])
        elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
            return torch.stack([x, y])
        else:
            raise ValueError("Invalid array types!")
  • Fixed a bug in which the very first argument to a function was erroneously reported as the one at fault for a typechecking error. This bug occurred when using default arguments.

Full Changelog: v0.2.31...v0.2.32

jaxtyping v0.2.31

25 Jun 18:34
Compare
Choose a tag to compare
  • Now duck-type on array shapes and dtypes, so you can use jaxtyping for your custom arraylike objects:

     class FooDtype(jaxtyping.AbstractDtype):
     	dtypes = ["foo"]
    
     class MyArray:
     	@property
     	def dtype(self):
     		return "foo"
    
     	@property
     	def shape(self):
     		return (3, 1, 4)
    
     def f(x: FooDtype[MyArray, "3 1 4"]): ...
  • Improved compatibility when typeguard warns that you're typechecking a function without annotations: it will no longer mention the jaxtyping-internal check_params function and will instead mention the name of the function that is missing annotations.

  • Improved the error message when typechecking fails, to state the full some_module.SomeClass.some_method rather than just some_method.

  • Fixed a JAX deprecation warning for jax.tree_map. (Thanks @groszewn!)

New Contributors

Full Changelog: v0.2.30...v0.2.31

jaxtyping v0.2.30

13 Jun 17:25
Compare
Choose a tag to compare
  • Now reporting the correct source code line numbers when using the import hook. Makes debuggers useful again! #214
  • Now supports numpy structured dtypes (Thanks @alexfanqi! #211)
  • Now respecting typing.no_type_check. #216

New Contributors

Full Changelog: v0.2.29...v0.2.30

jaxtyping v0.2.29

27 May 14:29
Compare
Choose a tag to compare
  • Crash fix for when jax is available but jaxlib is not. (Thanks @ar0ck! #191)
  • Crash fix when used alongside old TensorFlow versions that don't support tensor.ndim (Thanks @dziulek! #193)
  • Crash fix when using a default argument as a symbolic dimension size. (Thanks @jaraujo98! #208)
  • Improved import times by defining the IPython magic lazily. (Thanks @superbobry! #201)
  • The import hook will now typecheck functions that do not have any annotations in the arguments or return value. This is useful for those that do manual isinstance checks in the body of teh function. (Thanks @nimashoghi! #205)
  • Dropped the dependency on numpy. This makes it possible to just use jaxtyping+typeguard as the one-stop-shop for all runtime typechecking, even when you're not using arrays. Obviously that's a little unusual -- not really the main focus of jaxtyping -- but helps when wanting a single choice of runtime type checker across an entire codebase, only parts of which may use arrays. (#212)

New Contributors

Full Changelog: v0.2.28...v0.2.29

jaxtyping v0.2.28

07 Mar 17:31
Compare
Choose a tag to compare

Autogenerated release notes as follows:

What's Changed

Full Changelog: v0.2.27...v0.2.28

jaxtyping v0.2.27

06 Mar 19:38
Compare
Choose a tag to compare

Quick bugfix release:

  • Fixed some isinstance checks against variadics crashing (although this was when it was about to return False anyway). (Thanks @asford! #186)
  • Fixed docs for downstream libraries (Equinox, ...) not generating correctl (#182)

New Contributors

Full Changelog: v0.2.26...v0.2.27

jaxtyping v0.2.26

25 Feb 12:10
Compare
Choose a tag to compare

Features

  • Added jaxtyping.print_bindings to manually inspect the values of each axis, whilst inside a function.
  • Added support for jaxtyping.{Int4, UInt4}. (#174, thanks @jianlijianli!)

Bugfixes

  • We no longer import JAX at all, even if it is present. This ensures compatibility when using jaxtyping+PyTorch alongside an old JAX installation. (All JAX re-exports, like jaxtyping.Array = jax.Array, are looked up dynamically rather than import time.) (#178)
  • We no longer raise false postiives when @jaxtyped-ing generators (with yield statements). (#91, #171, thanks @knyazer!)

Internals

  • Added support for beartype's pseudostandard __instancecheck_str__ method. Instead of isinstance(x, Float[Array, "foo"]), then one can now call Float[Array, "foo"].__instancecheck_str__(x), which will return either an empty string (success) or an error message describing why the check failed (wrong shape, wrong dtype, ...). In practice this feature probably isn't super usable right now; we'll need to wait until we've later done a better job ensuring compatibility between the jaxtyping import hooks and the beartype import hooks.

Docs

New Contributors

Full Changelog: v0.2.25...v0.2.26

jaxtyping v0.2.25

15 Dec 18:38
Compare
Choose a tag to compare

This release is primarily a usability release, designed to help ensure the library is being used correctly.

  • The error messages from a failed typecheck have been improved, to explicitly highlight more information about which argument was wrong. :)
  • If the jaxtyping.jaxtyped(typechecker=...) argument is not passed, then a warning will be displayed. In practice, this will trigger:
    • if using the old double-decorator syntax (@jaxtyped @beartype def foo(...): ...) -- upgrade to the new @jaxtyped(typechecker=beartype) def foo(...): ... syntax and get better error messages! :)
    • If making the easy mistake of writing @jaxtyped(beartype) def foo(...): ... -- in this case it's actually the beartype call that is jaxtype'd, not foo.
  • Incorrect use of jaxtyping annotations will now raise an jaxtyping.AnnotationError rather than a mix of RuntimeErrors, NameErrors etc. For example isinstance(x, Float) is not correct (you should write something like Float[Array, "..."]) instead), and this will raise such an AnnotationError.
  • Introduced two config flags:
    • JAXTYPING_DISABLE=1 / jaxtyping.config.update("jaxtyping_disable", True): if enabled then all runtime type checking will be skipped.
    • JAXTYPING_REMOVE_TYPECHECKER_STACK=1 / jaxtyping.config.update("jaxtyping_remove_typechecker_stack", True): if enabled then type-checking errors will only show the jaxtyping.TypeCheckError, and won't include any extra stack trace from the underlying type-checker (beartype/typeguard). Some users have found that they preferred the conciseness over the extra information.

Full Changelog: v0.2.24...v0.2.25

jaxtyping v0.2.24

27 Nov 17:53
Compare
Choose a tag to compare

New features

  • Error messages will now include useful shape information for debugging. (!!!) This closes the venerable #6, which is is one of the oldest feature requests for jaxtyping. This is enabled by using the following syntax, instead of the old double-decorator syntax:
    from jaxtyping import jaxtyped
    from beartype/typeguard import beartype/typechecked as typechecker
    
    @jaxtyped(typechecker=typechecker)  # passing as keyword argument is important
    def foo(...):
        ...
    and moreover this is what install_import_hook now does.
    As an example of this done, consider this buggy code:
    import jax.numpy as jnp
    from jaxtyping import Array, Float, jaxtyped
    from beartype import beartype
    
    @jaxtyped(typechecker=beartype)
    def f(x: Float[Array, "foo bar"], y: Float[Array, "foo"]):
        ...
    
    f(jnp.zeros((3, 4)), jnp.zeros(5))
    will now produce the error message
    jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of f.
    The problem arose whilst typechecking argument 'y'.
    Called with arguments: {'x': f32[3,4], 'y': f32[5]}
    Parameter annotations: (x: Float[Array, 'foo bar'], y: Float[Array, 'foo']).
    The current values for each jaxtyping axis annotation are as follows.
    foo=3
    bar=4
    
    Hurrah! I'm really glad to have this important quality-of-life improvement in. (#6, #138)
  • Added support for the following:
    def make_zeros(size: int) -> Float[Array, "{size}"]:
        return jnp.zeros(size)
    in which axis names enclosed in {...} are evaluated as f-strings using the value of the argument of the function. This closes the long-standing feature request #93. (#93, #140) (Heads-up @MilesCranmer!)
  • Added support for declaring PyTree structures, which like array shapes must match across all arguments. For example
    def f(x: PyTree[int, "T"], y: PyTree[float, "T"])
    demands that x and y be PyTrees with the same jax.tree_util.tree_structure as each other. (#135)
  • Added support for treepath-dependent sizes using ?. This makes it possible for the value of a dimension to vary across its position within a pytree, but must still be consistent with its value in other pytrees of the same structure. Such annotations look like PyTree[Float[Array, "?foo"], "T"]. Together with the previous point, this means that you can now declare that two pytrees must have the exact same structure and array shapes as each other: use PyTree[Float[Array, "?*shape"], "T"] as the annotation for both. (#136)
  • Added jaxtyping.Real, which admits any float, signed integer, or unsigned integer. (But not bools or complexes.) (#128)
  • If JAX is installed, then jaxtyping.DTypeLike is now available (it is just a forwarding on of jax.typing.DTypeLike). (#129)

Bugfixes

  • Fixed no error being raised when having mismatched variadic+broadcast and variadic+nonbroadcast dimensions; see #134 for details. (#134)
  • Fixed jaxtyping.Key not being compatible with the new-style jax.random.key. (As opposed to the old-style jax.random.PRNGKey.) (#142, #143)
  • Fixed install_import_hook(..., None) crashing (#145, #146).
  • Variadic shapes combined with bool/int/float/complex now work correctly, e.g. Float[float, "..."] is now valid (and equivalent to just float). This is useful in particular for Float[ArrayLike, "..."] to work correctly (as ArrayLike includes float). (#133)

Better error messages

  • The error message due to a nonexist symbolic dimension -- e.g. def f(x: Float[Array, "dim*2"]) leaves dim unspecified -- are now fixed. (#131)
  • The error message due to the wrong dataclass attribute type -- e.g.
    @dataclass
    class Foo:
        attribute_name: int
    Foo("strings are not integers")
    will now correctly include the attribute_name. (#132)

Note that this release may result in new errors being raised, due to the inclusion of #134. If so then you then the appropriate thing to do is to fix your code -- this is a correct error that jaxtyping was previously failing to raise.

Full Changelog: v0.2.23...v0.2.24