-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Nested tensor subclass support #127431
Nested tensor subclass support #127431
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127431
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (6 Unrelated Failures)As of commit 4b0160f with merge base 78e40b2 (): FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/_subclasses/meta_utils.py
Outdated
sub = t.type.__tensor_unflatten__( | ||
transformed_tensors_dict, t.ctx, outer_size, outer_stride | ||
) | ||
todo = plain_meta_tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just do it recursively? I don't think you'll stack overflow and I think it will be a lot easier to understand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any update here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
ghstack-source-id: 7622d01854522c05294159ae7acf0f156ce070f3 Pull Request resolved: #127431
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
ghstack-source-id: 932b944dba4b95fea3f414ecac0611df451c18be Pull Request resolved: #127431
torch/_subclasses/meta_utils.py
Outdated
return inner_t | ||
|
||
attr_fqn = prefix + "." + attr if prefix != "" else attr | ||
attr_list = attr_fqn.split(".") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If all you're going to do to the attr_fqn is split it, why not just pass around a list
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
current_context = symbolic_context.inner_contexts[attr] | ||
|
||
current_source = AttrSource(source, attr) | ||
new_empty_tensor = _empty_create_subclass( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't have to fix it here, but there's a somewhat prevalent antipattern in this file of doing small recursions on helper functions, rather than calling all the way back to the very top level. I think it should be OK to recurse to the very top call function, and that makes things more general since you can handle composition of things with other things the small helpers don't help. Just calling attention to this.
sub = t.type.__tensor_unflatten__( | ||
transformed_tensors_dict, t.ctx, outer_size, outer_stride | ||
sub = _empty_create_subclass( | ||
t, outer_size, outer_stride, symbolic_context, callback, source |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ACKing this part, I'll let Brian do the rest
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
@@ -1726,10 +1727,19 @@ def backward(ctx, *flat_args): | |||
# TODO: figure out how to refactor the backward properly | |||
# so I can use aot_dispatch_subclass_wrapper() here. | |||
if CompiledFunction.maybe_subclass_metadata is not None: | |||
tangents = all_args[tangents_start_idx:tangents_end_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems kinda weird way of detecting wrong tangents but i guess this is best we can do?
curr_start_idx = self.flat_tensor_start_idx | ||
for attr, creation_meta in self.attrs.items(): | ||
if creation_meta is None: | ||
subclass = all_args[curr_start_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the variable name subclass
here seems misleading, since it may or not actually be a subclass (my understanding is that if creation_meta
is None, this is guaranteed to be a plain tensor.
Maybe inner_tensor
?
@@ -171,47 +171,56 @@ class SubclassCreationMeta: | |||
flat_tensor_start_idx: int | |||
# The number of tensors that live in this subclass wrapper | |||
arg_count: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after reading the code, some invariants that I think are worth explicitly mentioning in the comments:
arg_count
is inclusive of the arg_counts of any inner tensor subclasses: If I have a TwoTensor and both of its inner elements are TwoTensors, then thearg_count
of the outer-most sublass will be 4
curr_start_idx += creation_meta.arg_count | ||
inner_tensors[attr] = subclass | ||
|
||
rebuilt = type(self.original_subclass).__tensor_unflatten__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of the indices in this reconstruction are definitely non-trivial. It would be great if we had some runtime debug-asserts we could run that would tell us if we messed up the indexing somewhere, so we get a less cryptic error if we get this wrong 🤔. I can't think of a great way to do this though, unless we do something like save all of the shapes of the inner tensors at trace time and assert that our reconstructed inner tensors are have the same shape at runtime
z = x.clone().detach().requires_grad_() | ||
z_compile = z.clone().detach().requires_grad_() | ||
|
||
out_eager = f(x_nested, y_nested, z) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm... more out of paranoia than anything else, I'm worried about more complicated sets of inputs. The inputs to this test are something like Two(Two(plain, plain), Two(plain, plain)), plain, plain.
Some more testing ideas:
(1) Add a fourth argument that is an unbalanced TwoTensor, e.g. Two(plain, Two(plain, plain))
(2) add different subclass types into the test: e.g. make one input ConstantMetadataTensor(plain)
, and another a TwoTensor(plain, ConstantMetadataTensor(plain))
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some more nits and more tests would be great, pre-emptively stamping!
curious are we landing this PR soon? It's helpful in addressing IMA issues when compiling DTensor(local=fp8). Super valuable work! sharing my 2 cents perfs. For cpu overhead and gpu time, computing fp8 amax in eager is still faster than torch.compile #129457 |
When we have nested tensor subclasses, we need to recursively flatten/unflatten in Fake tensor creation and AOTAUtograd. Most of the PR is about mechanical change which changes today's single level flatten logic to be recursive. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang Differential Revision: [D58533224](https://our.internmc.facebook.com/intern/diff/D58533224) [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D58533224 |
Will try to land today :) |
When we have nested tensor subclasses, we need to recursively flatten/unflatten in Fake tensor creation and AOTAUtograd. Most of the PR is about mechanical change which changes today's single level flatten logic to be recursive. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang Differential Revision: [D58533224](https://our.internmc.facebook.com/intern/diff/D58533224) [ghstack-poisoned]
Pull Request resolved: #127431 When we have nested tensor subclasses, we need to recursively flatten/unflatten in Fake tensor creation and AOTAUtograd. Most of the PR is about mechanical change which changes today's single level flatten logic to be recursive. cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @imported-using-ghimport Differential Revision: [D58533224](https://our.internmc.facebook.com/intern/diff/D58533224/) ghstack-source-id: 21cebdbb6f197c68dec314feadbb0bae7f8081ba
@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: `unwrap_tensor_subclass` is incorporated in export stack natively after pytorch/pytorch#127431 so we can remove this workaround now Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Summary: `unwrap_tensor_subclass` is incorporated in export stack natively after pytorch/pytorch#127431 so we can remove this workaround now Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Summary: `unwrap_tensor_subclass` is incorporated in export stack natively after pytorch/pytorch#127431 so we can remove this workaround now Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Summary: `unwrap_tensor_subclass` is incorporated in export stack natively after pytorch/pytorch#127431 so we can remove this workaround now Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Stack from ghstack (oldest at bottom):
When we have nested tensor subclasses, we need to recursively flatten/unflatten in Fake tensor creation and AOTAUtograd. Most of the PR is about mechanical change which changes today's single level flatten logic to be recursive.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang
Differential Revision: D58533224