Releases: patrick-kidger/jaxtyping
jaxtyping v0.2.33
- Compatibility with Python 3.10 when using
Any
as the array type. - Compatibility with generic array types.
- Array typevars now respect
__constraints__
Full Changelog: v0.2.32...v0.2.33
jaxtyping v0.2.32
-
The array type can now be either
Any
or aTypeVar
. In both cases this means that anything is allowed at runtime. As usual, static type checkers will only look at the array part of an annotation, so that an annotation of the formFloat[T, "foo bar"]
(whereT = TypeVar("T")
) will be treated as justT
by static type checkers. This allows for expressing array-type-polymorphism with static typechecking. Here's an example:import numpy as np import torch from typing import TypeVar TensorLike = TypeVar("TensorLike", np.ndarray, torch.Tensor) def stack_scalars(x: Float[TensorLike, ""], y: Float[TensorLike, ""]) -> Float[TensorLike, "2"]: if isinstance(x, np.ndarray) and isinstance(y, np.ndarray): return np.stack([x, y]) elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): return torch.stack([x, y]) else: raise ValueError("Invalid array types!")
-
Fixed a bug in which the very first argument to a function was erroneously reported as the one at fault for a typechecking error. This bug occurred when using default arguments.
Full Changelog: v0.2.31...v0.2.32
jaxtyping v0.2.31
-
Now duck-type on array shapes and dtypes, so you can use
jaxtyping
for your custom arraylike objects:class FooDtype(jaxtyping.AbstractDtype): dtypes = ["foo"] class MyArray: @property def dtype(self): return "foo" @property def shape(self): return (3, 1, 4) def f(x: FooDtype[MyArray, "3 1 4"]): ...
-
Improved compatibility when typeguard warns that you're typechecking a function without annotations: it will no longer mention the jaxtyping-internal
check_params
function and will instead mention the name of the function that is missing annotations. -
Improved the error message when typechecking fails, to state the full
some_module.SomeClass.some_method
rather than justsome_method
. -
Fixed a JAX deprecation warning for
jax.tree_map
. (Thanks @groszewn!)
New Contributors
Full Changelog: v0.2.30...v0.2.31
jaxtyping v0.2.30
- Now reporting the correct source code line numbers when using the import hook. Makes debuggers useful again! #214
- Now supports numpy structured dtypes (Thanks @alexfanqi! #211)
- Now respecting
typing.no_type_check
. #216
New Contributors
- @alexfanqi made their first contribution in #211
Full Changelog: v0.2.29...v0.2.30
jaxtyping v0.2.29
- Crash fix for when
jax
is available butjaxlib
is not. (Thanks @ar0ck! #191) - Crash fix when used alongside old TensorFlow versions that don't support
tensor.ndim
(Thanks @dziulek! #193) - Crash fix when using a default argument as a symbolic dimension size. (Thanks @jaraujo98! #208)
- Improved import times by defining the IPython magic lazily. (Thanks @superbobry! #201)
- The import hook will now typecheck functions that do not have any annotations in the arguments or return value. This is useful for those that do manual
isinstance
checks in the body of teh function. (Thanks @nimashoghi! #205) - Dropped the dependency on numpy. This makes it possible to just use jaxtyping+typeguard as the one-stop-shop for all runtime typechecking, even when you're not using arrays. Obviously that's a little unusual -- not really the main focus of jaxtyping -- but helps when wanting a single choice of runtime type checker across an entire codebase, only parts of which may use arrays. (#212)
New Contributors
- @ar0ck made their first contribution in #191
- @dziulek made their first contribution in #193
- @superbobry made their first contribution in #201
- @nimashoghi made their first contribution in #205
- @jaraujo98 made their first contribution in #208
Full Changelog: v0.2.28...v0.2.29
jaxtyping v0.2.28
Autogenerated release notes as follows:
What's Changed
- Fixes #188. by @patrick-kidger in #190
Full Changelog: v0.2.27...v0.2.28
jaxtyping v0.2.27
Quick bugfix release:
- Fixed some
isinstance
checks against variadics crashing (although this was when it was about to returnFalse
anyway). (Thanks @asford! #186) - Fixed docs for downstream libraries (Equinox, ...) not generating correctl (#182)
New Contributors
Full Changelog: v0.2.26...v0.2.27
jaxtyping v0.2.26
Features
- Added
jaxtyping.print_bindings
to manually inspect the values of each axis, whilst inside a function. - Added support for
jaxtyping.{Int4, UInt4}
. (#174, thanks @jianlijianli!)
Bugfixes
- We no longer import JAX at all, even if it is present. This ensures compatibility when using jaxtyping+PyTorch alongside an old JAX installation. (All JAX re-exports, like
jaxtyping.Array = jax.Array
, are looked up dynamically rather than import time.) (#178) - We no longer raise false postiives when
@jaxtyped
-ing generators (withyield
statements). (#91, #171, thanks @knyazer!)
Internals
- Added support for beartype's pseudostandard
__instancecheck_str__
method. Instead ofisinstance(x, Float[Array, "foo"])
, then one can now callFloat[Array, "foo"].__instancecheck_str__(x)
, which will return either an empty string (success) or an error message describing why the check failed (wrong shape, wrong dtype, ...). In practice this feature probably isn't super usable right now; we'll need to wait until we've later done a better job ensuring compatibility between the jaxtyping import hooks and the beartype import hooks.
Docs
- Fixes by @jeertmans (#154) and @afrozenator (#170) -- thank you!
New Contributors
- @jeertmans made their first contribution in #154
- @afrozenator made their first contribution in #170
Full Changelog: v0.2.25...v0.2.26
jaxtyping v0.2.25
This release is primarily a usability release, designed to help ensure the library is being used correctly.
- The error messages from a failed typecheck have been improved, to explicitly highlight more information about which argument was wrong. :)
- If the
jaxtyping.jaxtyped(typechecker=...)
argument is not passed, then a warning will be displayed. In practice, this will trigger:- if using the old double-decorator syntax (
@jaxtyped @beartype def foo(...): ...
) -- upgrade to the new@jaxtyped(typechecker=beartype) def foo(...): ...
syntax and get better error messages! :) - If making the easy mistake of writing
@jaxtyped(beartype) def foo(...): ...
-- in this case it's actually thebeartype
call that is jaxtype'd, notfoo
.
- if using the old double-decorator syntax (
- Incorrect use of jaxtyping annotations will now raise an
jaxtyping.AnnotationError
rather than a mix ofRuntimeError
s,NameError
s etc. For exampleisinstance(x, Float)
is not correct (you should write something likeFloat[Array, "..."]
) instead), and this will raise such anAnnotationError
. - Introduced two config flags:
JAXTYPING_DISABLE=1
/jaxtyping.config.update("jaxtyping_disable", True)
: if enabled then all runtime type checking will be skipped.JAXTYPING_REMOVE_TYPECHECKER_STACK=1
/jaxtyping.config.update("jaxtyping_remove_typechecker_stack", True)
: if enabled then type-checking errors will only show thejaxtyping.TypeCheckError
, and won't include any extra stack trace from the underlying type-checker (beartype/typeguard). Some users have found that they preferred the conciseness over the extra information.
Full Changelog: v0.2.24...v0.2.25
jaxtyping v0.2.24
New features
- Error messages will now include useful shape information for debugging. (!!!) This closes the venerable #6, which is is one of the oldest feature requests for jaxtyping. This is enabled by using the following syntax, instead of the old double-decorator syntax:
and moreover this is what
from jaxtyping import jaxtyped from beartype/typeguard import beartype/typechecked as typechecker @jaxtyped(typechecker=typechecker) # passing as keyword argument is important def foo(...): ...
install_import_hook
now does.
As an example of this done, consider this buggy code:will now produce the error messageimport jax.numpy as jnp from jaxtyping import Array, Float, jaxtyped from beartype import beartype @jaxtyped(typechecker=beartype) def f(x: Float[Array, "foo bar"], y: Float[Array, "foo"]): ... f(jnp.zeros((3, 4)), jnp.zeros(5))
Hurrah! I'm really glad to have this important quality-of-life improvement in. (#6, #138)jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of f. The problem arose whilst typechecking argument 'y'. Called with arguments: {'x': f32[3,4], 'y': f32[5]} Parameter annotations: (x: Float[Array, 'foo bar'], y: Float[Array, 'foo']). The current values for each jaxtyping axis annotation are as follows. foo=3 bar=4
- Added support for the following:
in which axis names enclosed in
def make_zeros(size: int) -> Float[Array, "{size}"]: return jnp.zeros(size)
{...}
are evaluated as f-strings using the value of the argument of the function. This closes the long-standing feature request #93. (#93, #140) (Heads-up @MilesCranmer!) - Added support for declaring PyTree structures, which like array shapes must match across all arguments. For example
demands that
def f(x: PyTree[int, "T"], y: PyTree[float, "T"])
x
andy
be PyTrees with the samejax.tree_util.tree_structure
as each other. (#135) - Added support for treepath-dependent sizes using
?
. This makes it possible for the value of a dimension to vary across its position within a pytree, but must still be consistent with its value in other pytrees of the same structure. Such annotations look likePyTree[Float[Array, "?foo"], "T"]
. Together with the previous point, this means that you can now declare that two pytrees must have the exact same structure and array shapes as each other: usePyTree[Float[Array, "?*shape"], "T"]
as the annotation for both. (#136) - Added
jaxtyping.Real
, which admits any float, signed integer, or unsigned integer. (But not bools or complexes.) (#128) - If JAX is installed, then
jaxtyping.DTypeLike
is now available (it is just a forwarding on ofjax.typing.DTypeLike
). (#129)
Bugfixes
- Fixed no error being raised when having mismatched variadic+broadcast and variadic+nonbroadcast dimensions; see #134 for details. (#134)
- Fixed
jaxtyping.Key
not being compatible with the new-stylejax.random.key
. (As opposed to the old-stylejax.random.PRNGKey
.) (#142, #143) - Fixed
install_import_hook(..., None)
crashing (#145, #146). - Variadic shapes combined with
bool
/int
/float
/complex
now work correctly, e.g.Float[float, "..."]
is now valid (and equivalent to justfloat
). This is useful in particular forFloat[ArrayLike, "..."]
to work correctly (asArrayLike
includesfloat
). (#133)
Better error messages
- The error message due to a nonexist symbolic dimension -- e.g.
def f(x: Float[Array, "dim*2"])
leavesdim
unspecified -- are now fixed. (#131) - The error message due to the wrong dataclass attribute type -- e.g.
will now correctly include the
@dataclass class Foo: attribute_name: int Foo("strings are not integers")
attribute_name
. (#132)
Note that this release may result in new errors being raised, due to the inclusion of #134. If so then you then the appropriate thing to do is to fix your code -- this is a correct error that jaxtyping was previously failing to raise.
Full Changelog: v0.2.23...v0.2.24