Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type checking on closures #93

Closed
MilesCranmer opened this issue Jun 20, 2023 · 16 comments
Closed

Type checking on closures #93

MilesCranmer opened this issue Jun 20, 2023 · 16 comments
Labels
feature New feature next Higher-priority items

Comments

@MilesCranmer
Copy link

Hey @patrick-kidger,
great package as always. I was wondering what the best way is to type check closure functions? For example:

import numpy as np
from jaxtyping import Float
from typeguard import typechecked


def f(n):
    @typechecked
    def g(x: Float[np.ndarray, f"dim1 {n}"]) -> Float[np.ndarray, f"{n}"]:
        return x.sum(0)

    return g


print(f(10)(np.random.randn(10, 9)))

This seems to work, but I'm wondering if this is the proper way? Or maybe something with typing.TypeVar is better?

I should also mention I'm not necessarily interested in closures, but more in class methods, where I want the return type to depend on a class parameter (e.g., -> Float[Array, f"{self.latent}"]).

@MilesCranmer
Copy link
Author

e.g.,

import torch
from torch import nn, Tensor
import numpy as np
from jaxtyping import Float
from typeguard import typechecked

class Model(nn.Module):
    def __init__(self, latent):
        super().__init__()
        self.latent = latent
        self.net = nn.Linear(10, latent)

    @typechecked
    def forward(
        self, x: Float[Tensor, "batch 10"]
    ) -> Float[Tensor, f"batch {self.latent}"]:
        return self.net(x)

Model(5)(torch.randn(10, 10))

gives the error:

Traceback (most recent call last):
  File "/Users/mcranmer/Documents/perceiver4data/test.py", line 16, in <module>
    class Model(nn.Module):
  File "/Users/mcranmer/Documents/perceiver4data/test.py", line 24, in Model
    ) -> Float[Tensor, f"batch {self.latent}"]:
NameError: name 'self' is not defined

@patrick-kidger
Copy link
Owner

Hmm. So the closure example you've written looks correct. The class example errors because the annotation is evaluated when the function is created, not when the function is ran.

You could probably make this work by doing:

@jaxtyped
@typechecked
def forward(self, x: Float[Tensor, "batch 10"]) -> Float[Tensor, "batch latent"]:
    assert isinstance(torch.zeros(self.latent), Float[Tensor, "latent"])
    return self.net(x)

to bind latent to the value of self.latent. (Each collection of named axes is valid for the dynamic scope of the jaxtyped decorator, and this includes the interior of the body function.)

It's obviously kind of hackish to have to create a tensor just for that purpose though. I'd be happy to take a PR offering a nicer syntax.

@anivegesana
Copy link

anivegesana commented Jul 5, 2023

I have a similar piece of code that has the same issue but doesn't require a class:

import torch
from jaxtyping import Float

def square(i: int) -> Float[torch.Tensor, "i"]:
    return torch.arange(i) ** 2

Maybe adding a function with the following type signature

@overload
def assign_dims(**kwargs: int) -> None: ...

@overload
def assign_dims(names: Tuple[str], shape: Union[Tuple[int], int]) -> None: ...

@overload
def assign_dims(names: str, shape: Union[Tuple[int], int]) -> None: ...

def assign_dims(names, shape, /, **kwargs):
   ...

def square(i: int) -> Float[torch.Tensor, "i"]:
    assign_dims("i", i)
    return torch.arange(i) ** 2

Implementing this function is simple and straightforward. It requires gluing these two code snippets together:
https://github.com/google/jaxtyping/blob/a6ab6c0d28a5209e08c31791a1a20dc6cace6697/jaxtyping/_array_types.py#L177-L184
https://github.com/google/jaxtyping/blob/a6ab6c0d28a5209e08c31791a1a20dc6cace6697/jaxtyping/_array_types.py#L130-L137

Or a type annotation similar to the other ones in this repo would be useful, but I have no idea how to get static type checkers to be OK with this, but it does look more understandable to me:

def square(i: Dims["i"]) -> Float[torch.Tensor, "i"]:
    return torch.arange(i) ** 2

@patrick-kidger Are we forced to use the function version or is there a clever way to get something similar to the type annotation version working?

@patrick-kidger
Copy link
Owner

I think we can probably convince static type checkers to at least ignore the annotation. (We might not be able to make them treat it as an int though.)

The real question is what kind of annotation we might dream up to support the use-case with classes, like in the original question. What syntax would be neatest?

@anivegesana
Copy link

anivegesana commented Jul 6, 2023

Regardless of how the type annotation will be done, I am starting to think that a function for asserting the shape of an array (or a mechanism similar) is useful in its own right for introducing/asserting shape variables in the middle of a function, in addition to providing a way to provide shape variables for class attributes.

def bincount_wrapper(x: Int[torch.Tensor, "elems"]) -> Int[torch.Tensor, "max_elem"]:
    print(x)
    assign_dims("max_elem", torch.amax(x))
    return torch.bincount(x)

Additional type annotations for class attributes could probably be implemented using the assign_dims function.

@patrick-kidger
Copy link
Owner

That sounds reasonable to me!

@anivegesana
Copy link

anivegesana commented Jul 6, 2023

I am not entirely sure about the name of this function. Maybe changing its name to assert_shape or assert_shapes to mirror this function in TensorFlow would be a good idea? But it may also be confusing since the two functions are a little bit different.

https://www.tensorflow.org/api_docs/python/tf/debugging/assert_shapes

I am open for other name suggestions.

@patrick-kidger
Copy link
Owner

bind_name maybe? I don't have a strong opinion.

@anivegesana
Copy link

I like bind_name. I think implementing bind_name should be trivial. I can open a PR over the weekend.

@im-Kitsch
Copy link

Hi,
I wannt to return back to discuss this solution:

Hmm. So the closure example you've written looks correct. The class example errors because the annotation is evaluated when the function is created, not when the function is ran.

You could probably make this work by doing:

@jaxtyped
@typechecked
def forward(self, x: Float[Tensor, "batch 10"]) -> Float[Tensor, "batch latent"]:
    assert isinstance(torch.zeros(self.latent), Float[Tensor, "latent"])
    return self.net(x)

to bind latent to the value of self.latent. (Each collection of named axes is valid for the dynamic scope of the jaxtyped decorator, and this includes the interior of the body function.)

It's obviously kind of hackish to have to create a tensor just for that purpose though. I'd be happy to take a PR offering a nicer syntax.

we could maybe replace isinstance to functions like assign_dims, but I wonder if it's possible to integreate this feature into @jaxtyped, then it would be like again

@jaxtyped
@typechecked
def forward(self, x: Float[Tensor, "batch 10"]) -> Float[Tensor, "batch self.latent"]:
    return self.net(x)

I think this would be more easy for reading and implementation. Otherwise if we manually bind axis name(i.e. latent here) and class member (i.e. self.latent) inside of function, adding one sentence to assign dim is as same as simply doing as follows,

assert output.shape[1] == self.latent

so it would be really very appreciated to integrate accessing class member for type checking.

@anivegesana
Copy link

anivegesana commented Aug 20, 2023

I didn't get around to writing test cases for this since I was busy recently. I can fish out the code I wrote for assign_dims sometime tomorrow.

I would hope that the isinstance method would work, but I would need to verify again. The issue I find with it is that it unnecessarily creates a tensor object that isn't used.

The problem I see with what your proposal is that the variables in the shape annotations are in a different scope than the variables in the function body. We could add self to the shape annotation scope, but it doesn't answer how to deal with variables from the function body or class variables in classmethods.

The problem with assert something.shape == myshape (along with the different scopes) is that I think it would require expensive reparsing and inspection of the source code of the function to make work, but if it is possible to cheaply perform that operation, I would much prefer it.

Thank you for looking into this. I am not 100% pleased with the bind_name function myself, but I think it is the simplest and most general solution to this problem. I would like to be proven wrong. :)

@im-Kitsch
Copy link

im-Kitsch commented Aug 20, 2023

Hi, @anivegesana

Thanks for the fast reply. In my test, use isinstance works. But I don't think it's a good choice, for example,

    @jaxtyped
    @typechecked
    def forward(self, x: Float[Tensor, "batch dim_in"], ) -> Float[Tensor, "batch latent"]:
        assert isinstance(torch.zeros(5), Float[Tensor, "latent"])
        assert isinstance(torch.zeros(4), Float[Tensor, "latent"])

        self.latent = 5
        return torch.zeros(8, 5)

In this case, at the second assert, it will throw error. If we swap two assert sentence, still the second throw error. What I want to say is that here isinstance works not like its name: checking if it's the instance of given class, it just serves as bind names with variable. And assert will never works(in first assert sentence). This could be confusing for new users. And I thinks it's not safe to use isinstance() directly. if we only annotate input and output, the relationship and constraint of symbols are easy to check and understand. But if justifications like "asser isinstance()" sparsely distributed inside function, it would be soonly confusing and difficult for debug. People has to be carefully check where it's initialized and where is the shape conflict.

But I would like to directly add dim info to memo, as #93 (comment) also mentioned, in _array_types, it uses single memo and variadic_memo. In my understanding, the decorater @jaxtyped, it summarized anotation and create memo, then put the summerized infos into memo. (@patrick-kidger , do I understand right? or the mechanism is more complex?) It's more easy and clear to add two functions, check dim and assign dim, for parameters like input, all shape infos are determined. Check dim just compare "symbol" and given variable. For assigin dim, we only add symbol into memo, e.g.

DIM_0 = 10
def repeat_i_times(i: Int, x: Float[Tensor, "dim0"]) -> Float[Tensor, "dim1 dim0"]:
    check_dim("dim0", DIM_0) # assum x has shape same with global variable "DIM_0"
    assign_dim("dim1", i)
    return x.repeat(i, 1)   # output has shape (i, dim0)

Of course these two functions only make sense with @jaxtyped.

And back to parse "self.xxx", Yes, it doesn't solve problem with variable inside function body and class method. In the latter two cases it could only be implemented by assign_name. The motivation is that, for my experiences like machine learning, most of shape check are about checking like "batch data_dim", dimensions like "batch" are dynamic and flexible. Dimensions like "data_dim" are determined once the class is initailized. So I think this feature could be a very helpful feature.

The question is still there: what is the best way to annotate, so that it's most simple, robust and readable. I believe parsing "self.xxx", add function assign_dim() and check_dim() would be very useful.

@anivegesana
Copy link

anivegesana commented Aug 20, 2023

I think we are thinking the same thing. If there is no @jaxtyped, I chose to make assign_dims a no-op. That you described is literally how my assign_dims works. I attached the code and made a minor implementation to add a check_dims.

Treating self in the annotations is a little more involved and I will leave that for later. Since assign_dims("xxx", self.xxx) accomplishes this, I won't worry about it, but you can look into it if you like. I, personally, don't mind doing assign_dims("batch num_channels height width", *image.shape) in my code.

from itertools import islice
from typing import SupportsInt, LiteralString

from jaxtyping._array_types import storage

def _dims(dims: LiteralString, /, *lengths: SupportsInt, assign: bool):
    no_temp_memo = hasattr(storage, "memo_stack") and len(storage.memo_stack) != 0
    if no_temp_memo:
        single_memo: dict[str, int]
        variadic_memo: dict[str, tuple[int, ...]]
        variadic_broadcast_memo: dict[str, list[tuple[int, ...]]]
        single_memo, variadic_memo, variadic_broadcast_memo = storage.memo_stack[-1]

        dim_list = dims.split()
        iter_lengths = iter(lengths)
        seen_variadic = False

        for dim, length in zip(dim_list, iter_lengths, strict=True):
            length = int(length)

            if dim == "...":
                dim = "*_"

            if dim.startswith("*"):
                # We are dealing with a variadic
                if seen_variadic:
                    raise ValueError(f"Muliple variadic shape arguments in {dims!r}")
                seen_variadic = True

                amount_in_variadic = len(lengths) - len(dim_list)
                variadic = (length,) + tuple(int(i) for i in islice(iter_lengths, amount_in_variadic))

                if dim == "*_":
                    continue
                elif (old_variadic := single_memo.get(dim)) is not None:
                    if old_variadic != variadic:
                        raise ValueError(f"Conflict in dim {dim!r}. First found {old_variadic}. Now found {variadic}.")
                elif assign:
                    variadic_memo[dim] = variadic
            else:
                assert dim.isidentifier()

                if dim == "_":
                    continue
                elif (old_length := single_memo.get(dim)) is not None:
                    if old_length != length:
                        raise ValueError(f"Conflict in dim {dim!r}. First found {old_length}. Now found {length}.")
                elif assign:
                    single_memo[dim] = length

def assign_dims(dims: LiteralString, /, *lengths: SupportsInt):
    _dims(dims, *lengths, assign=True)

def check_dims(dims: LiteralString, /, *lengths: SupportsInt):
    _dims(dims, *lengths, assign=False)

It seems that I didn't add a case for variadic broadcasts. Am planning on looking into that later, but you can add support if you have the time.

@patrick-kidger
Copy link
Owner

In terms of making self work: I'm guessing this can probably be done by overriding __get__, which is part of Python's descriptor protocol.

In terms of the assign_dims functionality: I'd suggest that this actually make an isinstance call under the hood. And just use a dummy array type internal to jaxtyping:

def assign_dim(dim: str, value: int):
    isinstance(DummyArray(value), Shaped[DummyArray, dim])

This avoids the needle to meddle with the internal machinery.

@anivegesana
Copy link

Yes. I was unaware that isinstance changes the state of the annotations. I had, at first, thought that it only checked the shape. I think we should go for that as well.

@patrick-kidger patrick-kidger added feature New feature next Higher-priority items labels Aug 21, 2023
@im-Kitsch
Copy link

im-Kitsch commented Sep 5, 2023

A small thing, for dummy array, I think using function like np.empty(0, batch_size) is better than np.zeros(batchsize), since in this case we don't need to consider memory cost and creating unused variable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature next Higher-priority items
Projects
None yet
Development

No branches or pull requests

4 participants