Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast path detach()/alias() in FakeTensor #128281

Closed
ezyang opened this issue Jun 8, 2024 · 3 comments
Closed

Fast path detach()/alias() in FakeTensor #128281

ezyang opened this issue Jun 8, 2024 · 3 comments
Labels
actionable high priority module: dynamic shapes module: fakeTensor module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented Jun 8, 2024

🐛 Describe the bug

We call detach()/alias() for a variety of administrative purposes, typically because we need to get a copy of the metadata of a tensor that won't be modified by subsequent metadata mutation. This is currently implemented quite inefficiently:

  File "/data/users/ezyang/b/pytorch/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 135, in <lambda>
    torch.Tensor, lambda t: t.detach(), updated_flat_args_subclasses_desugared
  File "/data/users/ezyang/b/pytorch/torch/utils/_stats.py", line 20, in wrapper           
    return fn(*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_subclasses/fake_tensor.py", line 1060, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)      
  File "/data/users/ezyang/b/pytorch/torch/_subclasses/fake_tensor.py", line 1449, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_subclasses/fake_tensor.py", line 1144, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_subclasses/fake_tensor.py", line 1756, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_ops.py", line 666, in __call__
    return self_._op(*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_prims_common/wrappers.py", line 265, in _fn
    result = fn(*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_decomp/decompositions.py", line 2109, in nop_decomposition
    return aten.alias(x)
  File "/data/users/ezyang/b/pytorch/torch/_ops.py", line 1060, in __call__
    return self_._op(*args, **(kwargs or {}))
  File "/data/users/ezyang/b/pytorch/torch/_meta_registrations.py", line 3658, in meta_alias
    return self.view(self.shape)

We should have a fastpath for this which bypasses performing a view() on it.

High priority for compile time improvements.

Versions

main

cc @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @eellison

@zou3519 zou3519 added module: fakeTensor module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels Jun 10, 2024
@soulitzer soulitzer added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jun 11, 2024
@zou3519
Copy link
Contributor

zou3519 commented Jun 13, 2024

image

seems like this takes a long time

@zou3519
Copy link
Contributor

zou3519 commented Jul 8, 2024

Actionable to attempt the following approaches:

  1. the first step would be to stop going through so many hoops of dispatch (detach -> alias -> view) and seeing if that improves compilation time
  2. the second step is to see if we can call shallow_copy_and_detach directly on the FakeTensor for the detach() created by shallow_copy_and_detach, instead of going through the detach->alias->view dispatch.

@bdhirsh
Copy link
Contributor

bdhirsh commented Jul 26, 2024

A few findings. I used TORCH_COMPILE_CPROFILE=1 python benchmarks/dynamo/huggingface.py --performance --timing --explain --backend aot_eager --device cuda --training --float32 --only BertForMaskedLM as my benchmarking repro

(1) I tried fast-pathing detach to avoid the decomps (only step 1 above, not step 2) by adding a fastpath for FakeTensor.detach(), by temporarily turning off the python dispatcher , and did not see much of an overall speedup.

(2) Looking at the svg, you can see that the majority (~2/3) of the calls to TensorBase::detach()are flowing throughsnapshot_fake()`

(3) I updated snapshot_fake() to directly call the same fast_detach() (with no decomps), and I see a much larger speedup.

BertForMaskedLM is a bit weird because it has several graph breaks, so there are a few tiny graphs, and two large graphs. Looking at the largest graph, I see:

compile time before: 23.820
compile time after: 19.959

detach before:
image

detach after:
image

bdhirsh added a commit that referenced this issue Jul 26, 2024
Fixes #128281, see investigation at #128281 (comment).

benchmark:
```
python benchmarks/dynamo/huggingface.py --performance --timing --explain --backend aot_eager --device cuda --training --float32 --only BertForMaskedLM
```

time before:
```
TIMING: entire_frame_compile:30.85435 backend_compile:23.98599 total_wall_time:30.85435
```

time after:
```
TIMING: entire_frame_compile:24.35898 backend_compile:18.15235 total_wall_time:24.35898
```




[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable high priority module: dynamic shapes module: fakeTensor module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants