Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch.compile Error: RuntimeError: aten::_conj() Expected a value of type 'Tensor' for argument 'self' but instead found type 'complex'. #105290

Closed
lyndonlauder opened this issue Jul 16, 2023 · 11 comments
Assignees
Labels
actionable high priority module: complex Related to complex number support in PyTorch module: functionalization used for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch) module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lyndonlauder
Copy link

lyndonlauder commented Jul 16, 2023

🐛 Describe the bug

Training code

manual_seed(args.seed)
torch.backends.cudnn.benchmark = True

with open(args.model_path+'/config.yaml') as f:
    config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
config.training.num_steps = args.num_steps

trainset = MSSDatasets(config, args.data_root)

train_loader = DataLoader(
    trainset, 
    batch_size=config.training.batch_size, 
    shuffle=True, 
    num_workers=args.num_workers, 
    pin_memory=args.pin_memory
)

model = TFC_TDF_net(config)
model = torch.compile(model)
model.train()

device_ids = args.device_ids
if type(device_ids)==int:
    device = torch.device(f'cuda:{device_ids}')
    model = model.to(device)
else:
    device = torch.device(f'cuda:{device_ids[0]}')
    model = nn.DataParallel(model, device_ids=device_ids).to(device)

optimizer = Adam(model.parameters(), lr=config.training.lr)

print('Train Loop')
scaler = GradScaler()    
for batch in tqdm(train_loader):   

    y = batch.to(device)
    x = y.sum(1)  # mixture   
    if config.training.target_instrument is not None:
        i = config.training.instruments.index(config.training.target_instrument)
        y = y[:,i]
    with torch.cuda.amp.autocast():        
        y_ = model(x)   
        loss = nn.MSELoss()(y_, y) 

    scaler.scale(loss).backward()
    if config.training.grad_clip:
        nn.utils.clip_grad_norm_(model.parameters(), config.training.grad_clip)  
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)


state_dict = model.state_dict() if type(device_ids)==int else model.module.state_dict()

torch.save(state_dict, args.model_path+'/ckpt')

if name == "main":
train()`

Model code

> class STFT:
>     def __init__(self, config):
>         self.n_fft = config.n_fft
>         self.hop_length = config.hop_length
>         self.window = torch.hann_window(window_length=self.n_fft, periodic=True)        
>         self.dim_f = config.dim_f
>     
>     def __call__(self, x):
>         window = self.window.to(x.device)
>         batch_dims = x.shape[:-2]
>         c, t = x.shape[-2:]
>         x = x.reshape([-1, t])
>         x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, return_complex=False)
>         x = x.permute([0,3,1,2])
>         x = x.reshape([*batch_dims,c,2,-1,x.shape[-1]]).reshape([*batch_dims,c*2,-1,x.shape[-1]])
>         return x[...,:self.dim_f,:]
> 
>     def inverse(self, x):
>         window = self.window.to(x.device)
>         batch_dims = x.shape[:-3]
>         c,f,t = x.shape[-3:]
>         n = self.n_fft//2+1
>         f_pad = torch.zeros([*batch_dims,c,n-f,t]).to(x.device)
>         x = torch.cat([x, f_pad], -2)
>         x = x.reshape([*batch_dims,c//2,2,n,t]).reshape([-1,2,n,t])
>         x = x.permute([0,2,3,1])
>         x = x[...,0] + x[...,1] * 1.j
>         x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
>         x = x.reshape([*batch_dims,2,-1])
>         return x
> 
>     
> def get_norm(norm_type):
>     def norm(c, norm_type):   
>         if norm_type=='BatchNorm':
>             return nn.BatchNorm2d(c)
>         elif norm_type=='InstanceNorm':
>             return nn.InstanceNorm2d(c, affine=True)
>         elif 'GroupNorm' in norm_type:
>             g = int(norm_type.replace('GroupNorm', ''))
>             return nn.GroupNorm(num_groups=g, num_channels=c)
>         else:
>             return nn.Identity()
>     return partial(norm, norm_type=norm_type)
> 
> 
> def get_act(act_type):
>     if act_type=='gelu':
>         return nn.GELU()
>     elif act_type=='relu':
>         return nn.ReLU()
>     elif act_type[:3]=='elu':
>         alpha = float(act_type.replace('elu', ''))
>         return nn.ELU(alpha)
>     else:
>         raise Exception
> 
>         
> class Upscale(nn.Module):
>     def __init__(self, in_c, out_c, scale, norm, act):
>         super().__init__()
>         self.conv = nn.Sequential(
>             norm(in_c),
>             act,  
>             nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
>         )
>                                   
>     def forward(self, x):
>         return self.conv(x)
> 
> 
> class Downscale(nn.Module):
>     def __init__(self, in_c, out_c, scale, norm, act):
>         super().__init__()
>         self.conv = nn.Sequential(
>             norm(in_c),
>             act,   
>             nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
>         )
>                                   
>     def forward(self, x):
>         return self.conv(x)
> 
> 
> class TFC_TDF(nn.Module):
>     def __init__(self, in_c, c, l, f, bn, norm, act):        
>         super().__init__()
> 
>         self.blocks = nn.ModuleList()
>         for i in range(l): 
>             block = nn.Module()
>             
>             block.tfc1 = nn.Sequential(
>                 norm(in_c),
>                 act,
>                 nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
>             )
>             block.tdf = nn.Sequential(
>                 norm(c),
>                 act,
>                 nn.Linear(f, f//bn, bias=False),
>                 norm(c),
>                 act,
>                 nn.Linear(f//bn, f, bias=False),
>             )
>             block.tfc2 = nn.Sequential(
>                 norm(c),
>                 act,
>                 nn.Conv2d(c, c, 3, 1, 1, bias=False),
>             )
>             block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
>             
>             self.blocks.append(block)
>             in_c = c
>               
>     def forward(self, x):
>         for block in self.blocks:
>             s = block.shortcut(x)
>             x = block.tfc1(x)
>             x = x + block.tdf(x)
>             x = block.tfc2(x)
>             x = x + s
>         return x
> 
> 
> class TFC_TDF_net(nn.Module):
>     def __init__(self, config):
>         super().__init__()
>         self.config = config
>         
>         norm = get_norm(norm_type=config.model.norm)
>         act = get_act(act_type=config.model.act)
>         
>         self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
>         self.num_subbands = config.model.num_subbands
>         
>         dim_c = self.num_subbands * config.audio.num_channels * 2         
>         n = config.model.num_scales
>         scale = config.model.scale
>         l = config.model.num_blocks_per_scale 
>         c = config.model.num_channels
>         g = config.model.growth
>         bn = config.model.bottleneck_factor               
>         f = config.audio.dim_f // self.num_subbands
>         
>         self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
>  
>         self.encoder_blocks = nn.ModuleList()
>         for i in range(n):
>             block = nn.Module()
>             block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
>             block.downscale = Downscale(c, c+g, scale, norm, act) 
>             f = f//scale[1]
>             c += g
>             self.encoder_blocks.append(block)                
>                
>         self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
>         
>         self.decoder_blocks = nn.ModuleList()
>         for i in range(n):                
>             block = nn.Module()
>             block.upscale = Upscale(c, c-g, scale, norm, act)
>             f = f*scale[1]
>             c -= g  
>             block.tfc_tdf = TFC_TDF(2*c, c, l, f, bn, norm, act)
>             self.decoder_blocks.append(block) 
>               
>         self.final_conv = nn.Sequential(
>             nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
>             act,
>             nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
>         )
>         
>         self.stft = STFT(config.audio)
>     
>     def cac2cws(self, x):
>         k = self.num_subbands
>         b,c,f,t = x.shape
>         x = x.reshape(b,c,k,f//k,t)
>         x = x.reshape(b,c*k,f//k,t)
>         return x
>     
>     def cws2cac(self, x):
>         k = self.num_subbands
>         b,c,f,t = x.shape
>         x = x.reshape(b,c//k,k,f,t)
>         x = x.reshape(b,c//k,f*k,t)
>         return x
>     
>     def forward(self, x):
>         
>         x = self.stft(x)
>         
>         mix = x = self.cac2cws(x)
>         
>         first_conv_out = x = self.first_conv(x)
> 
>         x = x.transpose(-1,-2)
>         
>         encoder_outputs = []
>         for block in self.encoder_blocks:  
>             x = block.tfc_tdf(x) 
>             encoder_outputs.append(x)
>             x = block.downscale(x)              
>             
>         x = self.bottleneck_block(x)
>         
>         for block in self.decoder_blocks:            
>             x = block.upscale(x)
>             x = torch.cat([x, encoder_outputs.pop()], 1)
>             x = block.tfc_tdf(x) 
>             
>         x = x.transpose(-1,-2)
>         
>         x = x * first_conv_out  # reduce artifacts
>         
>         x = self.final_conv(torch.cat([mix, x], 1))
>         
>         x = self.cws2cac(x)
>         
>         if self.num_target_instruments > 1:
>             b,c,f,t = x.shape
>             x = x.reshape(b,self.num_target_instruments,-1,f,t)
>         
>         x = self.stft.inverse(x)
>         
>         return x

Error logs

  0% 0/1000000 [00:00<?, ?it/s][2023-07-16 14:24:31,474] torch._inductor.utils: [WARNING] DeviceCopy in input program
  0% 0/1000000 [01:07<?, ?it/s]
Traceback (most recent call last):
  File "/content/sdx23/my_submission/src/train.py", line 120, in <module>
    train()
  File "/content/sdx23/my_submission/src/train.py", line 91, in train
    out = model(x)
       ^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1531, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 183, in forward
    return self.module(*inputs[0], **module_kwargs[0])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1531, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1531, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/sdx23/my_submission/src/tfc_tdf_v3.py", line 196, in forward
    def forward(self, x):
  File "/content/sdx23/my_submission/src/tfc_tdf_v3.py", line 198, in <resume in forward>
    x = self.stft(x)
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 447, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 128, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 364, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 179, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 434, in _compile
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1002, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 419, in transform
    tracer.run()
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2068, in run
    super().run()
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 727, in run
    and self.step()
        ^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 687, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 441, in wrapper
    self.output.compile_subgraph(self, reason=reason)
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 815, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/usr/local/envs/mdx-net/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 915, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 179, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 971, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 967, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/__init__.py", line 1548, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1045, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 3750, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 179, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 3289, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 2098, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 2278, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 2686, in aot_dispatch_autograd
    fx_g = aot_dispatch_autograd_graph(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 2663, in aot_dispatch_autograd_graph
    fx_g = create_functionalized_graph(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1399, in create_functionalized_graph
    fx_g = make_fx(helper, decomposition_table=aot_config.decompositions)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 809, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 468, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 817, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 684, in flatten_fn
    tree_out = root_fn(*tree_args)
               ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 485, in wrapped
    out = f(*tensors)
          ^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1388, in joint_helper
    return functionalized_f_helper(primals, tangents)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1341, in functionalized_f_helper
    f_outs = fn(*f_args)
             ^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1312, in inner_fn_with_anomaly
    return inner_fn(*args)
           ^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1295, in inner_fn
    backward_out = torch.autograd.grad(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/autograd/__init__.py", line 319, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 555, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 580, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 361, in proxy_call
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_ops.py", line 437, in __call__
    return self._op(*args, **kwargs or {})
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: aten::_conj() Expected a value of type 'Tensor' for argument 'self' but instead found type 'complex'.
Position: 0
Value: 1j
Declaration: aten::_conj(Tensor(a) self) -> Tensor(a)
Cast error details: Unable to cast 1j to Tensor


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Minified repro

No response

Versions

# packages in environment at /usr/local/envs/mdx-net:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
absl-py                   1.4.0                    pypi_0    pypi
antlr4-python3-runtime    4.9.3                    pypi_0    pypi
attrs                     23.1.0                   pypi_0    pypi
bzip2                     1.0.8                h7f98852_4    conda-forge
ca-certificates           2023.5.7             hbcca054_0    conda-forge
certifi                   2023.5.7                 pypi_0    pypi
cffi                      1.15.1                   pypi_0    pypi
charset-normalizer        3.2.0                    pypi_0    pypi
click                     8.1.5                    pypi_0    pypi
cloudpickle               2.2.1                    pypi_0    pypi
cmake                     3.26.4                   pypi_0    pypi
contextlib2               21.6.0                   pypi_0    pypi
cython                    0.29.36                  pypi_0    pypi
demucs                    4.0.0                    pypi_0    pypi
diffq                     0.2.4                    pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
dora-search               0.1.12                   pypi_0    pypi
einops                    0.6.1                    pypi_0    pypi
ffmpeg-python             0.2.0                    pypi_0    pypi
filelock                  3.12.2                   pypi_0    pypi
fsspec                    2023.4.0                 pypi_0    pypi
future                    0.18.3                   pypi_0    pypi
gitdb                     4.0.10                   pypi_0    pypi
gitpython                 3.1.32                   pypi_0    pypi
idna                      3.4                      pypi_0    pypi
jinja2                    3.1.2                    pypi_0    pypi
jsonschema                4.18.3                   pypi_0    pypi
jsonschema-specifications 2023.6.1                 pypi_0    pypi
julius                    0.2.7                    pypi_0    pypi
lameenc                   1.5.1                    pypi_0    pypi
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
libexpat                  2.5.0                hcb278e6_1    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 13.1.0               he5830b7_0    conda-forge
libgomp                   13.1.0               he5830b7_0    conda-forge
libnsl                    2.0.0                h7f98852_0    conda-forge
libsqlite                 3.42.0               h2797004_0    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
lit                       16.0.6                   pypi_0    pypi
markupsafe                2.1.3                    pypi_0    pypi
mir-eval                  0.7                      pypi_0    pypi
ml-collections            0.1.1                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
musdb                     0.4.0                    pypi_0    pypi
museval                   0.4.1                    pypi_0    pypi
ncurses                   6.4                  hcb278e6_0    conda-forge
networkx                  3.1                      pypi_0    pypi
numpy                     1.25.1                   pypi_0    pypi
nvidia-cublas-cu11        11.10.3.66               pypi_0    pypi
nvidia-cuda-cupti-cu11    11.7.101                 pypi_0    pypi
nvidia-cuda-nvrtc-cu11    11.7.99                  pypi_0    pypi
nvidia-cuda-runtime-cu11  11.7.99                  pypi_0    pypi
nvidia-cudnn-cu11         8.5.0.96                 pypi_0    pypi
nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
nvidia-curand-cu11        10.2.10.91               pypi_0    pypi
nvidia-cusolver-cu11      11.4.0.1                 pypi_0    pypi
nvidia-cusparse-cu11      11.7.4.91                pypi_0    pypi
nvidia-nccl-cu11          2.14.3                   pypi_0    pypi
nvidia-nvtx-cu11          11.7.91                  pypi_0    pypi
omegaconf                 2.3.0                    pypi_0    pypi
openssl                   3.1.1                hd590300_1    conda-forge
openunmix                 1.2.1                    pypi_0    pypi
pandas                    2.0.3                    pypi_0    pypi
pathtools                 0.1.2                    pypi_0    pypi
pip                       23.2               pyhd8ed1ab_0    conda-forge
promise                   2.3                      pypi_0    pypi
protobuf                  3.20.3                   pypi_0    pypi
psutil                    5.9.5                    pypi_0    pypi
pyaml                     23.7.0                   pypi_0    pypi
pycparser                 2.21                     pypi_0    pypi
python                    3.11.4          hab00c5b_0_cpython    conda-forge
python-dateutil           2.8.2                    pypi_0    pypi
pytorch-triton            2.1.0+3c400e7818          pypi_0    pypi
pytz                      2023.3                   pypi_0    pypi
pyyaml                    6.0                      pypi_0    pypi
readline                  8.2                  h8228510_1    conda-forge
referencing               0.29.1                   pypi_0    pypi
requests                  2.31.0                   pypi_0    pypi
retrying                  1.3.4                    pypi_0    pypi
rpds-py                   0.8.10                   pypi_0    pypi
scipy                     1.11.1                   pypi_0    pypi
sentry-sdk                1.28.1                   pypi_0    pypi
setproctitle              1.3.2                    pypi_0    pypi
setuptools                68.0.0             pyhd8ed1ab_0    conda-forge
shortuuid                 1.0.11                   pypi_0    pypi
simplejson                3.19.1                   pypi_0    pypi
six                       1.16.0                   pypi_0    pypi
smmap                     5.0.0                    pypi_0    pypi
soundfile                 0.12.1                   pypi_0    pypi
stempeg                   0.2.3                    pypi_0    pypi
submitit                  1.4.5                    pypi_0    pypi
sympy                     1.12                     pypi_0    pypi
tk                        8.6.12               h27826a3_0    conda-forge
torch                     2.1.0.dev20230716+cu118          pypi_0    pypi
torchaudio                2.0.2                    pypi_0    pypi
tqdm                      4.65.0                   pypi_0    pypi
treetable                 0.2.5                    pypi_0    pypi
triton                    2.0.0                    pypi_0    pypi
typing-extensions         4.7.1                    pypi_0    pypi
tzdata                    2023.3                   pypi_0    pypi
urllib3                   2.0.3                    pypi_0    pypi
wandb                     0.13.2                   pypi_0    pypi
wheel                     0.40.0             pyhd8ed1ab_1    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge

cc @ezyang @gchanan @zou3519 @kadeng @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved @amjames @bdhirsh @msaroufim @anijain2305 @chauhang @wconstab

@indrajeetapache
Copy link

Which version you are using .

  1. There is no torch.compile available .
  2. Check your return type ing Torch.compile Error: RuntimeError: aten::_conj() because this expects the value the of type tensors .

Can you put the type (print(type(Value)) to check the type of the value you are inserting

@ezyang ezyang added the module: complex Related to complex number support in PyTorch label Jul 16, 2023
@williamwen42
Copy link
Member

Do you get the same error when you try backend="eager" or backend="aot_eager" in the torch.compile call?

@lyndonlauder
Copy link
Author

@williamwen42 No error with backend="eager" but for backend="aot_eager" i get this error

File "/usr/local/envs/mdx-net/lib/python3.11/site-packages/torch/_ops.py", line 437, in call
return self._op(*args, **kwargs or {})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
RuntimeError: aten::_conj() Expected a value of type 'Tensor' for argument 'self' but instead found type 'complex'.
Position: 0
Value: 1j
Declaration: aten::_conj(Tensor(a) self) -> Tensor(a)
Cast error details: Unable to cast 1j to Tensor

@williamwen42 williamwen42 added module: functionalization used for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jul 18, 2023
@williamwen42
Copy link
Member

@bdhirsh any thoughts?

@msaroufim
Copy link
Member

Just wanna add another data point to this discussion, we haven't been super highly prioritizing complex numbers because the argument was that people can rewrite their code to avoid using complex numbers using eulers identity

So I tried to do that here #105665 (comment) for rotary embeddings and got 10x slowdowns so in theory the rewrite argument is maybe fine but then we should be a more prescriptive about how to do those rewrites. Good starting point is probably rotary positional embeddings

@lyndonlauder
Copy link
Author

hi @williamwen42 @bdhirsh please do you have any further comments for getting torch.compile to work with my code?

@williamwen42
Copy link
Member

williamwen42 commented Aug 16, 2023

It seems to me that there's an op using complex numbers that we don't support yet. It's hard for me to further investigate without having a runnable code example (can you provide code with hardcoded config and random input?). You can try seeing the output when setting the envvar TORCH_LOGS=aot_graphs to see if you can find the problematic op.

@penguinwu penguinwu added the module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, label Nov 29, 2023
@bdhirsh
Copy link
Contributor

bdhirsh commented Dec 12, 2023

Hey @lyndonlauder if this is still breaking for you, do you mind including a self-contained repro? That will make diagnosing easier. The repro above isn't runnable(e.g. args doesn't exist)

@bdhirsh bdhirsh added the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label Dec 12, 2023
@ezyang ezyang removed the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label Mar 6, 2024
@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2024

This problem also affects AudioLM.

https://gist.github.com/ezyang/6914fb5dedff6598dc673288898a7498 has the graph going into AOTAutograd and a stack trace

Full repro code: gist.github.com/ezyang/64c24c9fc5529f3afed4ee4266f6adc5 but it fails before you get to this error; you need to bypass a different error by disabling compile in vector_quantize_pytorch/vector_quantize_pytorch.py in EuclideanCodebook.forward

See also https://docs.google.com/document/d/14uWNDXa10I_Dpq1KxYShJwYmjHT0xbmv38ZWYCN40NE/edit

@zou3519
Copy link
Contributor

zou3519 commented Jun 20, 2024

Putting this back into the queue for a bit.

@bdhirsh bdhirsh self-assigned this Jul 22, 2024
@bdhirsh
Copy link
Contributor

bdhirsh commented Jul 22, 2024

Very simple min repro:

import torch

@torch.compile(backend="aot_eager")
def f(x):
    return torch.mul(x, 1j)


x = torch.randn(4, dtype=torch.complex64, requires_grad=True)
out = f(x)

The problem is roughly:

(1) the python arg parser converts 1j to a tensor, and the autograd engine sees a call to aten.mul.Tensor(self, other)
(2) the derivative rule says to compute other.conj()
(3) during tracing, this results in aten._conj(other)
(4) however: other here is a scalar-tensor. Today, we convert scalar-tensors back into scalars when plumbing them into torch_dispatch. The input to FunctionalTensorMode.__torch_dispatch__ becomes a python scalar (1j in the repro above), and it fails when we attempt to redispatch to aten._conj` on the python scalar.

Still thinking about what the right thing to do here is... somewhere, we want to arrange for the scalar to become a tensor again.

bdhirsh added a commit that referenced this issue Jul 23, 2024
…k to python"

#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 23, 2024
#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 24, 2024
…k to python"

#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 24, 2024
#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 24, 2024
…k to python"

#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 24, 2024
#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 25, 2024
…k to python"

#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 25, 2024
#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 25, 2024
…k to python"

#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 25, 2024
#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 25, 2024
…k to python"

#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 25, 2024
#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 26, 2024
…k to python"

#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jul 26, 2024
#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Jul 26, 2024
#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.

Pull Request resolved: #131482
Approved by: https://github.com/zou3519, https://github.com/ezyang
ghstack dependencies: #131403
@ezyang ezyang closed this as completed Aug 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable high priority module: complex Related to complex number support in PyTorch module: functionalization used for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch) module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests