You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import torch
import fbgemm_gpu
import torch.nn.functional as F
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@torch.compile(mode="default", fullgraph=False)
def forward(
x: torch.Tensor,
x_offsets: torch.Tensor,
):
x = torch.ops.fbgemm.dense_to_jagged(x, [x_offsets])[0]
x = F.layer_norm(x, normalized_shape=[x.shape[-1]], eps=1e-6)
return x
for i in range(10):
x = torch.rand(128, 200, 256).to("cuda")
y = torch.randint(0, 12000, (129, ))
y = torch.sort(y)[0]
y[0] = 0
y[-1] = 12000
y = y.to('cuda')
print(forward(x=x, x_offsets=y))
Traceback (most recent call last):
File "/home/test.py", line 23, in <module>
print(forward(x=x, x_offsets=y))
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 403, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 979, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 820, in _convert_frame
result = inner_convert(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 411, in _convert_frame_assert
return _compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 70, in wrapper_function
return function(*args, **kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 701, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 568, in compile_inner
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 173, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 515, in transform
tracer.run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2234, in run
super().run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 884, in run
while self.step():
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 799, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2424, in RETURN_VALUE
self._return(inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2409, in _return
self.output.compile_subgraph(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1078, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1295, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1386, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1367, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 1745, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 1454, in compile_fx
return aot_autograd(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/backends/common.py", line 65, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py", line 958, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py", line 685, in create_aot_dispatcher_function
compiled_fn = compiler_fn(
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 470, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 672, in aot_wrapper_synthetic_base
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 169, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 1358, in fw_compiler_base
return inner_compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/debug.py", line 304, in inner
return fn(*args, **kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 483, in compile_fx_inner
compiled_graph = fx_codegen_and_compile(
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 779, in fx_codegen_and_compile
compiled_fn = graph.compile_to_fn()
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py", line 1695, in compile_to_fn
return self.compile_to_module().call
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py", line 1638, in compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py", line 1590, in codegen
self.scheduler = Scheduler(self.buffers)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/scheduler.py", line 1353, in __init__
self.fuse_nodes()
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/scheduler.py", line 1743, in fuse_nodes
self.fuse_nodes_once()
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/scheduler.py", line 1970, in fuse_nodes_once
for node1, node2 in self.get_possible_fusions():
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/scheduler.py", line 2027, in get_possible_fusions
check_all_pairs(node_grouping)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/scheduler.py", line 2014, in check_all_pairs
if self.can_fuse(node1, node2):
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/scheduler.py", line 2193, in can_fuse
return self.get_backend(device).can_fuse_vertical(node1, node2)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codegen/cuda_combined_scheduling.py", line 39, in can_fuse_vertical
return self._triton_scheduling.can_fuse_vertical(node1, node2)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codegen/triton.py", line 3249, in can_fuse
return self.can_fuse_horizontal(node2, node1)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codegen/triton.py", line 3222, in can_fuse
if not all(
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codegen/triton.py", line 3223, in <genexpr>
TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codegen/triton.py", line 1538, in is_compatible
cls._split_iteration_ranges(groups, lengths)
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codegen/triton.py", line 1501, in _split_iteration_ranges
and sv.size_hint(remaining[current_group]) == 1
File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/sizevars.py", line 423, in size_hint
return int(out)
File "/usr/local/lib/python3.10/dist-packages/sympy/core/expr.py", line 320, in __int__
raise TypeError("Cannot convert symbols to int")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: Cannot convert symbols to int
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Inserting print(x.shape) between x = torch.ops.fbgemm.dense_to_jagged(x, [x_offsets])[0] and x = F.layer_norm(x, normalized_shape=[x.shape[-1]], eps=1e-6) makes the code run without error.
This would be an example of code in Inductor assuming something is constant only but actually it could be a dynamic symbol. The fix is to adjust the code that it can be dynamic, somehow.
Thanks for your kind response! Can I know what induction assumes as constant? Would it be something like x.shape?
Also, could you provide some guide or example code that change something be dynamic?
🐛 Describe the bug
Inserting
print(x.shape)
betweenx = torch.ops.fbgemm.dense_to_jagged(x, [x_offsets])[0]
andx = F.layer_norm(x, normalized_shape=[x.shape[-1]], eps=1e-6)
makes the code run without error.Versions
Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] onnx==1.16.0
[pip3] optree==0.11.0
[pip3] pytorch-quantization==2.1.2
[pip3] pytorch-triton==3.0.0
[pip3] torch==2.4.0a0
[pip3] torch-tensorrt==2.4.0a0
[pip3] torchvision==0.19.0a0
cc @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire
The text was updated successfully, but these errors were encountered: