Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested tensor subclass support #127431

Closed

Conversation

tugsbayasgalan
Copy link
Contributor

@tugsbayasgalan tugsbayasgalan commented May 29, 2024

Stack from ghstack (oldest at bottom):

When we have nested tensor subclasses, we need to recursively flatten/unflatten in Fake tensor creation and AOTAUtograd. Most of the PR is about mechanical change which changes today's single level flatten logic to be recursive.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang

Differential Revision: D58533224

Copy link

pytorch-bot bot commented May 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127431

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (6 Unrelated Failures)

As of commit 4b0160f with merge base 78e40b2 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

sub = t.type.__tensor_unflatten__(
transformed_tensors_dict, t.ctx, outer_size, outer_stride
)
todo = plain_meta_tensors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just do it recursively? I don't think you'll stack overflow and I think it will be a lot easier to understand

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any update here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly?


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly?


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request May 29, 2024
ghstack-source-id: 7622d01854522c05294159ae7acf0f156ce070f3
Pull Request resolved: #127431
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly?


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Jun 6, 2024
ghstack-source-id: 932b944dba4b95fea3f414ecac0611df451c18be
Pull Request resolved: #127431
return inner_t

attr_fqn = prefix + "." + attr if prefix != "" else attr
attr_list = attr_fqn.split(".")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If all you're going to do to the attr_fqn is split it, why not just pass around a list

When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly?


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly?


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
current_context = symbolic_context.inner_contexts[attr]

current_source = AttrSource(source, attr)
new_empty_tensor = _empty_create_subclass(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't have to fix it here, but there's a somewhat prevalent antipattern in this file of doing small recursions on helper functions, rather than calling all the way back to the very top level. I think it should be OK to recurse to the very top call function, and that makes things more general since you can handle composition of things with other things the small helpers don't help. Just calling attention to this.

sub = t.type.__tensor_unflatten__(
transformed_tensors_dict, t.ctx, outer_size, outer_stride
sub = _empty_create_subclass(
t, outer_size, outer_stride, symbolic_context, callback, source
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ACKing this part, I'll let Brian do the rest

When we have nested tensor subclasses, we need to recurse down to access the underlying real tensor and wrap it in FakeTensor and recursively build back up the nested tensor subclasses. I am not sure if I am passing around the SymbolicContext correctly?


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
@tugsbayasgalan tugsbayasgalan changed the title [DRAFT] nested tensor subclass support Nested tensor subclass support Jun 12, 2024
@@ -1726,10 +1727,19 @@ def backward(ctx, *flat_args):
# TODO: figure out how to refactor the backward properly
# so I can use aot_dispatch_subclass_wrapper() here.
if CompiledFunction.maybe_subclass_metadata is not None:
tangents = all_args[tangents_start_idx:tangents_end_idx]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems kinda weird way of detecting wrong tangents but i guess this is best we can do?

curr_start_idx = self.flat_tensor_start_idx
for attr, creation_meta in self.attrs.items():
if creation_meta is None:
subclass = all_args[curr_start_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the variable name subclass here seems misleading, since it may or not actually be a subclass (my understanding is that if creation_meta is None, this is guaranteed to be a plain tensor.

Maybe inner_tensor?

@@ -171,47 +171,56 @@ class SubclassCreationMeta:
flat_tensor_start_idx: int
# The number of tensors that live in this subclass wrapper
arg_count: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after reading the code, some invariants that I think are worth explicitly mentioning in the comments:

  • arg_count is inclusive of the arg_counts of any inner tensor subclasses: If I have a TwoTensor and both of its inner elements are TwoTensors, then the arg_count of the outer-most sublass will be 4

curr_start_idx += creation_meta.arg_count
inner_tensors[attr] = subclass

rebuilt = type(self.original_subclass).__tensor_unflatten__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the indices in this reconstruction are definitely non-trivial. It would be great if we had some runtime debug-asserts we could run that would tell us if we messed up the indexing somewhere, so we get a less cryptic error if we get this wrong 🤔. I can't think of a great way to do this though, unless we do something like save all of the shapes of the inner tensors at trace time and assert that our reconstructed inner tensors are have the same shape at runtime

z = x.clone().detach().requires_grad_()
z_compile = z.clone().detach().requires_grad_()

out_eager = f(x_nested, y_nested, z)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... more out of paranoia than anything else, I'm worried about more complicated sets of inputs. The inputs to this test are something like Two(Two(plain, plain), Two(plain, plain)), plain, plain.

Some more testing ideas:
(1) Add a fourth argument that is an unbalanced TwoTensor, e.g. Two(plain, Two(plain, plain))
(2) add different subclass types into the test: e.g. make one input ConstantMetadataTensor(plain), and another a TwoTensor(plain, ConstantMetadataTensor(plain)).

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some more nits and more tests would be great, pre-emptively stamping!

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 25, 2024
@weifengpy
Copy link
Contributor

weifengpy commented Jun 25, 2024

curious are we landing this PR soon? It's helpful in addressing IMA issues when compiling DTensor(local=fp8). Super valuable work!

sharing my 2 cents perfs. For cpu overhead and gpu time, computing fp8 amax in eager is still faster than torch.compile #129457

When we have nested tensor subclasses, we need to recursively flatten/unflatten in Fake tensor creation and AOTAUtograd. Most of the PR is about mechanical change which changes today's single level flatten logic to be recursive. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

Differential Revision: [D58533224](https://our.internmc.facebook.com/intern/diff/D58533224)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58533224

@tugsbayasgalan
Copy link
Contributor Author

curious are we landing this PR soon? It's helpful in addressing IMA issues when compiling DTensor(local=fp8). Super valuable work!

sharing my 2 cents perfs. For cpu overhead and gpu time, computing fp8 amax in eager is still faster than torch.compile #129457

Will try to land today :)

When we have nested tensor subclasses, we need to recursively flatten/unflatten in Fake tensor creation and AOTAUtograd. Most of the PR is about mechanical change which changes today's single level flatten logic to be recursive. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

Differential Revision: [D58533224](https://our.internmc.facebook.com/intern/diff/D58533224)

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Jun 25, 2024
Pull Request resolved: #127431

When we have nested tensor subclasses, we need to recursively flatten/unflatten in Fake tensor creation and AOTAUtograd. Most of the PR is about mechanical change which changes today's single level flatten logic to be recursive.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang

@imported-using-ghimport

Differential Revision: [D58533224](https://our.internmc.facebook.com/intern/diff/D58533224/)
ghstack-source-id: 21cebdbb6f197c68dec314feadbb0bae7f8081ba
@tugsbayasgalan
Copy link
Contributor Author

@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -f 'Landed internally'

(Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Jul 1, 2024
Summary:
`unwrap_tensor_subclass` is incorporated in export stack natively after pytorch/pytorch#127431
so we can remove this workaround now

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Jul 1, 2024
Summary:
`unwrap_tensor_subclass` is incorporated in export stack natively after pytorch/pytorch#127431
so we can remove this workaround now

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Jul 2, 2024
Summary:
`unwrap_tensor_subclass` is incorporated in export stack natively after pytorch/pytorch#127431
so we can remove this workaround now

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Jul 2, 2024
Summary:
`unwrap_tensor_subclass` is incorporated in export stack natively after pytorch/pytorch#127431
so we can remove this workaround now

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
@github-actions github-actions bot deleted the gh/tugsbayasgalan/220/head branch July 27, 2024 01:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged module: dynamo release notes: fx release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants