Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Exceptions in Pytorch Export #123499

Open
FabianSchuetze opened this issue Apr 6, 2024 · 3 comments
Open

Support Exceptions in Pytorch Export #123499

FabianSchuetze opened this issue Apr 6, 2024 · 3 comments
Assignees
Labels
export-triage-review This tag is used to tag issues that have been looked by PT2 Export team and pending discussions. module: dynamo oncall: export

Comments

@FabianSchuetze
Copy link
Contributor

FabianSchuetze commented Apr 6, 2024

馃殌 The feature, motivation and pitch

Exceptions are currently not supported in pytorch export. Consider the following example documenting this:

import os
import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def validate(self, x):
        return torch.all(x >0)

    def compute(self, x):
        return x*x

    def forward(self, x):
        if not self.validate(x):
            raise RuntimeError("Exception")
        y = self.compute(x)
        if not self.validate(y):
            raise RuntimeError("Another exception")
        return self.compute(y)


with torch.no_grad():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = Model().to(device=device)
    example_inputs=(torch.randn(8, 10, device=device),)
    torch.export.export(model, example_inputs)

However, exception are an idiomatic way to report errors and bail-out early. For example, I count 278 uses of exceptions in diffusers models folder. Furthermore, exceptions are supported in torch.jit.

Is there any way exception can be supported in torch.compile? What would it take to implement this feature?

Alternatives

Alternatives are splitting the function into subfunctions, branching with torch.cond and then return empty tensors as status message if a predicate is not satisfied. That generates much nested code, but is also not uncommon in embedded c programs which forbid using exceptions.

Additional context

No response

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@tugsbayasgalan
Copy link
Contributor

We actually support python asserts today. As a workaround, you could try changing them into asserts.

@tugsbayasgalan tugsbayasgalan added export-triaged This tag is used to tag issues that have been looked by PT2 Export team and determined the next step export-triage-review This tag is used to tag issues that have been looked by PT2 Export team and pending discussions. module: dynamo and removed export-triaged This tag is used to tag issues that have been looked by PT2 Export team and determined the next step labels Apr 8, 2024
@FabianSchuetze
Copy link
Contributor Author

FabianSchuetze commented Apr 9, 2024

Thanks a lot, @tugsbayasgalan, for your reply! I wasn't aware that an assert statement could be used.

With assertion, I can catch thrown assertions of an exported program like so:

try:
    res = exported_program(example_inputs)
except:
    print("Caught exception")

However, try/catch statements can't be used inside a module. Nevertheless, most exceptions in library code (e.g. diffusers) are not caught but propagated to the user anyway.

The ticket could be closed, if you want.

@tugsbayasgalan tugsbayasgalan self-assigned this Apr 9, 2024
@anijain2305
Copy link
Contributor

@FabianSchuetze TorchDynamo (and thus export) supports try .. except handling, bit it does not support raising the exceptions.

If you have code like

try:
	.... raise AssertionError()
except:
    ... handle_exception()

This should work today. But in your example, we can just raise the assertion, which is hard to fix w/o graph breaks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
export-triage-review This tag is used to tag issues that have been looked by PT2 Export team and pending discussions. module: dynamo oncall: export
Projects
None yet
Development

No branches or pull requests

4 participants