Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] TB hf_Reformer graph breaks #101154

Open
yanboliang opened this issue May 11, 2023 · 5 comments
Open

[Dynamo] TB hf_Reformer graph breaks #101154

yanboliang opened this issue May 11, 2023 · 5 comments
Labels
dynamo-triage-june2024 module: dynamo module: graph breaks oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yanboliang
Copy link
Contributor

yanboliang commented May 11, 2023

🐛 Describe the bug

Repro:

import torch
import logging
import sys
import torch._dynamo

# torch._logging.set_logs(dynamo=logging.DEBUG, bytecode=True)
torch._dynamo.config.print_graph_breaks = True

import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):

    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(5, 5)
        self.dropout = torch.nn.Dropout()

    def _init_attention_seed(self):
        """
        This function sets a new seed for the attention layer to make dropout deterministic for both forward calls: 1
        normal forward call and 1 forward call in backward to recalculate activations.
        """

        # randomize seeds
        # use cuda generator if available
        if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0:
            # GPU
            device_idx = torch.cuda.current_device()
            self.attention_seed = torch.cuda.default_generators[device_idx].seed()
        else:
            # CPU
            self.attention_seed = int(torch.seed() % sys.maxsize)

        torch.manual_seed(self.attention_seed)

    def forward(self, x):
        self._init_attention_seed()
        return self.dropout(self.linear(x))

x = torch.randn(5, 5)

m = MyModel()
print(m(x))

opt_m = torch.compile(backend="eager")(m)
print(opt_m(x))

There are several graph breaks:

[2023-05-11 04:12:58,513] torch._dynamo.symbolic_convert: [WARNING] Graph break: hasattr: TorchVariable(<module 'torch.cuda' from '/scratch/ybliang/work/repos/pytorch/torch/cuda/__init__.py'>) from user code at   File "/scratch/ybliang/work/repos/debug/debug3.py", line 39, in forward
    self._init_attention_seed()
  File "/scratch/ybliang/work/repos/debug/debug3.py", line 28, in _init_attention_seed
    if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0:

[2023-05-11 04:12:58,748] torch._dynamo.symbolic_convert: [WARNING] Graph break: inlining disallowed: <function current_device at 0x7f2ec26d8430> from user code at   File "/scratch/ybliang/work/repos/debug/debug3.py", line 30, in <resume in _init_attention_seed>
    device_idx = torch.cuda.current_device()

[2023-05-11 04:12:58,754] torch._dynamo.symbolic_convert: [WARNING] Graph break: call_method UserDefinedObjectVariable(seed) __call__ [] {} from user code at   File "/scratch/ybliang/work/repos/debug/debug3.py", line 31, in <resume in _init_attention_seed>
    self.attention_seed = torch.cuda.default_generators[device_idx].seed()

Versions

N/A

cc @ezyang @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @soumith @msaroufim @wconstab @ngimel @bdhirsh @Xia-Weiwen @desertfire

@yanboliang yanboliang added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: dynamo labels May 11, 2023
@anijain2305
Copy link
Contributor

anijain2305 commented Jun 2, 2023

Repro for the 1st error

import torch

@torch.compile(fullgraph=True)
def fn(x):
    if hasattr(torch.cuda, "default_generators"):
        return x
    return x + 1

fn(torch.randn(1))

Repro for the 3rd error

import torch

@torch.compile(fullgraph=True)
def fn(x):
    device_idx = 0
    n_seed = torch.cuda.default_generators[device_idx].seed()
    return x + 1

fn(torch.randn(1))

@yanboliang
Copy link
Contributor Author

yanboliang commented Jun 2, 2023

1/ Need to add call_hasattr to TorchVariable.
2/ Intended graph break, don't need to fix.
3/ AOT autograd can't handle Generator.seed well, need more discussion and defer to after hackday.

@gmagogsfm
Copy link
Contributor

issue 1 is addressed, 3 is still there. And according to tests, there are still 44 graph breaks remaining in this model

@anijain2305
Copy link
Contributor

3/ It seems we should also graph break in this case. If a model has torch.cuda.default_generators[device_idx].seed() they probably want to reset random state in every invocation. torch.compile should respect that.

@tugsbayasgalan
Copy link
Contributor

@yanboliang, @anijain2305 any update on this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-triage-june2024 module: dynamo module: graph breaks oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants