Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile incorrect when imperative autograd APIs are used #91468

Open
zou3519 opened this issue Dec 28, 2022 · 6 comments
Open

torch.compile incorrect when imperative autograd APIs are used #91468

zou3519 opened this issue Dec 28, 2022 · 6 comments
Labels
module: aotdispatch umbrella label for AOTAutograd issues module: autograd Related to torch.autograd, and the autograd engine in general module: correctness (silent) issue that returns an incorrect result silently module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, months oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@zou3519
Copy link
Contributor

zou3519 commented Dec 28, 2022

Issue description

torch.compile may be silently incorrect when Tensor.retain_grad or Tensor.register_hook are involved.

Code example

Example with retain_grad:

 # Test 1:
 # Tensor.retain_grad inside function
 # on intermediate, and on output
 # See if the grad can be used :P

 def f(x):
     y = x.clone()
     y.retain_grad()
     z = y.clone()
     z.retain_grad()
     return z, y

 of = torch.compile(backend='aot_eager')(f)

 x = torch.randn([], requires_grad=True)
 z, y = of(x)
 z.clone().backward()
 # inspect z.grad, y.grad
 # Bug: they do not exist. May lead to silent correctness problems

Example with register_hook:

# If you register a hook on an intermediate, it won't work.
def f(x):
    y = x * x
    z = y * x * x * x * 1 * 1 * 1. * 1. * 1.
    print("graph_break")
    y.register_hook(lambda x: x if x is None else 3.14 * x)
    result, = torch.autograd.grad(z, x)
    return result

of = torch.compile(backend='aot_eager')(f)
x = torch.tensor(1., requires_grad=True)

expected = f(x)
result = of(x)
# assertion failed
assert torch.allclose(result, expected)

cc @ezyang @gchanan @kadeng @albanD @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @msaroufim @bdhirsh @anijain2305 @chauhang @wconstab @soumith @ngimel

@samdow samdow added oncall: pt2 module: autograd Related to torch.autograd, and the autograd engine in general labels Jan 3, 2023
@ngimel ngimel added module: aotdispatch umbrella label for AOTAutograd issues high priority labels Jan 3, 2023
@ngimel
Copy link
Collaborator

ngimel commented Jan 3, 2023

cc @Chillee, high priority for silently incorrect results.

@albanD
Copy link
Collaborator

albanD commented Jan 3, 2023

Hook handling is a general issue in Dynamo yes...

@ngimel
Copy link
Collaborator

ngimel commented Jan 22, 2023

Related #91665

@ngimel ngimel added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jan 22, 2023
@voznesenskym
Copy link
Collaborator

Still a bug

@voznesenskym voznesenskym assigned bdhirsh and voznesenskym and unassigned bdhirsh Nov 14, 2023
@penguinwu penguinwu added the module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, label Nov 29, 2023
@zou3519 zou3519 added the months label Mar 18, 2024
@zou3519
Copy link
Contributor Author

zou3519 commented Mar 18, 2024

graph-breaking isn't sufficient; the problem is that the autograd.Function we generate for AOTAutograd changes autograd semantics. This isn't an easy fix.

@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 20, 2024

Talking with @zou3519:

Not sure how we can fix this at all in compile, but we might want to consider erroring on this behavior if we don't allow it.

One option (using the register_hook repro from the first post): AOTAutograd create a graph that has y and z both as graph outputs. One option is that AOTAutograd can detect that y is a "special" output, because the other output z was computed off of y. We can then set a bit on y that says you are not allowed to register any hooks to y, since they will not fire properly.

This seems difficult to do though.

@bdhirsh bdhirsh added module: correctness (silent) issue that returns an incorrect result silently and removed high priority labels Jun 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: aotdispatch umbrella label for AOTAutograd issues module: autograd Related to torch.autograd, and the autograd engine in general module: correctness (silent) issue that returns an incorrect result silently module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, months oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants