Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] 'torch._C.ScriptFunction' object has no attribute '__defaults__' #93698

Closed
msaroufim opened this issue Sep 27, 2022 · 4 comments
Closed
Assignees
Labels
dynamo-must-fix These bugs affect TorchDynamo reliability. dynamo-triage-june2024 high priority module: dynamo module: guards oncall: pt2 triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@msaroufim
Copy link
Member

msaroufim commented Sep 27, 2022

Repro

Another lucidrains model
pip install imagen-pytorch

import torch
from imagen_pytorch import Unet, Imagen
import torchdynamo

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

@torchdynamo.optimize('inductor')
def train():
    for i in (1, 2):
        loss = imagen(images, text_embeds = text_embeds, unet_number = i)
        loss.backward()

train()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = imagen.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
], cond_scale = 3.)

images.shape # (3, 3, 256, 256)

Logs

(dynamo) ubuntu@ip-172-31-31-152:~/tests$ python imggen.py 
Downloading: 100%|█████████████████████| 605/605 [00:00<00:00, 484kB/s]
The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/
torchdynamo.symbolic_convert: [WARNING] Graph break: call_function in skip_files /opt/conda/envs/dynamo/lib/python3.8/functools.py from user code at   File "imggen.py", line 45, in train
    loss = imagen(images, text_embeds = text_embeds, unet_number = i)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2451, in forward
    cond_images = maybe(cast_uint8_images_to_float)(cond_images)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 44, in maybe
    @wraps(fn)

torchdynamo.symbolic_convert: [WARNING] Graph break: call_function BuiltinVariable(callable) [NestedUserFunctionVariable()] {} from user code at   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2457, in <graph break in forward>
    unet = default(unet, lambda: self.get_unet(unet_number))
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 67, in default
    return d() if callable(d) else d

torchdynamo.symbolic_convert: [WARNING] Graph break: call_function BuiltinVariable(delattr) [NNModuleVariable(), ConstantVariable(str)] {} from user code at   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2457, in <lambda>
    unet = default(unet, lambda: self.get_unet(unet_number))
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 1928, in get_unet
    delattr(self, 'unets')

torchdynamo.symbolic_convert: [WARNING] Graph break: non-const NNModule method to from user code at   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 1933, in <graph break in get_unet>
    unet.to(self.device if unet_index == index else 'cpu')

torchdynamo.symbolic_convert: [WARNING] Graph break: missing: BUILD_MAP_UNPACK_WITH_CALL from user code at   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2470, in <graph break in forward>
    check_shape(images, 'b c ...', c = self.channels)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/einops_exts/einops_exts.py", line 12, in check_shape
    return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/einops/einops.py", line 487, in rearrange
    return reduce(tensor, pattern, reduction='rearrange', **axes_lengths)

torchdynamo.symbolic_convert: [WARNING] Graph break: call_function UserDefinedObjectVariable(_lru_cache_wrapper) [UserDefinedObjectVariable(TransformRecipe), ShapeVariable()] {} from user code at   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/einops/einops.py", line 233, in _apply_recipe
    _reconstruct_from_shape(recipe, backend.shape(tensor))

torchinductor.graph: [WARNING] Creating implicit fallback for:
  target: aten.uniform_.default
  args[0]: TensorBox(StorageBox(
    Pointwise(
      'cuda',
      torch.float32,
      constant(0, torch.float32),
      ranges=[4],
      origins={zeros}
    )
  ))
  args[1]: 0.0
  args[2]: 0.999
torchinductor.ir: [WARNING] Using FallbackKernel: torch.ops.aten.uniform_.default
torchinductor.graph: [WARNING] Creating implicit fallback for:
  target: aten.randn_like.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s1, s2, s2], stride=[s1*s2**2, s2**2, s2, 1]))
  ))
  kwargs: {'dtype': torch.float32, 'layout': torch.strided, 'device': device(type='cuda', index=0), 'pin_memory': False}
torchinductor.ir: [WARNING] Using FallbackKernel: torch.ops.aten.randn_like.default
torchdynamo.symbolic_convert: [WARNING] Graph break: call_function UserDefinedObjectVariable(ScriptFunction) [TensorVariable()] {} from user code at   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2361, in <graph break in p_losses>
    x_noisy, log_snr = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 251, in q_sample
    log_snr = self.log_snr(t).type(dtype)

torchdynamo.symbolic_convert: [WARNING] Graph break: call_function UserDefinedObjectVariable(ScriptFunction) [TensorVariable()] {} from user code at   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 214, in <graph break in get_condition>
    return maybe(self.log_snr)(times)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 48, in inner
    return fn(x)

torchdynamo.variables.builtin: [WARNING] incorrect arg count <bound method BuiltinVariable.call_dict of BuiltinVariable(dict)> missing a required argument: 'arg'
torchdynamo.symbolic_convert: [WARNING] Graph break: call_function BuiltinVariable(dict) [] {'text_embeds': TensorVariable(), 'text_mask': TensorVariable(), 'cond_images': ConstantVariable(NoneType), 'lowres_noise_times': ConstantVariable(NoneType), 'lowres_cond_img': ConstantVariable(NoneType), 'cond_drop_prob': ConstantVariable(float)} from user code at   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2377, in <graph break in p_losses>
    unet_kwargs = dict(

torchdynamo.convert_frame: [ERROR] WON'T CONVERT <graph break in p_losses> /opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py line 2377 
due to: 
Traceback (most recent call last):
  File "/home/ubuntu/torchdynamo/torchdynamo/variables/nn_module.py", line 53, in unpack_var_sequence
    assert isinstance(
AssertionError: Unet

from user code:
   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2393, in <graph break in p_losses>
    if self_cond and random() < 0.5:

Set torchdynamo.config.verbose=True for more information
==========
torchdynamo.symbolic_convert: [WARNING] Graph break: non-const NNModule method _get_item_by_idx from user code at   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/torch/nn/modules/container.py", line 107, in __getitem__
    return self._get_item_by_idx(self._modules.values(), idx)

torchinductor.ir: [WARNING] Using FallbackKernel: torch.ops.aten.uniform_.default
torchdynamo.symbolic_convert: [WARNING] Graph break: missing: BUILD_MAP_UNPACK_WITH_CALL from user code at   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 1586, in <graph break in forward>
    text_tokens = self.attn_pool(text_tokens)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 444, in forward
    latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/einops/einops.py", line 537, in repeat
    return reduce(tensor, pattern, reduction='repeat', **axes_lengths)

torchdynamo.convert_frame: [ERROR] WON'T CONVERT forward /opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py line 363 
due to: 
Traceback (most recent call last):
  File "/home/ubuntu/torchdynamo/torchdynamo/symbolic_convert.py", line 1637, in LOAD_CLOSURE
    self.push(self.closure_cells[inst.argval])
KeyError: 'fn'

from user code:
   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 375, in forward
    q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)

Set torchdynamo.config.verbose=True for more information
==========
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchdynamo.symbolic_convert: [WARNING] Graph break: call_function UserDefinedObjectVariable(Always) [TensorVariable()] {} from user code at   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 713, in <graph break in forward>
    h = h * self.gca(h)

torchdynamo.convert_frame: [ERROR] WON'T CONVERT <graph break in forward> /opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py line 702 
due to: 
Traceback (most recent call last):
  File "/home/ubuntu/torchdynamo/torchdynamo/symbolic_convert.py", line 1637, in LOAD_CLOSURE
    self.push(self.closure_cells[inst.argval])
KeyError: 'fn'

from user code:
   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 713, in <graph break in forward>
    h = h * self.gca(h)

Set torchdynamo.config.verbose=True for more information
==========
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchdynamo.convert_frame: [ERROR] WON'T CONVERT forward /opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py line 923 
due to: 
Traceback (most recent call last):
  File "/home/ubuntu/torchdynamo/torchdynamo/symbolic_convert.py", line 1637, in LOAD_CLOSURE
    self.push(self.closure_cells[inst.argval])
KeyError: 'fn'

from user code:
   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 925, in forward
    x, context = rearrange_many((x, context), 'b n ... -> b n (...)')

Set torchdynamo.config.verbose=True for more information
==========
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchdynamo.convert_frame: [ERROR] WON'T CONVERT forward /opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py line 750 
due to: 
Traceback (most recent call last):
  File "/home/ubuntu/torchdynamo/torchdynamo/symbolic_convert.py", line 1637, in LOAD_CLOSURE
    self.push(self.closure_cells[inst.argval])
KeyError: 'fn'

from user code:
   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 758, in forward
    q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)

Set torchdynamo.config.verbose=True for more information
==========
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchdynamo.convert_frame: [ERROR] WON'T CONVERT <graph break in forward> /opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py line 497 
due to: 
Traceback (most recent call last):
  File "/home/ubuntu/torchdynamo/torchdynamo/symbolic_convert.py", line 1637, in LOAD_CLOSURE
    self.push(self.closure_cells[inst.argval])
KeyError: 'fn'

from user code:
   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 502, in <graph break in forward>
    nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)

Set torchdynamo.config.verbose=True for more information
==========
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.graph: [WARNING] Creating implicit fallback for:
  target: aten.pixel_shuffle.default
  args[0]: TensorBox(StorageBox(
    Pointwise(
      'cuda',
      torch.float32,
      silu(load(buf0, i3 + 8 * i2 + 64 * i1 + 32768 * i0) + load(arg1_1, i1)),
      ranges=[s2, 512, 8, 8],
      origins={silu}
    )
  ))
  args[1]: 2
torchinductor.ir: [WARNING] Using FallbackKernel: torch.ops.aten.pixel_shuffle.default
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.ir: [WARNING] Using FallbackKernel: torch.ops.aten.pixel_shuffle.default
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.ir: [WARNING] Using FallbackKernel: torch.ops.aten.pixel_shuffle.default
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchdynamo.convert_frame: [WARNING] torchdynamo hit config.cache_size_limit (64)
   function: 'forward' (/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py:640)
   reasons:  ['___check_obj_id(self, 140496746038752)']
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.
torchdynamo.symbolic_convert: [WARNING] Graph break: call_function UserDefinedObjectVariable(ScriptFunction) [TensorVariable()] {} from user code at   File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 280, in predict_start_from_noise
    log_snr = self.log_snr(t)

torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchdynamo.convert_frame: [WARNING] torchdynamo hit config.cache_size_limit (64)
   function: 'forward' (/opt/conda/envs/dynamo/lib/python3.8/site-packages/torch/nn/modules/container.py:202)
   reasons:  ['___check_obj_id(self, 140496750308800)']
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchdynamo.convert_frame: [WARNING] torchdynamo hit config.cache_size_limit (64)
   function: '<graph break in _apply_recipe>' (/opt/conda/envs/dynamo/lib/python3.8/site-packages/einops/einops.py:233)
   reasons:  ['len(___stack0[0]) == 3']
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchinductor.compile_fx: [WARNING] skipping cudagraphs due to complex input striding
torchdynamo.convert_frame: [WARNING] torchdynamo hit config.cache_size_limit (64)
   function: 'reduce' (/opt/conda/envs/dynamo/lib/python3.8/site-packages/einops/einops.py:355)
   reasons:  ["set(axes_lengths.keys()) == {'c'}"]
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.
torchdynamo.convert_frame: [WARNING] torchdynamo hit config.cache_size_limit (64)
   function: '_apply_recipe' (/opt/conda/envs/dynamo/lib/python3.8/site-packages/einops/einops.py:229)
   reasons:  ["tensor 'tensor' requires_grad mismatch. expected requires_grad=0"]
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.
torchdynamo.convert_frame: [WARNING] torchdynamo hit config.cache_size_limit (64)
   function: 'rearrange' (/opt/conda/envs/dynamo/lib/python3.8/site-packages/einops/einops.py:428)
   reasons:  ["set(axes_lengths.keys()) == {'c'}"]
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.
torchinductor.ir: [WARNING] Using FallbackKernel: torch.ops.aten.pixel_shuffle.default
torchdynamo.convert_frame: [WARNING] torchdynamo hit config.cache_size_limit (64)
   function: 'reshape' (/opt/conda/envs/dynamo/lib/python3.8/site-packages/einops/_backends.py:83)
   reasons:  ['len(shape) == 4']
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.
Traceback (most recent call last):
  File "imggen.py", line 48, in <module>
    train()
  File "/home/ubuntu/torchdynamo/torchdynamo/eval_frame.py", line 166, in _fn
    return fn(*args, **kwargs)
  File "imggen.py", line 42, in train
    @torchdynamo.optimize('inductor')
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2435, in forward
    def forward(
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2451, in <graph break in forward>
    cond_images = maybe(cast_uint8_images_to_float)(cond_images)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2457, in <graph break in forward>
    unet = default(unet, lambda: self.get_unet(unet_number))
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2470, in <graph break in forward>
    check_shape(images, 'b c ...', c = self.channels)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2487, in <graph break in forward>
    text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2313, in p_losses
    def p_losses(
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2333, in <graph break in p_losses>
    noise = default(noise, lambda: torch.randn_like(x_start))
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2338, in <graph break in p_losses>
    lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2361, in <graph break in p_losses>
    x_noisy, log_snr = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2373, in <graph break in p_losses>
    noise_cond = noise_scheduler.get_condition(times)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2381, in <graph break in p_losses>
    lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 2407, in <graph break in p_losses>
    pred = unet.forward(
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 1475, in forward
    def forward(
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 1524, in <graph break in forward>
    time_hiddens = self.to_time_hiddens(time)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 1552, in <graph break in forward>
    text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 1552, in <graph break in forward>
    text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 1574, in <graph break in forward>
    text_mask = rearrange(text_mask, 'b n -> b n 1')
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 1651, in <graph break in forward>
    x = init_block(x, t, c)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 697, in forward
    def forward(self, x, time_emb = None, cond = None):
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 709, in <graph break in forward>
    h = self.cross_attn(h, context = cond) + h
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/einops_exts/torch.py", line 17, in forward
    x = self.fn(x, **kwargs)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/dynamo/lib/python3.8/site-packages/imagen_pytorch/imagen_pytorch.py", line 764, in forward
    k = torch.cat((nk, k), dim = -2)
RuntimeError: Output 0 of ReshapeAliasBackward0 is a view and its base or another view of its base has been modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

cc @ezyang @gchanan @zou3519 @kadeng @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @wconstab @aakhundov @soumith @ngimel

@malfet malfet transferred this issue from pytorch/torchdynamo Feb 1, 2023
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 7, 2023
@ydwu4
Copy link
Contributor

ydwu4 commented Nov 27, 2023

Appears to be a dynamo issue. Repro:

pip install imagen-pytorch
pip install einops_exts

and run the fowllowing:

import torch
from imagen_pytorch import Unet, Imagen

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

def train():
    for i in (1, 2):
        loss = imagen(images, text_embeds = text_embeds, unet_number = i)
        loss.backward()

torch.compile(train, backend="eager")()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = imagen.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
], cond_scale = 3.)

images.shape # (3, 3, 256, 256)

The error message:

/home/yidi/local/miniconda3/envs/repro/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/yidi/local/miniconda3/envs/repro/lib/python3.10/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR] Error while creating guard:
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR] Name: "L['noise_scheduler'].log_snr.__defaults__[0]"
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]     Source: local_nn_module
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]     Create Function: CONSTANT_MATCH
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]     Guard Types: None
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]     Code List: None
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]     Object Weakref: None
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]     Guarded Class Weakref: None
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR] Created at:
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 28, in wrap_bound_arg
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]     return VariableBuilder(tx, source=source)(val)
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 239, in __call__
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]     vt = self._wrap(value).clone(**self.options())
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 383, in _wrap
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]     return type_dispatch(self, value)
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 941, in wrap_literal
[2023-11-27 10:03:00,313] [9/0_1] torch._guards: [ERROR]     self.install_guards(GuardBuilder.CONSTANT_MATCH)
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/repro6.py", line 45, in <module>
    torch.compile(train, backend="eager")()
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/repro6.py", line 42, in train
    loss = imagen(images, text_embeds = text_embeds, unet_number = i)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "<@beartype(imagen_pytorch.imagen_pytorch.Imagen.forward) at 0x7f0cfafbb130>", line 12, in forward
  File "<@beartype(imagen_pytorch.imagen_pytorch.Imagen.forward) at 0x7f0cfafbb130>", line 63, in resume_in_forward
  File "/home/yidi/local/miniconda3/envs/repro/lib/python3.10/site-packages/imagen_pytorch/imagen_pytorch.py", line 2665, in forward
    unet = default(unet, lambda: self.get_unet(unet_number))
  File "/home/yidi/local/miniconda3/envs/repro/lib/python3.10/site-packages/imagen_pytorch/imagen_pytorch.py", line 2731, in resume_in_forward
    return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, min_snr_gamma = min_snr_gamma, random_crop_size = random_crop_size, **kwargs)
  File "<@beartype(imagen_pytorch.imagen_pytorch.Imagen.p_losses) at 0x7f0cfafe2290>", line 34, in p_losses
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 721, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 664, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 645, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 625, in compile_inner
    check_fn = CheckFunctionManager(
  File "/home/yidi/local/pytorch/torch/_dynamo/guards.py", line 1011, in __init__
    guard.create(builder)
  File "/home/yidi/local/pytorch/torch/_guards.py", line 246, in create
    return self.create_fn(builder, self)
  File "/home/yidi/local/pytorch/torch/_dynamo/guards.py", line 448, in CONSTANT_MATCH
    val = self.get(guard.name)
  File "/home/yidi/local/pytorch/torch/_dynamo/guards.py", line 258, in get
    return eval(name, self.scope, CLOSURE_VARS)
  File "<string>", line 1, in <module>
torch._dynamo.exc.InternalTorchDynamoError: 'torch._C.ScriptFunction' object has no attribute '__defaults__'


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

@anijain2305 anijain2305 added the dynamo-must-fix These bugs affect TorchDynamo reliability. label Jan 31, 2024
@anijain2305
Copy link
Contributor

A simpler repro is

import torch

@torch.jit.script
def fast_cos(x, c=None):
    if c is None:
        return torch.sin(x)
    return torch.cos(x)

class Mod(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fast_cos = fast_cos

    def forward(self, x):
        return self.fast_cos(x)

mod = Mod()
opt_mod = torch.compile(mod, backend="eager")
opt_mod(torch.randn(4))

@anijain2305 anijain2305 changed the title imagen inductor errors [dynamo] 'torch._C.ScriptFunction' object has no attribute '__defaults__' Jun 10, 2024
@anijain2305
Copy link
Contributor

Dynamo needs to better handle things like _torchdynamo_inline, where we monkeypatch self.source of UserDefinedFunctionVariable to the unwrapped value. Instead of monkeypatching, we should delay and unwrap when we actually inline (InliningInstructionTranslator). This ensures that we have the right sources for both unwrapped and wrapped object.

@anijain2305 anijain2305 added module: molly-guard Features which help prevent users from committing common mistakes module: guards and removed module: molly-guard Features which help prevent users from committing common mistakes labels Jun 10, 2024
@ezyang
Copy link
Contributor

ezyang commented Jun 11, 2024

anijain2305 added a commit that referenced this issue Jun 20, 2024
…ources for jit scripted functions"



Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
…cripted functions"



Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
…ources for jit scripted functions"



Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
…cripted functions"



Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
…ources for jit scripted functions"



Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
…cripted functions"



Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
…ources for jit scripted functions"



Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
…cripted functions"



Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
… attributes of torch.jit.* and lru_cache_wrapper"


Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
…orch.jit.* and lru_cache_wrapper"


Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
… attributes of torch.jit.* and lru_cache_wrapper"


Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
…orch.jit.* and lru_cache_wrapper"


Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
… attributes of torch.jit.* and lru_cache_wrapper"


Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
…orch.jit.* and lru_cache_wrapper"


Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
… attributes of torch.jit.* and lru_cache_wrapper"


Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 20, 2024
…orch.jit.* and lru_cache_wrapper"


Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 21, 2024
… attributes of torch.jit.* and lru_cache_wrapper"


Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 21, 2024
…orch.jit.* and lru_cache_wrapper"


Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 21, 2024
… attributes of torch.jit.* and lru_cache_wrapper"


Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
anijain2305 added a commit that referenced this issue Jun 21, 2024
…orch.jit.* and lru_cache_wrapper"


Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes #93698

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
facebook-github-bot pushed a commit to pytorch/benchmark that referenced this issue Jun 24, 2024
…ache_wrapper (#128336)

Summary:
Earlier we were taking the vt for `obj` and then monkeypatching that `vt.source` to be `obj._torchdynamo_inline`. If one accesses `obj.attr_a`, this would cause problems because Dynamo would then search it in `obj._torchdynamo_inline.attr_a`. This PR makes it more functional, so that we have different vts for obj and `ob._torchdynamo_inline`.

Fixes pytorch/pytorch#93698

X-link: pytorch/pytorch#128336
Approved by: https://github.com/jansel, https://github.com/yanboliang
ghstack dependencies: #129117

Reviewed By: fbgheith

Differential Revision: D58918358

Pulled By: anijain2305

fbshipit-source-id: 1143f7066132ea63c84faca01c505ef33ab87578
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-must-fix These bugs affect TorchDynamo reliability. dynamo-triage-june2024 high priority module: dynamo module: guards oncall: pt2 triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants