Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile(fullgraph=True): can't pass lambdas to hooks? #116220

Closed
zou3519 opened this issue Dec 20, 2023 · 1 comment
Closed

torch.compile(fullgraph=True): can't pass lambdas to hooks? #116220

zou3519 opened this issue Dec 20, 2023 · 1 comment
Labels
dynamo-triage-june2024 feature A request for a proper, new feature. module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@zou3519
Copy link
Contributor

zou3519 commented Dec 20, 2023

import torch
v = torch.tensor([0., 0., 0.], requires_grad=True)

@torch.compile(backend="aot_eager", fullgraph=True)
def f(v):
    lr = 0.01
    #    simulate a simple SGD update
    h = v.register_hook(lambda p: p.add_(p.grad, alpha=-lr))
    
    return v.clone(), h

print(v)
k, h = f(v)
v.backward(torch.tensor([1., 2., 3.]))
print(v)

raises Unsupported: Unexpected callable type passed to register_hook

cc @ezyang @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @msaroufim @wconstab @bdhirsh @aakhundov

@yf225 yf225 added feature A request for a proper, new feature. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamo labels Dec 21, 2023
@masnesral
Copy link
Contributor

I believe this is fixed. I get the same (new) error with and without torch.compile:
TypeError: add_(): argument 'other' (position 1) must be Tensor, not NoneType

But if I change to:
h = v.register_hook(lambda p: p.add_(p, alpha=-lr))

The script runs fine without an Unexpected callable type passed to register_hook

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-triage-june2024 feature A request for a proper, new feature. module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants