Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] torch.cuda.device context manager doesn't work #128059

Open
yanboliang opened this issue Jun 5, 2024 · 0 comments
Open

[Dynamo] torch.cuda.device context manager doesn't work #128059

yanboliang opened this issue Jun 5, 2024 · 0 comments
Labels
dynamo-ctx-manager dynamo-triage-june2024 module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yanboliang
Copy link
Contributor

yanboliang commented Jun 5, 2024

🐛 Describe the bug

Repro

import torch

@torch.compile(fullgraph=True)
def fn(x):
    with torch.cuda.device(x.device.index):
        x = x + 1

    return x

x = torch.randn(10, device="cuda")
print(fn(x))

Error

Traceback (most recent call last):
  File "/data/users/ybliang/debug/debug7.py", line 11, in <module>
    print(fn(x))
  File "/home/ybliang/local/pytorch/torch/_dynamo/eval_frame.py", line 421, in _fn
    return fn(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 1078, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 456, in _convert_frame_assert
    return _compile(
  File "/home/ybliang/local/pytorch/torch/_utils_internal.py", line 83, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/home/ybliang/local/pytorch/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 799, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/ybliang/local/pytorch/torch/_dynamo/utils.py", line 232, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 618, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/ybliang/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1184, in transform_code_object
    transformations(instructions, code_options)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 177, in _fn
    return fn(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 564, in transform
    tracer.run()
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2450, in run
    super().run()
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 892, in run
    while self.step():
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 804, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 498, in wrapper
    return inner_fn(self, inst)
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1458, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 742, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/ybliang/local/pytorch/torch/_dynamo/variables/user_defined.py", line 392, in call_function
    var.call_method(tx, "__init__", args, kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/variables/user_defined.py", line 641, in call_method
    return UserMethodVariable(method, self, source=source).call_function(
  File "/home/ybliang/local/pytorch/torch/_dynamo/variables/functions.py", line 342, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/variables/functions.py", line 91, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 748, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2665, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2709, in inline_call_
    result = InliningInstructionTranslator.check_inlineable(func)
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2686, in check_inlineable
    unimplemented(
  File "/home/ybliang/local/pytorch/torch/_dynamo/exc.py", line 220, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: 'inline in skipfiles: device.__init__ | __init__ /home/ybliang/local/pytorch/torch/cuda/__init__.py, skipped according trace_rules.lookup SKIP_DIRS'

from user code:
   File "/data/users/ybliang/debug/debug7.py", line 5, in fn
    with torch.cuda.device(x.device.index):

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

N/A

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@yanboliang yanboliang added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: dynamo labels Jun 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-ctx-manager dynamo-triage-june2024 module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants