Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CompileProfiler reports graph breaks while dynamo.explain reports no graph breaks #113443

Closed
akihironitta opened this issue Nov 10, 2023 · 4 comments
Assignees
Labels
bug dynamo-logging dynamo-triage-june2024 internal ramp-up task Tasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folks module: dynamo module: logging Features which make it easier to tell what PyTorch is doing under the hood oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@akihironitta
Copy link
Contributor

akihironitta commented Nov 10, 2023

🐛 Describe the bug

I'm seeing inconsistent number of graph breaks between (1) CompileProfiler and (2) dynamo.explain as reproduced with a script below.

Error logs

(1) CompileProfiler

Torchdynamo Profiler Report
===========================

Graph Breaks
------------
Graph breaks happen when torchdynamo encounters code it can't safely trace.
If you want to find out why breaks are happening, check below for each break reason
You may gain additional insight by passing `fullgraph=True` to torch.compile,
to stop at the first break.

Graph Break Reason                     Count
-----------------------------------  -------
hasattr: UserDefinedClassVariable()        1
hasattr no source                          2

Recompilation
-------------
These subgraphs were recompiled more than once due to guard failures
Guard failures indicate some condition assumed to be static by the tracer changed,
making it unsafe to reuse the compiled program.

No recompilation detected

(2) dynamo.explain

Graph Count: 1
Graph Break Count: 0
Op Count: 0
Break Reasons:
Ops per Graph:
  Ops 1:
Out Guards:
  Guard 1:
    Name: "G['edge_type']"
    Source: global
    Create Function: LIST_LENGTH
    Guard Types: ['LIST_LENGTH']
    Code List: ["___check_type_id(G['edge_type'], 94630539513280)", "len(G['edge_type']) == 3"]
    Object Weakref: None
    Guarded Class Weakref: <weakref at 0x7f33bc8b7790; to 'type' at 0x5610e3b14dc0 (tuple)>
  Guard 2:
    Name: ''
    Source: global
    Create Function: DETERMINISTIC_ALGORITHMS
    Guard Types: ['DETERMINISTIC_ALGORITHMS']
    Code List: ['not ___are_deterministic_algorithms_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 3:
    Name: "L['x']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['x'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f32cadd3790; dead>
    Guarded Class Weakref: <weakref at 0x7f330243d760; to 'torch._C._TensorMeta' at 0x5610e753afd0 (Tensor)>
  Guard 4:
    Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 5:
    Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: ['GRAD_MODE']
    Code List: ['___is_grad_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 6:
    Name: "G['edge_type'][0]"
    Source: global
    Create Function: CONSTANT_MATCH
    Guard Types: ['EQUALS_MATCH']
    Code List: ["___check_type_id(G['edge_type'][0], 94630539501152)", "G['edge_type'][0] == 'a'"]
    Object Weakref: None
    Guarded Class Weakref: <weakref at 0x7f33bc8c1030; to 'type' at 0x5610e3b11e60 (str)>
  Guard 7:
    Name: "G['edge_type'][1]"
    Source: global
    Create Function: CONSTANT_MATCH
    Guard Types: ['EQUALS_MATCH']
    Code List: ["___check_type_id(G['edge_type'][1], 94630539501152)", "G['edge_type'][1] == 'to'"]
    Object Weakref: None
    Guarded Class Weakref: <weakref at 0x7f33bc8c1030; to 'type' at 0x5610e3b11e60 (str)>
  Guard 8:
    Name: "G['edge_type'][2]"
    Source: global
    Create Function: CONSTANT_MATCH
    Guard Types: ['EQUALS_MATCH']
    Code List: ["___check_type_id(G['edge_type'][2], 94630539501152)", "G['edge_type'][2] == 'b'"]
    Object Weakref: None
    Guarded Class Weakref: <weakref at 0x7f33bc8c1030; to 'type' at 0x5610e3b11e60 (str)>
  Guard 9:
    Name: ''
    Source: global
    Create Function: TORCH_FUNCTION_STATE
    Guard Types: ['TORCH_FUNCTION_STATE']
    Code List: ['___is_torch_function_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 10:
    Name: ''
    Source: shape_env
    Create Function: SHAPE_ENV
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 11:
    Name: "G['ModuleDict']"
    Source: global
    Create Function: FUNCTION_MATCH
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 12:
    Name: "L['x']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['x'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f32cadd3790; dead>
    Guarded Class Weakref: <weakref at 0x7f330243d760; to 'torch._C._TensorMeta' at 0x5610e753afd0 (Tensor)>
  Guard 13:
    Name: "L['self'].module_dict"
    Source: local_nn_module
    Create Function: NN_MODULE
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 14:
    Name: ''
    Source: global
    Create Function: DETERMINISTIC_ALGORITHMS
    Guard Types: ['DETERMINISTIC_ALGORITHMS']
    Code List: ['not ___are_deterministic_algorithms_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 15:
    Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 16:
    Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: ['GRAD_MODE']
    Code List: ['___is_grad_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 17:
    Name: "L['___stack0']"
    Source: local
    Create Function: CONSTANT_MATCH
    Guard Types: ['EQUALS_MATCH']
    Code List: ["___check_type_id(L['___stack0'], 94630539501152)", "L['___stack0'] == '<a___to___b>'"]
    Object Weakref: None
    Guarded Class Weakref: <weakref at 0x7f33bc8c1030; to 'type' at 0x5610e3b11e60 (str)>
  Guard 18:
    Name: ''
    Source: global
    Create Function: TORCH_FUNCTION_STATE
    Guard Types: ['TORCH_FUNCTION_STATE']
    Code List: ['___is_torch_function_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 19:
    Name: "L['self']"
    Source: local
    Create Function: NN_MODULE
    Guard Types: ['ID_MATCH']
    Code List: ["___check_obj_id(L['self'], 139860182476800)"]
    Object Weakref: <weakref at 0x7f32c24753f0; to 'SomeModel' at 0x7f33bc7ea800>
    Guarded Class Weakref: <weakref at 0x7f33bc8550d0; to 'type' at 0x5610ea6587d0 (SomeModel)>
  Guard 20:
    Name: ''
    Source: shape_env
    Create Function: SHAPE_ENV
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 21:
    Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: ['GRAD_MODE']
    Code List: ['___is_grad_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 22:
    Name: ''
    Source: global
    Create Function: DETERMINISTIC_ALGORITHMS
    Guard Types: ['DETERMINISTIC_ALGORITHMS']
    Code List: ['not ___are_deterministic_algorithms_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 23:
    Name: "L['___stack0']"
    Source: local
    Create Function: NN_MODULE
    Guard Types: ['ID_MATCH']
    Code List: ["___check_obj_id(L['___stack0'], 139860182476848)"]
    Object Weakref: <weakref at 0x7f32c22e8720; to 'Linear' at 0x7f33bc7ea830>
    Guarded Class Weakref: <weakref at 0x7f3301f98ea0; to 'type' at 0x5610e77ae410 (Linear)>
  Guard 24:
    Name: "L['x']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['x'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f32cadd3790; dead>
    Guarded Class Weakref: <weakref at 0x7f330243d760; to 'torch._C._TensorMeta' at 0x5610e753afd0 (Tensor)>
  Guard 25:
    Name: ''
    Source: global
    Create Function: TORCH_FUNCTION_STATE
    Guard Types: ['TORCH_FUNCTION_STATE']
    Code List: ['___is_torch_function_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 26:
    Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 27:
    Name: ''
    Source: shape_env
    Create Function: SHAPE_ENV
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
Compile Times: TorchDynamo compilation metrics:
Function                         Runtimes (s)
-------------------------------  --------------------------------------
_compile.<locals>.compile_inner  0.1674, 0.0107, 0.0078, 0.0017, 0.0185
OutputGraph.call_user_compiler   0.0042

Minified repro

import torch
from torch_geometric.nn.module_dict import ModuleDict

edge_type = ("a", "to", "b")

class SomeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.module_dict = ModuleDict({
            edge_type: torch.nn.Linear(1, 1),
        })

    def forward(self, x):
        key = ModuleDict.to_internal_key(edge_type)
        x = self.module_dict[key](x)
        return x

# (1) shows that module has graph breaks
from torch._dynamo.utils import CompileProfiler
with CompileProfiler() as prof:
    model = torch.compile(SomeModel())
    model(torch.randn(1, 1))
    print(prof.report())

# (2) shows that module has NO graph break
model = SomeModel()
explain = torch._dynamo.explain(model)(torch.randn(1, 1))
print(explain)

For the note, pyg-team/pytorch_geometric@40cc3b1 was used to reproduce this.

Versions

Collecting environment information...
PyTorch version: 2.1.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.22.5
Libc version: glibc-2.31

Python version: 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.13.0-1031-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.5.119
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 510.47.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   46 bits physical, 48 bits virtual
CPU(s):                          16
On-line CPU(s) list:             0-15
Thread(s) per core:              2
Core(s) per socket:              8
Socket(s):                       1
NUMA node(s):                    1
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           85
Model name:                      Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
Stepping:                        7
CPU MHz:                         2499.996
BogoMIPS:                        4999.99
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       256 KiB
L1i cache:                       256 KiB
L2 cache:                        8 MiB
L3 cache:                        35.8 MiB
NUMA node0 CPU(s):               0-15
Vulnerability Itlb multihit:     KVM: Mitigation: VMX unsupported
Vulnerability L1tf:              Mitigation; PTE Inversion
Vulnerability Mds:               Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Mmio stale data:   Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, STIBP disabled, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke avx512_vnni

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] numpy==1.24.1
[pip3] onnx==1.14.1
[pip3] onnxruntime==1.16.0
[pip3] pytorch_frame==0.1.0
[pip3] pytorch-lightning==2.0.9.post0
[pip3] pytorch-memlab==0.3.0
[pip3] torch==2.1.0+cu118
[pip3] torch_frame==0.1.0
[pip3] torch_geometric==2.4.0
[pip3] torchmetrics==1.2.0
[pip3] torchvision==0.16.0
[pip3] triton==2.1.0
[conda] numpy                     1.24.1                   pypi_0    pypi
[conda] pytorch-frame             0.1.0                    pypi_0    pypi
[conda] pytorch-lightning         2.0.9.post0              pypi_0    pypi
[conda] pytorch-memlab            0.3.0                    pypi_0    pypi
[conda] torch                     2.1.0+cu118              pypi_0    pypi
[conda] torch-frame               0.1.0                    pypi_0    pypi
[conda] torch-geometric           2.4.0                    pypi_0    pypi
[conda] torchmetrics              1.2.0                    pypi_0    pypi
[conda] torchvision               0.16.0                   pypi_0    pypi
[conda] triton                    2.1.0                    pypi_0    pypi

cc @ezyang @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @msaroufim @wconstab @bdhirsh @zou3519 @aakhundov

@wconstab
Copy link
Contributor

wconstab commented Nov 10, 2023

Able to repro something almost matching the above. (using torch build from 11/8 - 84d64d7) and using the provided torch_geometric version.

In my case, I see CompileProfiler report 2 graph breaks and I see explain report 1 graph and no breaks.

But i also see a crash at exit due to something going wrong with CleanupManager/CleanupHook

Exception ignored in: <function ExactWeakKeyDictionary.__setitem__.<locals>.<lambda> at 0x7f2b0f138040>                                                                                                 
Traceback (most recent call last):
  File "/data/users/whc/pytorch/torch/_dynamo/utils.py", line 448, in <lambda>                      
  File "/data/users/whc/pytorch/torch/_dynamo/utils.py", line 610, in _remove_id                    
  File "/data/users/whc/pytorch/torch/_dynamo/utils.py", line 592, in __call__                                                                                                                          
ImportError: sys.meta_path is None, Python is likely shutting down            

Might be a totally separate issue, will investigate a bit more and possibly fork a new PR

cc @voznesenskym

@wconstab
Copy link
Contributor

TORCH_LOGS=+dynamo corroborates we are recompiling due to hasattr

[2023-11-10 13:19:04,197] [0/0_1] torch._dynamo.output_graph: [DEBUG] COMPILING GRAPH due to GraphCompileReason(reason='hasattr: UserDefinedClassVariable()', user_stack=[<FrameSummary file /data/users
/whc/pytorch/explain.py, line 14 in forward>, <FrameSummary file /home/whc/local/miniconda3/envs/pytorch/lib/python3.11/site-packages/torch_geometric/nn/module_dict.py, line 34 in to_internal_key>], graph_break=True)                   
[2023-11-10 13:19:04,220] [1/0_1] torch._dynamo.output_graph: [DEBUG] COMPILING GRAPH due to GraphCompileReason(reason='hasattr: UserDefinedClassVariable()', user_stack=[<FrameSummary file /home/whc/l
ocal/miniconda3/envs/pytorch/lib/python3.11/site-packages/torch_geometric/nn/module_dict.py, line 34 in to_internal_key>], graph_break=True)

@wconstab wconstab added bug module: dynamo triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Nov 10, 2023
@penguinwu penguinwu added module: logging Features which make it easier to tell what PyTorch is doing under the hood internal ramp-up task Tasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folks labels Jan 3, 2024
@anijain2305
Copy link
Contributor

We should probably retire CompileProfiler. We have TORCH_LOGS that supersedes everything (even explain).

Keeping this open but with the focus on sunsetting CompilerProfiler and update docs.

@anijain2305 anijain2305 self-assigned this Sep 4, 2024
anijain2305 added a commit that referenced this issue Sep 4, 2024
Fixes confusion in #113443

We have TORCH_LOGS that supersedes CompileProfiler

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Sep 4, 2024
Fixes confusion in #113443

We have TORCH_LOGS that supersedes CompileProfiler

ghstack-source-id: e5d387a780cd58b3e1da1050b5f0b49ae354867a
Pull Request resolved: #135133
pytorchmergebot pushed a commit that referenced this issue Sep 5, 2024
Fixes confusion in #113443

We have TORCH_LOGS that supersedes CompileProfiler

Pull Request resolved: #135133
Approved by: https://github.com/ezyang
ghstack dependencies: #135039, #135121, #135129, #135130
@akihironitta
Copy link
Contributor Author

We have TORCH_LOGS that supersedes everything (even explain).

@anijain2305 How would you suggest we test the number of graph breaks in a model? We are currently using torch._dynamo.explain and its output attribute graph_break_count, but I wonder if there's another recommended way.

facebook-github-bot pushed a commit to pytorch/benchmark that referenced this issue Sep 6, 2024
Summary:
Fixes confusion in pytorch/pytorch#113443

We have TORCH_LOGS that supersedes CompileProfiler

X-link: pytorch/pytorch#135133
Approved by: https://github.com/ezyang
ghstack dependencies: #135039, #135121, #135129, #135130

Reviewed By: kit1980

Differential Revision: D62277271

Pulled By: anijain2305

fbshipit-source-id: 95a1c844252a83a17b942d46664034af10edc4c2
tolleybot pushed a commit to tolleybot/pytorch that referenced this issue Sep 14, 2024
Fixes confusion in pytorch#113443

We have TORCH_LOGS that supersedes CompileProfiler

Pull Request resolved: pytorch#135133
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#135039, pytorch#135121, pytorch#135129, pytorch#135130
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this issue Sep 20, 2024
Fixes confusion in pytorch#113443

We have TORCH_LOGS that supersedes CompileProfiler

Pull Request resolved: pytorch#135133
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#135039, pytorch#135121, pytorch#135129, pytorch#135130
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug dynamo-logging dynamo-triage-june2024 internal ramp-up task Tasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folks module: dynamo module: logging Features which make it easier to tell what PyTorch is doing under the hood oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants