-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type checking on closures #93
Comments
e.g., import torch
from torch import nn, Tensor
import numpy as np
from jaxtyping import Float
from typeguard import typechecked
class Model(nn.Module):
def __init__(self, latent):
super().__init__()
self.latent = latent
self.net = nn.Linear(10, latent)
@typechecked
def forward(
self, x: Float[Tensor, "batch 10"]
) -> Float[Tensor, f"batch {self.latent}"]:
return self.net(x)
Model(5)(torch.randn(10, 10)) gives the error:
|
Hmm. So the closure example you've written looks correct. The class example errors because the annotation is evaluated when the function is created, not when the function is ran. You could probably make this work by doing: @jaxtyped
@typechecked
def forward(self, x: Float[Tensor, "batch 10"]) -> Float[Tensor, "batch latent"]:
assert isinstance(torch.zeros(self.latent), Float[Tensor, "latent"])
return self.net(x) to bind It's obviously kind of hackish to have to create a tensor just for that purpose though. I'd be happy to take a PR offering a nicer syntax. |
I have a similar piece of code that has the same issue but doesn't require a class: import torch
from jaxtyping import Float
def square(i: int) -> Float[torch.Tensor, "i"]:
return torch.arange(i) ** 2 Maybe adding a function with the following type signature @overload
def assign_dims(**kwargs: int) -> None: ...
@overload
def assign_dims(names: Tuple[str], shape: Union[Tuple[int], int]) -> None: ...
@overload
def assign_dims(names: str, shape: Union[Tuple[int], int]) -> None: ...
def assign_dims(names, shape, /, **kwargs):
...
def square(i: int) -> Float[torch.Tensor, "i"]:
assign_dims("i", i)
return torch.arange(i) ** 2 Implementing this function is simple and straightforward. It requires gluing these two code snippets together: Or a type annotation similar to the other ones in this repo would be useful, but I have no idea how to get static type checkers to be OK with this, but it does look more understandable to me: def square(i: Dims["i"]) -> Float[torch.Tensor, "i"]:
return torch.arange(i) ** 2 @patrick-kidger Are we forced to use the function version or is there a clever way to get something similar to the type annotation version working? |
I think we can probably convince static type checkers to at least ignore the annotation. (We might not be able to make them treat it as an The real question is what kind of annotation we might dream up to support the use-case with classes, like in the original question. What syntax would be neatest? |
Regardless of how the type annotation will be done, I am starting to think that a function for asserting the shape of an array (or a mechanism similar) is useful in its own right for introducing/asserting shape variables in the middle of a function, in addition to providing a way to provide shape variables for class attributes. def bincount_wrapper(x: Int[torch.Tensor, "elems"]) -> Int[torch.Tensor, "max_elem"]:
print(x)
assign_dims("max_elem", torch.amax(x))
return torch.bincount(x) Additional type annotations for class attributes could probably be implemented using the |
That sounds reasonable to me! |
I am not entirely sure about the name of this function. Maybe changing its name to https://www.tensorflow.org/api_docs/python/tf/debugging/assert_shapes I am open for other name suggestions. |
|
I like |
Hi,
we could maybe replace
I think this would be more easy for reading and implementation. Otherwise if we manually bind
so it would be really very appreciated to integrate accessing class member for type checking. |
I didn't get around to writing test cases for this since I was busy recently. I can fish out the code I wrote for I would hope that the The problem I see with what your proposal is that the variables in the shape annotations are in a different scope than the variables in the function body. We could add The problem with Thank you for looking into this. I am not 100% pleased with the |
Hi, @anivegesana Thanks for the fast reply. In my test, use
In this case, at the second But I would like to directly add dim info to
Of course these two functions only make sense with And back to parse "self.xxx", Yes, it doesn't solve problem with variable inside function body and class method. In the latter two cases it could only be implemented by assign_name. The motivation is that, for my experiences like machine learning, most of shape check are about checking like "batch data_dim", dimensions like "batch" are dynamic and flexible. Dimensions like "data_dim" are determined once the class is initailized. So I think this feature could be a very helpful feature. The question is still there: what is the best way to annotate, so that it's most simple, robust and readable. I believe parsing "self.xxx", add function assign_dim() and check_dim() would be very useful. |
I think we are thinking the same thing. If there is no Treating from itertools import islice
from typing import SupportsInt, LiteralString
from jaxtyping._array_types import storage
def _dims(dims: LiteralString, /, *lengths: SupportsInt, assign: bool):
no_temp_memo = hasattr(storage, "memo_stack") and len(storage.memo_stack) != 0
if no_temp_memo:
single_memo: dict[str, int]
variadic_memo: dict[str, tuple[int, ...]]
variadic_broadcast_memo: dict[str, list[tuple[int, ...]]]
single_memo, variadic_memo, variadic_broadcast_memo = storage.memo_stack[-1]
dim_list = dims.split()
iter_lengths = iter(lengths)
seen_variadic = False
for dim, length in zip(dim_list, iter_lengths, strict=True):
length = int(length)
if dim == "...":
dim = "*_"
if dim.startswith("*"):
# We are dealing with a variadic
if seen_variadic:
raise ValueError(f"Muliple variadic shape arguments in {dims!r}")
seen_variadic = True
amount_in_variadic = len(lengths) - len(dim_list)
variadic = (length,) + tuple(int(i) for i in islice(iter_lengths, amount_in_variadic))
if dim == "*_":
continue
elif (old_variadic := single_memo.get(dim)) is not None:
if old_variadic != variadic:
raise ValueError(f"Conflict in dim {dim!r}. First found {old_variadic}. Now found {variadic}.")
elif assign:
variadic_memo[dim] = variadic
else:
assert dim.isidentifier()
if dim == "_":
continue
elif (old_length := single_memo.get(dim)) is not None:
if old_length != length:
raise ValueError(f"Conflict in dim {dim!r}. First found {old_length}. Now found {length}.")
elif assign:
single_memo[dim] = length
def assign_dims(dims: LiteralString, /, *lengths: SupportsInt):
_dims(dims, *lengths, assign=True)
def check_dims(dims: LiteralString, /, *lengths: SupportsInt):
_dims(dims, *lengths, assign=False) It seems that I didn't add a case for variadic broadcasts. Am planning on looking into that later, but you can add support if you have the time. |
In terms of making In terms of the def assign_dim(dim: str, value: int):
isinstance(DummyArray(value), Shaped[DummyArray, dim]) This avoids the needle to meddle with the internal machinery. |
Yes. I was unaware that |
A small thing, for dummy array, I think using function like |
Hey @patrick-kidger,
great package as always. I was wondering what the best way is to type check closure functions? For example:
This seems to work, but I'm wondering if this is the proper way? Or maybe something with
typing.TypeVar
is better?I should also mention I'm not necessarily interested in closures, but more in class methods, where I want the return type to depend on a class parameter (e.g.,
-> Float[Array, f"{self.latent}"]
).The text was updated successfully, but these errors were encountered: