Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InternalTorchDynamoError on KL Divergences #120497

Closed
White-Link opened this issue Feb 23, 2024 · 3 comments
Closed

InternalTorchDynamoError on KL Divergences #120497

White-Link opened this issue Feb 23, 2024 · 3 comments
Labels
dynamo-must-fix These bugs affect TorchDynamo reliability. module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@White-Link
Copy link

White-Link commented Feb 23, 2024

🐛 Describe the bug

PyTorch throws an InternalTorchDynamoError exception when compiling the computation of a KL divergence, independently of the devices of the inputs. I am working with VAEs so supporting KL divergences would be a nice addition to PyTorch compilation! Thanks!

Error logs

trace.txt

  File "/home/user/.pyenv/versions/env/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 651, in call_function
    result = handler(tx, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/.pyenv/versions/env/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 999, in call_getitem
    return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/.pyenv/versions/env/lib/python3.11/site-packages/torch/_dynamo/variables/dicts.py", line 91, in call_method
    return self.getitem_const(args[0])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/.pyenv/versions/env/lib/python3.11/site-packages/torch/_dynamo/variables/dicts.py", line 71, in getitem_const
    return self.items[ConstDictVariable.get_key(arg)]
           ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.InternalTorchDynamoError: (<class 'torch.distributions.normal.Normal'>, <class 'torch.distributions.normal.Normal'>)

from user code:
   File "/home/user/test.py", line 9, in f
    return D.kl_divergence(d1, d2)
  File "/home/user/.pyenv/versions/env/lib/python3.11/site-packages/torch/distributions/kl.py", line 183, in kl_divergence
    fun = _KL_MEMOIZE[type(p), type(q)]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Minified repro

import torch
import torch.distributions as D


@torch.compile
def f():
    d1 = D.Normal(0, 1)
    d2 = D.Normal(2, 1)
    return D.kl_divergence(d1, d2)


print(f())

Versions

env.txt

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng

@ezyang ezyang added module: dynamo dynamo-must-fix These bugs affect TorchDynamo reliability. labels Feb 24, 2024
@Chillee Chillee added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 27, 2024
@Chillee
Copy link
Contributor

Chillee commented Feb 27, 2024

@yanboliang I think this is something that probably can be fixed by modifying the skipfiles?

@yanboliang
Copy link
Contributor

yanboliang commented Feb 27, 2024

@yanboliang I think this is something that probably can be fixed by modifying the skipfiles?

No, this is because we don't support arbitrary user defined class as dict keys. The complicated part is how to generate guards of DICT_KEYS match for arbitrary classes. But for this particular case, it's not that hard, because the keys are torch classes and we already added torch to closure vars used for guards eval:

"torch": torch,

I can send a fix for this.

@anijain2305
Copy link
Contributor

Seems to work on main

V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code] TRACED GRAPH
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]  ===== __compiled_fn_2 =====
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]  /data/users/anijain/pytorch2/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]     def forward(self):
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]          # File: /data/users/anijain/pytorch2/torch/distributions/utils.py:51 in <listcomp>, code: v if is_tensor_like(v) else torch.tensor(v, **options) for v in values
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]         tensor: "f32[][]cpu" = torch.tensor(0, dtype = torch.float32)
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]         tensor_1: "f32[][]cpu" = torch.tensor(1, dtype = torch.float32)
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]          # File: /data/users/anijain/pytorch2/torch/distributions/utils.py:53 in broadcast_all, code: return torch.broadcast_tensors(*new_values)
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]         broadcast_tensors = torch.functional.broadcast_tensors(tensor, tensor_1);  tensor = tensor_1 = None
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]         getitem: "f32[][]cpu" = broadcast_tensors[0]
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]         getitem_1: "f32[][]cpu" = broadcast_tensors[1];  broadcast_tensors = None
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]          # File: /data/users/anijain/pytorch2/torch/distributions/utils.py:51 in <listcomp>, code: v if is_tensor_like(v) else torch.tensor(v, **options) for v in values
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]         tensor_2: "f32[][]cpu" = torch.tensor(2, dtype = torch.float32)
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]         tensor_3: "f32[][]cpu" = torch.tensor(1, dtype = torch.float32)
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]          # File: /data/users/anijain/pytorch2/torch/distributions/utils.py:53 in broadcast_all, code: return torch.broadcast_tensors(*new_values)
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]         broadcast_tensors_1 = torch.functional.broadcast_tensors(tensor_2, tensor_3);  tensor_2 = tensor_3 = None
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]         getitem_2: "f32[][]cpu" = broadcast_tensors_1[0]
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]         getitem_3: "f32[][]cpu" = broadcast_tensors_1[1];  broadcast_tensors_1 = None
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]         return (getitem_1, getitem, getitem_3, getitem_2)
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]
V0617 15:31:15.531000 140565508674688 torch/_dynamo/output_graph.py:1293] [0/0_1] [__graph_code]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-must-fix These bugs affect TorchDynamo reliability. module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants