Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamo-based ONNX Export: Failed to produce a graph during tracing as no tensor operations were found. #123973

Open
asfiyab-nvidia opened this issue Apr 12, 2024 · 6 comments
Labels
oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@asfiyab-nvidia
Copy link

asfiyab-nvidia commented Apr 12, 2024

🐛 Describe the bug

I'm using Dynamo export on the FullyConnectedNet model by MONAI.

cc @ezyang @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @msaroufim @bdhirsh @anijain2305 @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @moraxu

I'm using the following script for export:

'''
Install:
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
pip3 install monai
'''

import torch
from monai.networks.nets import FullyConnectedNet

model = FullyConnectedNet(
    10,
    3,
    [8, 16],
    0.15
).to('cuda')
data = torch.randn(4, 10).cuda()

torch._dynamo.export(model, args=(data,))

The export error is below

Traceback (most recent call last):
  File "/opt/pytorch/repro_123973.py", line 18, in <module>
    torch._dynamo.export(model, args=(data,))
  File "/opt/pytorch/torch/_dynamo/eval_frame.py", line 1367, in export
    return inner(*extra_args, **extra_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/eval_frame.py", line 1237, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/convert_frame.py", line 978, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/convert_frame.py", line 411, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/opt/pytorch/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ptca/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/convert_frame.py", line 700, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/utils.py", line 266, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/convert_frame.py", line 568, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/bytecode_transformation.py", line 1116, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/pytorch/torch/_dynamo/convert_frame.py", line 173, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/convert_frame.py", line 515, in transform
    tracer.run()
  File "/opt/pytorch/torch/_dynamo/symbolic_convert.py", line 2237, in run
    super().run()
  File "/opt/pytorch/torch/_dynamo/symbolic_convert.py", line 875, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/symbolic_convert.py", line 790, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/pytorch/torch/_dynamo/symbolic_convert.py", line 492, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/symbolic_convert.py", line 1301, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/opt/pytorch/torch/_dynamo/symbolic_convert.py", line 730, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/_dynamo/variables/nn_module.py", line 275, in call_function
    assert not kwargs
AssertionError: 

from user code:
   File "/opt/pytorch/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Versions

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] optree==0.10.0
[pip3] pytorch-quantization==2.1.2
[pip3] pytorch-triton==3.0.0+989adb9a29
[pip3] torch==2.4.0.dev20240412+cu121
[pip3] torch-tensorrt==2.3.0a0
[pip3] torchaudio==2.2.0.dev20240412+cu121
[pip3] torchdata==0.7.1a0
[pip3] torchtext==0.17.0a0
[pip3] torchvision==0.19.0.dev20240412+cu121
[pip3] triton==2.2.0+e28a256

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

@albanD albanD added module: onnx Related to torch.onnx oncall: pt2 labels Apr 12, 2024
@thiagocrepaldi
Copy link
Collaborator

thiagocrepaldi commented Apr 15, 2024

@asfiyab-nvidia try this

'''
Install:
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
pip3 install onnxscript==0.1.0.dev20240405
pip3 install monai
'''

import torch
from monai.networks.nets import FullyConnectedNet

model = FullyConnectedNet(
    10,
    3,
    [8, 16],
    0.15
).to('cuda')
data = torch.randn(4, 10).cuda()

model = torch.export.export(model, args=(data,))
export_output = torch.onnx.dynamo_export(model, data)
export_output.save('fullyconnectednet.onnx')

It should work

For some reason, when we wrap this model to flatten its output before export, the model becomes not exportable

This is the piece that does the wrapping, leading to this issue

    # torch/onnx/_internal/fx/dynamo_graph_extractor.py
    def generate_fx(
        self,
        options: exporter.ResolvedExportOptions,
        model: Union[torch.nn.Module, Callable],
        model_args: Sequence[Any],
        model_kwargs: Mapping[str, Any],
    ) -> torch.fx.GraphModule:
        # `dynamo.export` does not recognize custom user defined classes as output type.
        # Apply wrapper to adapt the outputs back to `dynamo.export` compatible types,
        # i.e. :class:`torch.Tensor`.
        dynamo_flatten_output_step = DynamoFlattenOutputStep()
        wrapped_model = _wrap_model_with_output_adapter(
            model, dynamo_flatten_output_step
        )
        (...)

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 15, 2024
@asfiyab-nvidia
Copy link
Author

Thanks for the comment @thiagocrepaldi . Your suggestion to call torch.export.export(...) before ONNX export worked and the model exports successfully now.

Since this is a WAR to enable export, shall I keep the ticket open until the export doesn't require torch.export.export(...)?

@thiagocrepaldi
Copy link
Collaborator

Thanks for the comment @thiagocrepaldi . Your suggestion to call torch.export.export(...) before ONNX export worked and the model exports successfully now.

Since this is a WAR to enable export, shall I keep the ticket open until the export doesn't require torch.export.export(...)?

Use the workaround, but let's keep this ticket open for triage on dynamo side. I have updated the ticket accordingly

@thiagocrepaldi thiagocrepaldi removed module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 17, 2024
@thiagocrepaldi
Copy link
Collaborator

thiagocrepaldi commented Apr 17, 2024 via email

@jbschlosser jbschlosser added module: dynamo triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 23, 2024
@ssslakter
Copy link

Also just ran into the same problem by following an introduction guide https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html

What's Interesting is that it fails if you use build-in modules instead of user defined class

import torch, torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)

    def forward(self, x): return self.conv1(x)

model1 = nn.Sequential(nn.Conv2d(1, 6, 5))
model2 = MyModel()
model2.conv1 = model1[0]
torch_input = torch.randn(1, 1, 32, 32)
assert torch.isclose(model1(torch_input), model2(torch_input)).all()

# works
onnx_program = torch.onnx.dynamo_export(model2, torch_input)

# fails
onnx_program = torch.onnx.dynamo_export(model1, torch_input)

@anijain2305
Copy link
Contributor

cc @avikchaudhuri @angelayi on the export side

From what I understand, this is expected behavior though undesirable. torch.export via Dynamo does not inline through inbuilt nn modules (like Sequential, Conv2D) etc.

torch.compile has recently started tracing inside inbuilt nn modules, but that feature is disabled for export, because of input matching requirements in export.

I will let export folks guide on the non-strict mode if necessary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Inbox
Development

No branches or pull requests

7 participants