-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use return_and_correct_aliasing() for NJT + compatible storage setting #126552
base: gh/jbschlosser/144/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126552
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 8 New Failures, 1 Cancelled Job, 21 Unrelated FailuresAs of commit db474e3 with merge base f86dbae (): NEW FAILURES - The following jobs have failed:
CANCELLED JOB - The following job was cancelled. Please retry:
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/csrc/Module.cpp
Outdated
@@ -2105,6 +2106,13 @@ Call this whenever a new thread is created in order to propagate values from | |||
"_set_conj", [](const at::Tensor& x, bool conj) { x._set_conj(conj); }); | |||
py_module.def( | |||
"_set_neg", [](const at::Tensor& x, bool neg) { x._set_neg(neg); }); | |||
py_module.def( | |||
"_set_storage", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(trying to bikeshed on the name.... _unsafe_set_storage
/_set_only_storage
?)
@albanD mostly wondering if you are unhappy with a dedicated API for this. Basically, x.set_(storage, size, stride)
: does 3 things right now:
(1) set the storage
(2) call `set_sizes_strides()
(3) potentially resize the storage (if the number of bytes required for the given size/stride is larger than the nbytes of the passed-in storage
And return_and_correct_aliasing
(used by subclasses) only ever calls a.set_(b.storage(), a.size(), a.stride())
, where we are guaranteed that the resize and the set_sizes_strides_
are not necessary.
Those extra two steps are (1) a pain for nested tensor (trying to compute nbytes), and (2) unnecessary work we are doing throughout tracing (since this is used by FunctionalTensor
as well)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, your suggested names are much better - do we not already have something like this? Just looked and there's _unsafe_reset_storage()
that seems to be a functionalization thing.
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch._C._set_storage()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch._C._set_storage()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch._C._set_storage()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch._C._set_storage()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch._C._set_storage()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch._C._set_storage()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
@bdhirsh this PR breaks some NJT stuff higher up in the stack. Hacking in the storage changes the reported I think the difficulty here comes from the fact that wrapper subclasses have null storages with the desired device hacked in: pytorch/torch/csrc/autograd/python_variable.cpp Lines 838 to 846 in 980f5ac
Combine that with this scenario during tracing:
Dense -> subclass views mix the two and cause problems. Not sure if this is easily addressable for NJT; might be a better approach to dodge the debug asserts without the use of I'm happy to break out the part that does storage setting without extra check logic to help #125977. |
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
…rage setting" Fixes #125503 Context: `return_and_correct_aliasing()` is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503). This PR: * Uses `return_and_correct_aliasing()` in NJT * Changes how storage setting is done in `return_and_correct_aliasing()` * Old way: use `set_.source_Storage_storage_offset()`, which has extra logic for storage resizing that we don't need * New way: `torch.ops.aten._unsafe_set_storage_()` that shoves in a storage without this extra logic. Notably, this avoids `computeStorageNbytes()` choking on nested ints in NJT's sizes / strides [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Fixes #125503
Context:
return_and_correct_aliasing()
is required for traceable wrapper subclasses so that aliasing relationships are correct. NJT has not been using this, but needs to for correct aliasing relationships, and to avoid tripping asserts when DEBUG=1 (e.g. #125503).This PR:
return_and_correct_aliasing()
in NJTreturn_and_correct_aliasing()
set_.source_Storage_storage_offset()
, which has extra logic for storage resizing that we don't needtorch.ops.aten._unsafe_set_storage_()
that shoves in a storage without this extra logic. Notably, this avoidscomputeStorageNbytes()
choking on nested ints in NJT's sizes / strides