Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autograd.Function] freevar lifting is too aggressive? #106894

Closed
zou3519 opened this issue Aug 9, 2023 · 1 comment
Closed

[autograd.Function] freevar lifting is too aggressive? #106894

zou3519 opened this issue Aug 9, 2023 · 1 comment
Labels
dynamo-must-fix These bugs affect TorchDynamo reliability. module: dynamo module: higher order operators torch.cond and similar oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@zou3519
Copy link
Contributor

zou3519 commented Aug 9, 2023

🐛 Describe the bug

Seems like the freevar lifting things it needs to lift constants and then graph breaks if they are used in autograd.Function backwards:

import torch
from typing import *

def halve(x):
    return x * 0.5


class ScaleGradient(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x

    @staticmethod
    def backward(ctx, grad):
        return halve(x)

x = torch.randn(3, requires_grad=True)
def f(x):
    return ScaleGradient.apply(x)
output = torch.compile(f, backend='eager', fullgraph=True)(x)

gives:

    f.call_function(tx, sub_args, {})
  File "/raid/rzou/pt/debug-cpu2/torch/_dynamo/variables/higher_order_ops.py", line 950, in call_funct
ion
    unimplemented("NYI - freevars in autograd function.")
  File "/raid/rzou/pt/debug-cpu2/torch/_dynamo/exc.py", line 143, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: NYI - freevars in autograd function.

from user code:
   File "/raid/rzou/pt/debug-cpu2/foo.py", line 20, in f
    return ScaleGradient.apply(x)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

cc @ezyang @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @ydwu4 @msaroufim @bdhirsh @aakhundov @wconstab @Xia-Weiwen

Error logs

No response

Minified repro

No response

Versions

main

@zou3519 zou3519 changed the title [autograd.Function] freevar lifting is too conservative? [autograd.Function] freevar lifting is too aggressive? Aug 9, 2023
@bdhirsh bdhirsh added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamo labels Aug 29, 2023
@anijain2305 anijain2305 added the dynamo-must-fix These bugs affect TorchDynamo reliability. label Jan 31, 2024
@anijain2305 anijain2305 added the module: higher order operators torch.cond and similar label Jun 18, 2024
@zou3519
Copy link
Contributor Author

zou3519 commented Jun 24, 2024

Succeeds now

@zou3519 zou3519 closed this as completed Jun 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-must-fix These bugs affect TorchDynamo reliability. module: dynamo module: higher order operators torch.cond and similar oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants