Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap output with FakeTensor if input FakeTensor is not preserved #128206

Closed
wants to merge 2 commits into from

Conversation

jerrychenhf
Copy link
Contributor

Fixes #128202

For cases that the input is a FakeTensor with "meta" device, the output may not preserve its fakeness and come out as Tensor (see _convert_element_type_meta). In this case, the original device type is lost in the output.

This PR pick up the work from #104689 and apply the check of the input and output fakeness originally showing in #119868.

Copy link

pytorch-bot bot commented Jun 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128206

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 6492290 with merge base 70724bd (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jerrychenhf
Copy link
Contributor Author

jerrychenhf commented Jun 7, 2024

@ezyang @laithsakka
As I lack the context of the concrete concerns as to "only wrap_output_with_input_device_ if dtype conversion changed a fake tensor to tensor", I put an initial solution for further discussion here. Please help review and suggest a direction.

I am also checking how a test case can be added for the case showing here.

@jerrychenhf
Copy link
Contributor Author

I made a deep dive to what is happening as to why fakeness is not preserved during type conversion for this case:

import torch

def fn(a):
 b = a.t()
 b.mul_(1.0)
 return b

x = torch.arange(6).reshape([2, 3]).to('cpu')

print("x ", x.cpu())

compiled_fn = torch.compile(fn)
y = compiled_fn(x)

print("y ", y.cpu())

The key part is the mul_ operator call on tensor (int type) and argument 1.0 (float type):
b.mul_(1.0)

Here is what happening for this line when compiling:

  1. _dispatch_impl of FakeTensorMode will be called to handle func: aten.mul.Tensor
  2. In _dispatch_impl, the following code will finally called to handle the func call. It calls under in_kernel_invocation_manager context. I saw other calls under the self (FakeTensorMode) context will preserve the fakeness. (I assume in_kernel_invocation_manager context will not preserve fakeness, am I right?)
        # run kernel registered to meta for func, which include
        # python meta registrations, prims, decomps, and c++ meta fns (structured kernels)
        try:
            with in_kernel_invocation_manager(self):
                r = func(*args, **kwargs)
        except NotImplementedError as not_implemented_error:
            return maybe_run_unsafe_fallback(not_implemented_error)

        return self.wrap_meta_outputs_with_default_device_logic(
            r, func, flat_args, device=kwargs.get("device")
        )
  1. Within the func(*args, **kwargs) call, the output wrapper for the func will try to do type promotion. Fakeness is not preserved under this context. And the code following (such as fn(**bound.arguments)) may lack of the real device context because of this.
            promoted_args = {
                x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
                for x in self.type_promoting_arg_names  # type: ignore[union-attr]
                if x in bound.arguments.keys()
            }
            # x device type losts in type conversion
            bound.arguments.update(promoted_args)

            result = fn(**bound.arguments)
  1. wrap_meta_outputs_with_default_device_logic in the end of _dispatch_impl will wrap to a FakeTensor with device when necessary. But this is after fn(**bound.arguments) in the wrapper showing in step 3.
  2. fn(**bound.arguments) may possible use the arguments and checks its device type. In my case, I saw it calls into _make_elementwise_binary_reference in torch/_refs/init.py and it may use the device for checking (for example is_noncontiguous_supported) and thus may make the wrong decision because the real device is not there.

@ezyang
Copy link
Contributor

ezyang commented Jun 10, 2024

This change is probably not correct but I need to study the bug report in more detail

@jerrychenhf
Copy link
Contributor Author

jerrychenhf commented Jun 10, 2024

Thanks @ezyang.

The key area causing the problem under layer is the following piece of code of elementwise_type_promotion_wrapper in torch/_prims_common/wrappers.py:

            promoted_args = {
                x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
                for x in self.type_promoting_arg_names  # type: ignore[union-attr]
                if x in bound.arguments.keys()
            }
            # x device type losts in type conversion
            bound.arguments.update(promoted_args)

            result = fn(**bound.arguments)

_maybe_convert_to_dtype will do the type conversion and will finally calls _to_copy in decompositions. The fakeness is not preserved under this context (in_kernel_invocation_manager). and the promoted argument is tensor while the original argument is a FakeTensor. During the process of fn(**bound.arguments), the tensor argument may be used. If a device check is done within any of the code path, there is a problem. For CPU device, it has not problem for the above test case. For HPU device, there is a device check in is_noncontiguous_supported function of torch/_refs.

So the problem is resulted in combination of two things:

  1. FakeTensor argument becomes a "meta" tensor during the type promotion (and thus the real device is not wraps into the tensor)
  2. Right after the type promotion, within fn(**bound.arguments), device type is needed for doing some specific logic.

For the bug test case running on CPU, only condition 1 is happening and shows no problem. When running on HPU, both condition 1 and condition 2 are happening and thus cause a problem.

Just as mentioned in the previous "deep dive" comment, in fake_tensor.py, there is wrap_meta_outputs_with_default_device_logic function called over the returned meta tensor at a later stage. After that, a FakeTensor is wrapped again with the real device and will have no problem here after.

@ezyang ezyang requested review from ezyang and laithsakka June 10, 2024 15:09
@ezyang
Copy link
Contributor

ezyang commented Jun 10, 2024

@laithsakka did we ever end up talking about what we thought the "correct" way to fix this problem was? I read through all of the old issues/PRs but all it says is that this was "pending discussion" at #104689 (comment) and we never followed up.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pending discussion

@ezyang ezyang requested a review from eellison June 10, 2024 15:16
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 10, 2024
@eellison
Copy link
Contributor

eellison commented Jun 13, 2024

The example in your linked issue fails to run in eager:

def fn(a):
 b = a.t()
 b.mul_(1.0)
 return b

x = torch.arange(6).reshape([2, 3]).to('cpu')

print("x ", x.cpu())
fn(x)

Gives:

  File "/home/eellison/pytorch/work_dir/tmp10.py", line 13, in <module>
    fn(x)
  File "/home/eellison/pytorch/work_dir/tmp10.py", line 5, in fn
    b.mul_(1.0)
RuntimeError: result type Float can't be cast to the desired output type Long

Do you have another example of this which is non-erroring ? If so, please re-request review.

@jerrychenhf
Copy link
Contributor Author

@eellison
The following case can show the same behavior and can also run with eager (Just change x from long tensor to bfloat16 tensor):

import torch

def fn(a):
 b = a.t()
 b.mul_(1.0)
 return b

x = torch.arange(6).reshape([2, 3]).bfloat16()

print("x ", x.cpu())

compiled_fn =torch.compile(fn)
y = compiled_fn(x)

print("y ", y.cpu())

Please note that the problematic behavior mentioned in the issue shows only on compile mode (related a lot to FakeTensor behaviors related to compile). Eager mode doesn't go with the problematic path.

@jerrychenhf jerrychenhf requested a review from ezyang June 14, 2024 03:31
@jerrychenhf
Copy link
Contributor Author

jerrychenhf commented Jun 14, 2024

@eellison
You don't show in the reviewer list and I don't have permission to request review for those who are not in the list.
Please review with the updated example.

@eellison
Copy link
Contributor

eellison commented Jun 14, 2024

I'm still not sure what the error is:

import torch

def fn(a):
 b = a.t()
 b.mul_(1.0)
 return b

x = torch.arange(6).reshape([2, 3]).to('cpu')

print("x ", x.cpu())
out_c = torch.compile(fn)(x.clone())
out_eager = fn(x.clone())
# no error here
torch.testing.assert_close(out_c, out_eager)

assert x.untyped_storage().data_ptr() == out_c.untyped_storage().data_ptr()
with torch._subclasses.fake_utils.CrossRefFakeMode():
    # no error here
    fn(x)

with torch._subclasses.FakeTensorMode():
    x = torch.arange(6).reshape([2, 3]).to('cpu')
    b = x.t()  
    out = b.mul_(1.0)
    # just to verify... same storage
    assert out.untyped_storage()._cdata == x.untyped_storage()._cdata

@jerrychenhf
Copy link
Contributor Author

jerrychenhf commented Jun 15, 2024

@eellison Thank you for checking this.

There is no error shown up running on CPU (so I use the term behavior). When I run with HPU device, the data of result tensor is incorrect which is caused by the combination of the behavior (same as on CPU) and another check done specific by HPU depending on it.

Since there is no error on CPU, the behavior on CPU is not directly observable without analysis into the internal code path. See my comment for the code path analysis #128206 (comment) and #128206 (comment).

Here is the summary of the behavior that is not expected:

  1. There is a time a FakeTensor doing a type conversion and returns a meta Tensor without the original device
  2. Immediately following this, there are some code that using the type converted tensor and check its device type (but real device type lost)

This is caused by exactly the following line of code ( of course tensor data types matters) during the compilation:
b.mul_(1.0)

Because b is bfloat16 tensor, 1.0 is float32 type, there is implicit type conversion happening dispatching aten.mul.Tensor in FakeTensorMode. This triggers the behavior that is not expected.

To showing up this for analysis, one simple thing to do is to print some tensor info in the _to_copy defined in torch/_decomp/decompositions.py as following (two prints before and after :torch._prims.convert_element_type)

@register_decomposition(aten._to_copy)
@out_wrapper()
def _to_copy(
    x: Tensor,
    *,
    dtype: Optional[torch.dtype] = None,
    layout=None,
    device: Optional[torch.device] = None,
    pin_memory: bool = False,
    non_blocking: bool = False,
    memory_format: Optional[torch.memory_format] = None,
):
    assert not layout or layout == torch.strided, "TODO"
    assert not pin_memory, "TODO"
    if device is None and dtype is None and memory_format is None:
        return x.clone()
    dtype_converted = False
    common_device = device_hint(x)

    if device is not None and device != x.device:
        # avoid conversions on cpu
        if dtype is not None and device.type == "cpu":
            x = torch._prims.convert_element_type(x, dtype)
            dtype_converted = True
        x = torch._prims.device_put(x, device)

    if dtype is not None and not dtype_converted:
        print("Tensor before convert_element_type: " x)
        x = torch._prims.convert_element_type(x, dtype)
        print("Tensor after convert_element_type: " x)
        dtype_converted = True

    if memory_format is not None:  # no ref/prim for memory format
        return torch.clone(x, memory_format=memory_format)
    return x

You will see the FakeTensor becomes meta Tensor here. My comment #128206 (comment) has details as to from where it comes here in FakeTensorMode dispatch function.

The behavior doesn't show when you do something with explicitly FakeTensorMode context (like you do in the above test code, for all this type of usage you will see FakeTensor to FakeTensor before and after type conversion). It happens in a very specific program pattern I shown above (compile mode, a function with implicit type conversion for its args)

Hope this would help start for understanding the problem.

@jerrychenhf
Copy link
Contributor Author

@eellison Any chance to check this again?

@eellison
Copy link
Contributor

@jerrychenhf

We are invoking _to_copy as a meta kernel under in_kernel_invocation_manager. It is expected that any intermediary tensor will on meta device.

Those outputs of the meta kernel then get wrapped back to the appropriate device with wrap_meta_outputs_with_default_device_logic. If the default device logic does not apply then we have device-specific logic in fake_impls.

@jerrychenhf
Copy link
Contributor Author

@eellison Thank your for checking this and response.

Yes, you are right that _to_copy is called under in_kernel_invocation_manager and the meta device tensor is wrapped back by wrap_meta_outputs_with_default_device_logic.

So the current problem behavior (state) I got is just after _to_copy and before wrap_meta_outputs_with_default_device_logic. Which is the following code piece in torch/_prims_common/wrappers.py:

            promoted_args = {
                x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
                for x in self.type_promoting_arg_names  # type: ignore[union-attr]
                if x in bound.arguments.keys()
            }
            # x device type losts in type conversion
            bound.arguments.update(promoted_args)

            result = fn(**bound.arguments)

The _maybe_convert_to_dtype calls _to_copy to type conversion of arguments and fn(**bound.arguments) may use the arguments lacking of the real device. wrap_meta_outputs_with_default_device_logic happens afterwards and so it doesn't help here.

Any thoughts how to improve going forward?

@eellison
Copy link
Contributor

Is it just pointwise ops that are is_noncontiguous_supported ? why dont you move the contiguous wrapping to FakeTensor, where we more reliably have the device.

@jerrychenhf
Copy link
Contributor Author

jerrychenhf commented Jun 30, 2024

@eellison

Is it just pointwise ops that are is_noncontiguous_supported ?

Checked with the original author, the is_noncontiguous_supported handling is not strictly to pointwise ops.
Would you please give any hint on how this would impact the implementation?

why dont you move the contiguous wrapping to FakeTensor, where we more reliably have the device.

I am not sure whether there are equivalent places in FakeTensor to add these device checking and contiguous wrappings. I saw quite a few places of handling this in torch/_refs/init.py.
The original noncontiguous handling for outputs was discussed and reviewed by Edward. Maybe we should ask Edward to check this point. Edward may be the best person who have the best knowledge and background to help a design choice if we need to change this part.

@ezyang Would you please help a check this part? I am not sure what @eellison suggested (move the contiguous wrapping to FakeTensor) is feasible or not. There are quite a few layers of wrappers and intermediate logics between FakeTensor dispatch function and noncontiguous handling in _refs.

@ezyang
Copy link
Contributor

ezyang commented Jul 3, 2024

So, certainly if we reimplement all the pointwise rules directly in fake tensor, we can ensure we never decay to meta (since we're always operating on fake tensors) and then you would be guaranteed to have accurate device info. It does sound a bit duplicatey. But it definitely will work.

@jerrychenhf
Copy link
Contributor Author

@ezyang Thank you checking on this.
I am not an expert on all these pointwise rules. Do you have a list of these rules to reimplement? And we will have to duplicate the code of the rule implementations, right? (maybe there is a way to reuse).
And in FakeTensor, would you please help a hint on where to implement these rules?

@eellison
Copy link
Contributor

eellison commented Jul 8, 2024

@jerrychenhf you can check if the operator has the Pointwise tag

@jerrychenhf
Copy link
Contributor Author

Since PT 2.4 has a improvement of reconstructing a view on another base in AOT autograd (#121007) The noncontiguous outputs handling for some specific device is no longer needed. (#104689)

And so this is no longer needed in PT 2.4 and upwards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Wrong result for Inplace tensor update on transpose for some devices with torch 2.3.0
5 participants