Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support setattr of arbitrary user provided types in tracing #93511

Open
shunting314 opened this issue Jan 17, 2023 · 4 comments
Open

support setattr of arbitrary user provided types in tracing #93511

shunting314 opened this issue Jan 17, 2023 · 4 comments
Assignees
Labels
bug dynamo-nn-modules dynamo-symbolic-analysis dynamo-triage-june2024 module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@shunting314
Copy link
Contributor

shunting314 commented Jan 17, 2023

🐛 Describe the bug

Dynamo already support patching nn.Module attribute outside of forward call (e.g. during model initialization): #91018 . But some use cases (e.g. detectrons's RCNN model) need patch nn.Module attribute in forward method ( fb internal link: https://fburl.com/code/vvekrxl6 ). Dynamo does not support this right now.

Error logs

Traceback (most recent call last):
File "", line 1, in
File "/home/shunting/cpython/build/install/lib/python3.9/importlib/init.py", line 169, in reload
_bootstrap._exec(spec, module)
File "", line 613, in _exec
File "", line 790, in exec_module
File "", line 228, in _call_with_frames_removed
File "/home/shunting/learn/misc.py", line 20, in
gm, guards = dynamo.export(MyModule(), *inputs, aten_graph=True, tracing_mode="symbolic")
File "/home/shunting/pytorch/torch/_dynamo/eval_frame.py", line 616, in export
result_traced = opt_f(*args, **kwargs)
File "/home/shunting/pytorch/torch/nn/modules/module.py", line 1482, in _call_impl
return forward_call(*args, **kwargs)
File "/home/shunting/pytorch/torch/_dynamo/eval_frame.py", line 82, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/home/shunting/pytorch/torch/_dynamo/eval_frame.py", line 211, in _fn
return fn(*args, **kwargs)
File "/home/shunting/pytorch/torch/_dynamo/eval_frame.py", line 332, in catch_errors
return callback(frame, cache_size, hooks)
File "/home/shunting/pytorch/torch/_dynamo/convert_frame.py", line 103, in _fn
return fn(*args, **kwargs)
File "/home/shunting/pytorch/torch/_dynamo/utils.py", line 90, in time_wrapper
r = func(*args, **kwargs)
File "/home/shunting/pytorch/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
return _compile(
File "/home/shunting/pytorch/torch/_dynamo/convert_frame.py", line 398, in _compile
out_code = transform_code_object(code, transform)
File "/home/shunting/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/home/shunting/pytorch/torch/_dynamo/convert_frame.py", line 385, in transform
tracer.run()
File "/home/shunting/pytorch/torch/_dynamo/symbolic_convert.py", line 1686, in run
super().run()
File "/home/shunting/pytorch/torch/_dynamo/symbolic_convert.py", line 537, in run
and self.step()
File "/home/shunting/pytorch/torch/_dynamo/symbolic_convert.py", line 500, in step
getattr(self, inst.opname)(inst)
File "/home/shunting/pytorch/torch/_dynamo/symbolic_convert.py", line 1048, in STORE_ATTR
BuiltinVariable(setattr)
File "/home/shunting/pytorch/torch/_dynamo/variables/builtin.py", line 375, in call_function
return super().call_function(tx, args, kwargs)
File "/home/shunting/pytorch/torch/_dynamo/variables/base.py", line 230, in call_function
unimplemented(f"call_function {self} {args} {kwargs}")
File "/home/shunting/pytorch/torch/_dynamo/exc.py", line 67, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_function BuiltinVariable(setattr) [UserDefinedClassVariable(), ConstantVariable(str), GetAttrVariable(UserDefinedClassVariable(), run_cos)] {}

from user code:
File "/home/shunting/learn/misc.py", line 10, in forward
MyModule.run = MyModule.run_cos

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True

Minified repro

import torch
from torch import nn
import torch._dynamo as dynamo

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        MyModule.run = MyModule.run_cos
        return self.run(x)

    def run(self, x):
        return torch.sin(x)

    def run_cos(self, x):
        return torch.cos(x)

inputs = [torch.rand(5)]
gm, guards = dynamo.export(MyModule(), *inputs, aten_graph=True, tracing_mode="symbolic")
print(f"Graph is {gm.graph}")

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @soumith @wconstab @ngimel

@shunting314
Copy link
Contributor Author

cc @voznesenskym @suo @ezyang

@voznesenskym
Copy link
Collaborator

voznesenskym commented Jan 17, 2023

Hmm, the problem is not in it being a patch - but that it's patched with UserDefinedClassVariable's run_cos(it has no idea this is module code or not, it just thinks you're calling some arbitrary class as input to setattr) - I think this is a weird usecase, fwiw, and the link (https://www.internalfb.com/code/fbsource/[5fdf103084ca30e4afb25738dc2bbf9a8a95be05]/fbcode/vision/fair/detectron2/detectron2/export/caffe2_modeling.py?lines=274) looks like a context manager, not a patch?

Also, the issue title is wrong, it should be "support setattr of arbitrary user provided types in tracing" or smth like that

@shunting314
Copy link
Contributor Author

@voznesenskym the context manager does the patch in it's implementation :)

@shunting314
Copy link
Contributor Author

@shunting314 shunting314 changed the title Support patch nn.Module attribute in forward call support setattr of arbitrary user provided types in tracing Jan 17, 2023
@malfet malfet transferred this issue from pytorch/torchdynamo Feb 1, 2023
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 7, 2023
@anijain2305 anijain2305 self-assigned this Sep 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug dynamo-nn-modules dynamo-symbolic-analysis dynamo-triage-june2024 module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants