Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Warn on future divergent behavior for conditional views #126129

Draft
wants to merge 25 commits into
base: gh/kurtamohler/30/base
Choose a base branch
from

Conversation

kurtamohler
Copy link
Collaborator

@kurtamohler kurtamohler commented May 14, 2024

Copy link

pytorch-bot bot commented May 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126129

Note: Links to docs will display an error until the docs builds have been completed.

❌ 87 New Failures, 1 Unrelated Failure

As of commit 2addcd3 with merge base f2552dc (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kurtamohler added a commit that referenced this pull request May 14, 2024
ghstack-source-id: 16e443f188e379ec6cb2ea6bc5be14343a1833a5
Pull Request resolved: #126129
…ews"


Part of #109833

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
…ews"


Part of #109833

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
…ews"


Part of #109833

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request May 14, 2024
ghstack-source-id: 46765f09395d7b370de30fdd61d8a36363ae072f
Pull Request resolved: #126129
…ews"


Part of #109833

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Jun 10, 2024
ghstack-source-id: d6678c5a8fcaf99b5c0a2acf979e9b82366ed141
Pull Request resolved: #126129
[ghstack-poisoned]
[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: 92ca4f38b33f8a13f6d3f5b990d5b9f0a370e2a9
Pull Request resolved: #126129
@kurtamohler
Copy link
Collaborator Author

I'm getting an odd segfault that I'm having trouble understanding.

$ python test/functorch/test_ops.py -k test_vmapjvpall_has_batch_rule_max_pool2d_with_indices_backward_cuda_float32
Segmentation fault (core dumped)

From the backtrace at the end of this comment, this happens when lazy_clone_storage checks if the given storage has a simple data pointer by calling has_simple_data_ptr. Apparently, in this case, the Allocator* returned by storage.allocator() is non-null, and yet the allocator->is_simple_data_ptr call causes a segfault, indicating that the Allocator* is an invalid pointer.

I don't know why that would happen. Maybe the Allocator was deleted for some reason, and if so, I'm not sure how to detect that in has_simple_data_ptr.

One way to workaround this segfault, and prevent it in some cases, is to change has_simple_data_ptr from this:

   if (allocator != nullptr) {
     return allocator->is_simple_data_ptr(data_ptr);
   } else {
     return ctx == data;
   }

to this:

   if (ctx == data) {
     return true;
   } else if (allocator != nullptr) {
     return allocator->is_simple_data_ptr(data_ptr);
   } else {
     return false;
   }

But that change just covers up the issue rather than really solving it. There could potentially be a case where we have a simple data pointer, but ctx == data + some_offset (which is the case for CPUAllocator) and the Allocator* is invalid.

I guess I'll have to just go with the workaround until I know more.

Backtrace:

Click to expand
#0  0x00007fffdacfc1eb in c10::impl::cow::has_simple_data_ptr (storage=...)
    at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COW.cpp:44
#1  0x00007fffdacfc545 in c10::impl::cow::lazy_clone_storage (storage=...)
    at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COW.cpp:91
#2  0x00007fffe48e8eb3 in at::native::_lazy_clone (self=...)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/native/AutogradComposite.cpp:96
#3  0x00007fffe5aeba6c in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd___lazy_clone (self=...)
    at /home/kurtamohler/develop/pytorch-1/build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:1605
#4  c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&), at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd___lazy_clone>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&> >::operator() (args#0=..., this=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13
#5  c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&), at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd___lazy_clone>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&> >, at::Tensor(const at::Tensor&)>::call (args#0=...,
    functor=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:468
#6  c10::impl::call_functor_with_args_from_stack_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&), at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd___lazy_clone>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&> >, false, 0, const at::Tensor&> (functor=<optimized out>,
    dispatchKeySet=..., stack=0x7fffffff4330)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:506
#7  c10::impl::call_functor_with_args_from_stack<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&), at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd___lazy_clone>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&> >, false> (stack=0x7fffffff4330, dispatchKeySet=...,
    functor=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:518
#8  c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&), at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd___lazy_clone>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&> >, false>::call(c10::OperatorKernel *, const c10::OperatorHandle &, c10::DispatchKeySet, c10::Stack *) (functor=<optimized out>, dispatchKeySet=...,
    stack=0x7fffffff4330)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:584
#9  0x00007fffe518dd49 in c10::BoxedKernel::callBoxed (stack=0x7fffffff4330,
    dispatchKeySet=..., opHandle=..., this=0x555556760198)
    at /home/kurtamohler/develop/pytorch-1/c10/util/intrusive_ptr.h:414
#10 c10::impl::BoxedKernelWrapper<at::Tensor (at::Tensor const&), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&) (
    boxed_kernel_func=..., opHandle=..., dispatchKeySet=dispatchKeySet@entry=..., args#0=...)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/boxing.h:236
#11 0x00007fffe54ef840 in c10::KernelFunction::call<at::Tensor, at::Tensor const&> (
    dispatchKeySet=..., opHandle=..., this=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/KernelFunction_impl.h:114
#12 c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&)> const&, c10::DispatchKeySet, at::Tensor const&) const (
    this=<optimized out>, currentDispatchKeySet=..., op=...)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/dispatch/Dispatcher.h:716
#13 c10::TypedOperatorHandle<at::Tensor (at::Tensor const&)>::redispatch(c10::DispatchKeySet, at::Tensor const&) const (args#0=..., currentDispatchKeySet=...,
    this=0x7fffee66a3d0 <at::_ops::_lazy_clone::redispatch(c10::DispatchKeySet, at::Tensor const&)::op>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/dispatch/Dispatcher.h:536
#14 at::_ops::_lazy_clone::redispatch (dispatchKeySet=..., self=...)
    at /home/kurtamohler/develop/pytorch-1/build/aten/src/ATen/Operators_2.cpp:1784
#15 0x00007fffe7773286 in operator() (__closure=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/torch/csrc/autograd/generated/ADInplaceOrViewType_0.cpp:854
#16 torch::ADInplaceOrView::(anonymous namespace)::_lazy_clone (ks=..., self=...)
    at /home/kurtamohler/develop/pytorch-1/torch/csrc/autograd/generated/ADInplaceOrViewType_0.cpp:855
#17 0x00007fffe77735e3 in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(c10::DispatchKeySet, const at::Tensor&), torch::ADInplaceOrView::(anonymous namespace)::_lazy_clone>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, const at::Tensor&> >::operator() (args#1=..., args#0=..., this=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:12
#18 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(c10::DispatchKeySet, const at::Tensor&), torch::ADInplaceOrView::(anonymous namespace)::_lazy_clone>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, const at::Tensor&> >, at::Tensor(c10::DispatchKeySet, const at::Tensor&)>::call(c10::OperatorKernel *, c10::DispatchKeySet, const at::Tensor &) (functor=<optimized out>,
    dispatchKeySet=..., args#0=...)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:485
#19 0x00007fffe54ef793 in c10::callUnboxedKernelFunction<at::Tensor, at::Tensor const&> (
    dispatchKeySet=..., functor=<optimized out>, unboxed_kernel_func=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/KernelFunction_impl.h:50
#20 c10::KernelFunction::call<at::Tensor, at::Tensor const&> (dispatchKeySet=...,
    opHandle=..., this=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/KernelFunction_impl.h:105
#21 c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&)> const&, c10::DispatchKeySet, at::Tensor const&) const (
    this=<optimized out>, currentDispatchKeySet=..., op=...)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/dispatch/Dispatcher.h:716
#22 c10::TypedOperatorHandle<at::Tensor (at::Tensor const&)>::redispatch(c10::DispatchKeySet, at::Tensor const&) const (args#0=..., currentDispatchKeySet=...,
    this=0x7fffee66a3d0 <at::_ops::_lazy_clone::redispatch(c10::DispatchKeySet, at::Tensor const&)::op>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/dispatch/Dispatcher.h:536
#23 at::_ops::_lazy_clone::redispatch (dispatchKeySet=dispatchKeySet@entry=..., self=...)
    at /home/kurtamohler/develop/pytorch-1/build/aten/src/ATen/Operators_2.cpp:1784
#24 0x00007fffe7060292 in at::redispatch::_lazy_clone (self=..., dispatchKeySet=...)
    at /home/kurtamohler/develop/pytorch-1/build/aten/src/ATen/RedispatchFunctions.h:1377
#25 operator() (__closure=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/torch/csrc/autograd/generated/VariableType_2.cpp:4436
#26 torch::autograd::VariableType::(anonymous namespace)::_lazy_clone (ks=..., self=...)
    at /home/kurtamohler/develop/pytorch-1/torch/csrc/autograd/generated/VariableType_2.cpp:4437
#27 0x00007fffe70607c3 in c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(c10::DispatchKeySet, const at::Tensor&), torch::autograd::VariableType::(anonymous namespace)::_lazy_clone>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, const at::Tensor&> >::operator() (args#1=..., args#0=..., this=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:12
#28 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(c10::DispatchKeySet, const at::Tensor&), torch::autograd::VariableType::(anonymous namespace)::_lazy_clone>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, const at::Tensor&> >, at::Tensor(c10::DispatchKeySet, const at::Tensor&)>::call(c10::OperatorKernel *, c10::DispatchKeySet, const at::Tensor &) (
    functor=<optimized out>, dispatchKeySet=..., args#0=...)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:485
#29 0x00007fffe5557fe7 in c10::callUnboxedKernelFunction<at::Tensor, at::Tensor const&> (
    dispatchKeySet=..., functor=<optimized out>, unboxed_kernel_func=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/KernelFunction_impl.h:50
#30 c10::KernelFunction::call<at::Tensor, at::Tensor const&> (dispatchKeySet=...,
    opHandle=..., this=0x555556760218)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/KernelFunction_impl.h:105
#31 c10::Dispatcher::call<at::Tensor, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&)> const&, at::Tensor const&) const (op=..., this=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/dispatch/Dispatcher.h:698
#32 c10::TypedOperatorHandle<at::Tensor (at::Tensor const&)>::call(at::Tensor const&) const (
    args#0=..., this=0x7fffee66a3f0 <at::_ops::_lazy_clone::call(at::Tensor const&)::op>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/dispatch/Dispatcher.h:531
#33 at::_ops::_lazy_clone::call (self=...)
    at /home/kurtamohler/develop/pytorch-1/build/aten/src/ATen/Operators_2.cpp:1777
#34 0x00007fffe4d2a27d in at::Tensor::_lazy_clone (this=0x7fffffff4698)
    at /home/kurtamohler/develop/pytorch-1/build/aten/src/ATen/core/TensorBody.h:1944
#35 at::native::reshape_symint (self=..., proposed_shape=...)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/native/TensorShape.cpp:1670
#36 0x00007fffe5d357dd in at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__reshape (shape=..., self=...)
    at /home/kurtamohler/develop/pytorch-1/build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:3196
#37 c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&, c10::ArrayRef<c10::SymInt>), at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__reshape>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&, c10::ArrayRef<c10::SymInt> > >::operator() (args#1=..., args#0=...,
    this=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13
#38 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&, c10::ArrayRef<c10::SymInt>), at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__reshape>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&, c10::ArrayRef<c10::SymInt> > >, at::Tensor(const at::Tensor&, c10::ArrayRef<c10::SymInt>)>::call (args#1=..., args#0=...,
    functor=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:468
#39 c10::impl::call_functor_with_args_from_stack_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&, c10::ArrayRef<c10::SymInt>), at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__reshape>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&, c10::ArrayRef<c10::SymInt> > >, false, 0, 1, const at::Tensor&, c10::ArrayRef<c10::SymInt> > (functor=<optimized out>,
    dispatchKeySet=..., stack=0x7fffffff4c90)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:506
#40 c10::impl::call_functor_with_args_from_stack<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&, c10::ArrayRef<c10::SymInt>), at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__reshape>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&, c10::ArrayRef<c10::SymInt> > >, false> (stack=0x7fffffff4c90, dispatchKeySet=..., functor=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:518
#41 c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&, c10::ArrayRef<c10::SymInt>), at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__reshape>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&, c10::ArrayRef<c10::SymInt> > >, false>::call(c10::OperatorKernel *, const c10::OperatorHandle &, c10::DispatchKeySet, c10::Stack *)
    (functor=<optimized out>, dispatchKeySet=..., stack=0x7fffffff4c90)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:584
#42 0x00007fffe47c4a58 in c10::BoxedKernel::callBoxed (stack=0x7fffffff4c90,
    dispatchKeySet=..., opHandle=..., this=0x555556410e08)
    at /home/kurtamohler/develop/pytorch-1/c10/util/intrusive_ptr.h:414
#43 c10::KernelFunction::callBoxed (stack=0x7fffffff4c90, dispatchKeySet=..., opHandle=...,
    this=0x555556410e08)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/KernelFunction_impl.h:46
#44 c10::Dispatcher::callBoxed (op=..., stack=0x7fffffff4c90, this=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/dispatch/Dispatcher.h:751
#45 0x00007fffe47ace86 in at::functorch::Interpreter::sendToNextInterpreter (
    this=this@entry=0x7fffffff4bf0, op=..., stack=stack@entry=0x7fffffff4c90,
    grad_special_case=grad_special_case@entry=false)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/functorch/Interpreter.cpp:127
#46 0x00007fffe47aa133 in at::functorch::dynamicLayerBack (op=..., stack=0x7fffffff4c90,
    grad_special_case=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/functorch/DynamicLayer.h:56
#47 0x00007fffe5395de7 in c10::BoxedKernel::callBoxed (stack=0x7fffffff4c90,
    dispatchKeySet=..., opHandle=..., this=0x555556410ce8)
    at /home/kurtamohler/develop/pytorch-1/c10/util/intrusive_ptr.h:414
#48 c10::impl::BoxedKernelWrapper<at::Tensor (at::Tensor const&, c10::ArrayRef<c10::SymInt>), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>) (boxed_kernel_func=..., opHandle=...,
    dispatchKeySet=dispatchKeySet@entry=..., args#0=..., args#1=...)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/impl/boxing.h:236
#49 0x00007fffe573fb7f in c10::KernelFunction::call<at::Tensor, at::Tensor const&, c10::ArrayRef<c10::SymInt> > (dispatchKeySet=..., opHandle=..., this=0x555556410ce8)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/boxing/KernelFunction_impl.h:114
#50 c10::Dispatcher::call<at::Tensor, at::Tensor const&, c10::ArrayRef<c10::SymInt> >(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, c10::ArrayRef<c10::SymInt>)> const&, at::Tensor const&, c10::ArrayRef<c10::SymInt>) const (op=..., this=<optimized out>)
    at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/dispatch/Dispatcher.h:698
#51 c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, c10::ArrayRef<c10::SymInt>)>::call
(at::Tensor const&, c10::ArrayRef<c10::SymInt>) const (args#1=..., args#0=...,
    this=0x7fffee671e50 <at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<c10::SymInt>)::op>) at /home/kurtamohler/develop/pytorch-1/aten/src/ATen/core/dispatch/Dispatcher.h:531
#52 at::_ops::reshape::call (self=..., shape=...)
    at /home/kurtamohler/develop/pytorch-1/build/aten/src/ATen/Operators_3.cpp:4324
#53 0x00007fffe46fdbe6 in at::reshape (shape=..., self=...)
    at /home/kurtamohler/develop/pytorch-1/build/aten/src/ATen/ops/reshape.h:27

@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Jun 13, 2024

Another problem is that this PR currently breaks sharing tensors between processes.

import torch
import torch.multiprocessing as mp

def test(q):
    x = torch.randn(10)
    x.share_memory_()
    q.put(x)

if __name__ == "__main__":
    manager = mp.Manager()
    q = manager.Queue()
    p = mp.Process(target=test, args=(q,))
    p.start()
    p.join()

    print(q.get())

Output:

  File "/home/kurtamohler/tmp/lazy_clone_multiprocess.py", line 17, in <module>
    print(q.get())
  File "/home/kurtamohler/develop/pytorch-1/torch/_tensor.py", line 463, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "/home/kurtamohler/develop/pytorch-1/torch/_tensor_str.py", line 698, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "/home/kurtamohler/develop/pytorch-1/torch/_tensor_str.py", line 618, in _str_intern
    tensor_str = _tensor_str(self, indent)
  File "/home/kurtamohler/develop/pytorch-1/torch/_tensor_str.py", line 350, in _tensor_str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
  File "/home/kurtamohler/develop/pytorch-1/torch/_tensor_str.py", line 130, in __init__
    tensor_view = tensor.reshape(-1)
RuntimeError: Expected storage != nullptr to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

That failed check happens in at::native::_lazy_clone. lazy_clone_storage does not support the case where the storage was allocated by a separate process, so I will need to fix that.

I'm not sure, but maybe we should just forbid sharing COW tensors between processes? So if _lazy_clone is given a shared tensor, future behavior would just be an immediate clone, and simulated behavior would just be a view that raises a deprecation warning immediately. @ezyang, does that sound good? I can't remember if this was discussed at any point

@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Jun 14, 2024

Yet another problem--lazy_clone_storage doesn't work for numpy tensors, or potentially anything else created by at::from_blob:

import torch
import numpy as np
a = torch.from_numpy(np.random.randn(10))
b = a._lazy_clone()
Traceback (most recent call last):
  File "/home/kurtamohler/tmp/lazy_clone_numpy.py", line 4, in <module>
    b = a._lazy_clone()
RuntimeError: Expected storage != nullptr to be true, but got false.

We could sort of support COW for numpy tensors, but the problem is that there would be no way to prevent numpy from mutating the data underneath, as far as I know. Should we do the same as I mentioned above for shared tensors?

So if _lazy_clone is given a [numpy] tensor, future behavior would just be an immediate clone, and simulated behavior would just be a view that raises a deprecation warning immediately.

Also, if we don't support COW for numpy-based tensors, what exactly should the warning message about the future behavior say? Ideally, it would say something like "This operation creates a conditional view of a numpy-based tensor. In the future it will unconditionally create a clone". That would give the user a clue about why this is happening. But from _lazy_clone's perspective, I don't think there is currently any way for it to check whether the tensor was created with from_numpy--all it knows is that the tensor does not have a simple context (ie, context != data). The deleter function for a data pointer created by from_numpy is set to a lambda here, and the address of the lambda will be different each time, so we can't use the address of the deleter to check if it's numpy-based either.

And I'm sure there are going to be other kinds of cases where we don't have a simple data pointer. We could think about creating a special deleter and context to wrap around numpy data pointers and other non-simple data pointers to add some information about what these data pointers are, for the sole purpose of making the warning messages emitted by _lazy_clone more transparent. I'm not sure if it's really worth the extra complexity though

@ezyang
Copy link
Contributor

ezyang commented Jun 19, 2024

Yeah, the safe thing seems to be to unconditionally clone, and then adjusting the warnings accordingly also seems good.

[ghstack-poisoned]
@kurtamohler kurtamohler added the keep-going Don't stop on first failure, keep running tests until the end label Jun 21, 2024
[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Jun 21, 2024
ghstack-source-id: 54d1377955d87b06dbc6250722182de24160016b
Pull Request resolved: #126129
@kurtamohler
Copy link
Collaborator Author

I'm getting failures for gradgradcheck that I don't understand. For instance:

Click to expand
$ python test/test_ops_gradients.py -k test_fn_gradgrad_trapz_cpu_float64
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:1154: FutureWarning: Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.
  result = vmap(vjp)(torch.stack(grad_outputs))
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:1154: FutureWarning: Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.
  result = vmap(vjp)(torch.stack(grad_outputs))
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:1154: FutureWarning: Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.
  result = vmap(vjp)(torch.stack(grad_outputs))
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:1154: FutureWarning: Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.
  result = vmap(vjp)(torch.stack(grad_outputs))
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:1154: FutureWarning: Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.
  result = vmap(vjp)(torch.stack(grad_outputs))
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:1154: FutureWarning: Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.
  result = vmap(vjp)(torch.stack(grad_outputs))
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:1154: FutureWarning: Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.
  result = vmap(vjp)(torch.stack(grad_outputs))
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:1154: FutureWarning: Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.
  result = vmap(vjp)(torch.stack(grad_outputs))
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:908: UserWarning: Detected divergent behavior on read (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  grad_inputs = vjp_fn(v.reshape(sample_output.shape))
E
======================================================================
ERROR: test_fn_gradgrad_trapz_cpu_float64 (__main__.TestBwdGradientsCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/kurtamohler/develop/pytorch-1/torch/testing/_internal/common_device_type.py", line 1121, in test_wrapper
    return test(*args, **kwargs)
  File "/home/kurtamohler/develop/pytorch-1/test/test_ops_gradients.py", line 75, in test_fn_gradgrad
    self._check_helper(device, dtype, op, op.get_op(), "bwgrad_bwgrad")
  File "/home/kurtamohler/develop/pytorch-1/torch/testing/_internal/common_utils.py", line 4949, in _check_helper
    self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
  File "/home/kurtamohler/develop/pytorch-1/torch/testing/_internal/common_utils.py", line 4514, in gradgradcheck
    return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs)
  File "/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py", line 2255, in gradgradcheck
    return gradcheck(
  File "/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py", line 2053, in gradcheck
    return _gradcheck_helper(**args)
  File "/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py", line 2082, in _gradcheck_helper
    _gradcheck_real_imag(
  File "/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py", line 1492, in _gradcheck_real_imag
    gradcheck_fn(
  File "/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py", line 1922, in _fast_gradcheck
    analytical_vJu = _get_analytical_vJu_backward_mode(
  File "/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py", line 805, in _get_analytical_vJu_backward_mode
    all_vJ = _check_analytical_jacobian_attributes(
  File "/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py", line 791, in _check_analytical_jacobian_attributes
    raise GradcheckError(
torch.autograd.gradcheck.GradcheckError: Backward is not reentrant, i.e., running backward with same input and grad_output multiple times gives different values, although analytical gradient matches numerical gradient.The tolerance for nondeterminism was 0.0.

NOTE: If your op relies on non-deterministic operations i.e., it is listed here:
https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
this failure might be expected.

If you are adding a new operator, please file an issue and then use one of the
workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.
If the test
- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
  with `nondet_tol=<tol>` as a keyword argument.
- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
  to have `gradcheck_nondet_tol=<tol>`.
- is a Module test (e.g., in common_nn.py), then modify the corresponding
  module_test entry to have `gradcheck_nondet_tol=<tol>`


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/kurtamohler/develop/pytorch-1/torch/testing/_internal/common_utils.py", line 2748, in wrapper
    method(*args, **kwargs)
  File "/home/kurtamohler/develop/pytorch-1/torch/testing/_internal/common_device_type.py", line 446, in instantiated_test
    result = test(self, **param_kwargs)
  File "/home/kurtamohler/develop/pytorch-1/torch/testing/_internal/common_utils.py", line 1363, in wrapper
    fn(*args, **kwargs)
  File "/home/kurtamohler/develop/pytorch-1/torch/testing/_internal/common_device_type.py", line 1127, in test_wrapper
    raise Exception(  # noqa: TRY002
Exception: Caused by sample input at index 2: SampleInput(input=Tensor[size=(6,), device="cpu", dtype=torch.float64], args=TensorList[Tensor[size=(6,), device="cpu", dtype=torch.float64]], kwargs={}, broadcasts_input=False, name='')

To execute this test, run the following from the base repo dir:
     python test/test_ops_gradients.py -k TestBwdGradientsCPU.test_fn_gradgrad_trapz_cpu_float64

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 1 test in 0.048s

FAILED (errors=1)

If I change _lazy_clone to just return a normal view of the input, the error goes away, but as soon as I add a tensor.unsafeGetTensorImpl()->set_storage_keep_dtype(new_storage), where new_storage is the storage returned by lazy_clone_storage, the error comes back. I don't understand why setting the tensor's storage to a different storage that points to the exact same underlying data would change numerical results for double grad. Apparently, single grad is working fine.

[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Jun 26, 2024
ghstack-source-id: b673187821d5b34ca0a1001ae2efb91e49f9e8a1
Pull Request resolved: #126129
@ezyang
Copy link
Contributor

ezyang commented Jun 28, 2024

Is it possible we are accidentally deallocating the data too early?

@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Jul 9, 2024

Is it possible we are accidentally deallocating the data too early?

I'm not sure yet. But I was able to write a minimal reproducer:

import torch
y = torch.randn(6, dtype=torch.float64, requires_grad=True)
x = torch.randn(6, dtype=torch.float64, requires_grad=True)
inputs = (y, x)
torch.autograd.gradgradcheck(
    torch.trapezoid,
    inputs,
    check_batched_grad=True,
)

A few observations:

  • gradgradcheck only fails for the batched grad check. If I turn it off, it passes.
  • If I set torch.set_future_lazy_clone(True), the check passes. So the failure is only happening with _lazy_clone_alias/COWSim, and not with _lazy_clone_future/COW.
  • If I change the implementation of _lazy_clone_alias to be exactly the same as _lazy_clone_future, the above code snippet still fails. So the issue must be with how _lazy_clone_alias is registered/dispatched.

[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Jul 17, 2024
ghstack-source-id: 003b7f106109d89094b0b3c0ab46b3b3fa52f481
Pull Request resolved: #126129
@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Jul 17, 2024

I've still been trying to debug the failure for python test/test_ops_gradients.py -k test_fn_gradgrad_trapezoid_cpu_float64.

I have the following script which just runs gradgradcheck on trapezoid four times. Two of the runs have future lazy clone disabled so that COWSim will be used and the other two have future lazy clone enabled so that real COW will be used. For both of these pairs of runs, the check_batched_grad setting for gradgradcheck is toggled on and off.

Click to expand
import torch
import warnings
from itertools import product
warnings.simplefilter('always')
warnings.filterwarnings('ignore', '.*torch\.vmap.*', FutureWarning)
for future, check_batched_grad in product([False, True], [False, True]):
    print('--------------------------------------------')
    print(f'future={future}, check_batched_grad={check_batched_grad}')
    torch.set_future_lazy_clone(future)
    y = torch.randn(6, dtype=torch.float64, requires_grad=True)
    x = torch.randn(6, dtype=torch.float64, requires_grad=True)
    try:
        torch.autograd.gradgradcheck(
            torch.trapezoid,
            (y, x),
            check_batched_grad=check_batched_grad)
    except RuntimeError:
        print('FAIL')
    else:
        print('PASS')

If I run it, I get the following output:

Click to expand
--------------------------------------------
future=False, check_batched_grad=False
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:892: UserWarning: Detected divergent behavior on write (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  flat_grad_out.zero_()
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:892: UserWarning: Detected divergent behavior on write (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  flat_grad_out.zero_()
PASS
--------------------------------------------
future=False, check_batched_grad=True
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:892: UserWarning: Detected divergent behavior on write (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  flat_grad_out.zero_()
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:892: UserWarning: Detected divergent behavior on write (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  flat_grad_out.zero_()
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:1154: UserWarning: Detected divergent behavior on read (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  result = vmap(vjp)(torch.stack(grad_outputs))
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:1154: UserWarning: Detected divergent behavior on read (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  result = vmap(vjp)(torch.stack(grad_outputs))
FAIL
--------------------------------------------
future=True, check_batched_grad=False
PASS
--------------------------------------------
future=True, check_batched_grad=True
PASS

This shows that everything passes, except in the case where future lazy clone is turned off and batched grads are being checked.

So then I thought, what if I change the implementation of aten::_lazy_clone_future to do the same thing that aten::_lazy_clone_alias does? That way, we're still calling either aten::_lazy_clone_alias or aten::_lazy_clone_future depending on the setting for torch.set_future_lazy_clone, but they will both apply COWSim (return a view of the input that with a new StorageImpl pointing to the same data as the input, and both the input and output have COWSim applied to them). So the only real difference between _lazy_clone_alias and _lazy_clone_future would be how they are registered/dispatched.

This is the diff I applied on top of my current PR:

Click to expand
diff --git a/aten/src/ATen/native/AutogradComposite.cpp b/aten/src/ATen/native/AutogradComposite.cpp
index eb8fc453507..f15c142abe3 100644
--- a/aten/src/ATen/native/AutogradComposite.cpp
+++ b/aten/src/ATen/native/AutogradComposite.cpp
@@ -141,7 +141,7 @@ Tensor _lazy_clone_alias(Tensor const& self) {
 }
 
 Tensor _lazy_clone_future(Tensor const& self) {
-  return _lazy_clone_impl(self, /*future=*/true);
+  return _lazy_clone_impl(self, /*future=*/false);
 }
 
 Tensor _lazy_clone(Tensor const& self) {

Here's what I get when I rerun the script above:

Click to expand
--------------------------------------------
future=False, check_batched_grad=False
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:892: UserWarning: Detected divergent behavior on write (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  flat_grad_out.zero_()
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:892: UserWarning: Detected divergent behavior on write (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  flat_grad_out.zero_()
PASS
--------------------------------------------
future=False, check_batched_grad=True
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:892: UserWarning: Detected divergent behavior on write (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  flat_grad_out.zero_()
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:892: UserWarning: Detected divergent behavior on write (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  flat_grad_out.zero_()
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:1154: UserWarning: Detected divergent behavior on read (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  result = vmap(vjp)(torch.stack(grad_outputs))
/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py:1154: UserWarning: Detected divergent behavior on read (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  result = vmap(vjp)(torch.stack(grad_outputs))
FAIL
--------------------------------------------
future=True, check_batched_grad=False
/home/kurtamohler/develop/pytorch-1/torch/autograd/graph.py:768: UserWarning: Detected divergent behavior on read (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/home/kurtamohler/develop/pytorch-1/torch/autograd/graph.py:768: UserWarning: Detected divergent behavior on read (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
PASS
--------------------------------------------
future=True, check_batched_grad=True
/home/kurtamohler/develop/pytorch-1/torch/autograd/graph.py:768: UserWarning: Detected divergent behavior on read (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/home/kurtamohler/develop/pytorch-1/torch/autograd/graph.py:768: UserWarning: Detected divergent behavior on read (Triggered internally at /home/kurtamohler/develop/pytorch-1/c10/core/impl/COWDeleter.cpp:65.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
PASS

Even though the implementation is the same for set_future_lazy_clone(False) and set_future_lazy_clone(True), we're still only getting a failure in the future=False case. Also, I notice that the divergent behavior warnings (emitted by the COWSim machinery) are raised at different places in the above printout for future=False versus future=True. So something about how _lazy_clone_alias is registered must be causing the issue.

I was wondering why the future=True, check_batched_grad=False case has two more warnings than the other cases. I figured I could try to avoid the future divergent behavior warning at gradcheck.py:1154 with the following diff:

Click to expand
diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py
index 5bf74afacb6..2203c948a09 100644
--- a/torch/autograd/gradcheck.py
+++ b/torch/autograd/gradcheck.py
@@ -1142,7 +1142,7 @@ def _test_batched_grad(input, output, output_idx) -> bool:
 
     grad_outputs = [torch.randn_like(output) for _ in range(2)]
 
-    expected = [vjp(gO) for gO in grad_outputs]
+    expected = [vjp(gO.clone()) for gO in grad_outputs]
     expected = [torch.stack(shards) for shards in zip(*expected)]
 
     # Squash warnings since these are expected to happen in most cases

That did indeed avoid the extra warning, and it also made the script pass in all four cases. Of course, this is not a real solution, since we need to make sure this PR's changes are backward compatible.

I also tried removing the alias annotations and other related things from _lazy_clone_alias, so that the function is registered/dispatched exactly the same way as _lazy_clone_future. This is the diff for that:

Click to expand
diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp
index fa30199bb97..a1cf449cde7 100644
--- a/aten/src/ATen/FunctionalInverses.cpp
+++ b/aten/src/ATen/FunctionalInverses.cpp
@@ -443,14 +443,6 @@ Tensor FunctionalInverses::alias_inverse(const Tensor& base, const Tensor& mutat
     }
 }
 
-Tensor FunctionalInverses::_lazy_clone_alias_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
-  if (inverse_return_mode != InverseReturnMode::NeverView) {
-    return at::_lazy_clone_alias(mutated_view);
-  } else {
-    return at::_lazy_clone_alias_copy(mutated_view);
-  }
-}
-
 Tensor FunctionalInverses::chunk_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int chunks, int dim) {
     // TODO: Can the logic from TensorShape.cpp be reused here somehow?
     const auto dim_size = base.sym_size(dim);
diff --git a/aten/src/ATen/functorch/BatchRulesViews.cpp b/aten/src/ATen/functorch/BatchRulesViews.cpp
index 93a3033e0b4..ed0553e915b 100644
--- a/aten/src/ATen/functorch/BatchRulesViews.cpp
+++ b/aten/src/ATen/functorch/BatchRulesViews.cpp
@@ -279,13 +279,6 @@ std::tuple<Tensor, optional<int64_t>> _reshape_alias_batch_rule(const Tensor& se
   return std::make_tuple(at::reshape_symint(self_, new_shape), 0);
 }
 
-std::tuple<Tensor, optional<int64_t>> _lazy_clone_alias_batch_rule(
-    const Tensor &self, optional<int64_t> bdim) {
-  TORCH_INTERNAL_ASSERT(bdim.has_value());
-  auto self_ = moveBatchDimToFront(self, bdim);
-  return std::make_tuple(at::_lazy_clone_alias(self), 0);
-}
-
 std::tuple<Tensor, optional<int64_t>> roll_batch_rule(const Tensor& self, optional<int64_t> bdim, SymIntArrayRef shifts, IntArrayRef dims) {
   TORCH_INTERNAL_ASSERT(bdim.has_value());
 
@@ -572,7 +565,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
   VMAP_SUPPORT2(squeeze, dim, squeeze_dim_batch_rule);
   VMAP_SUPPORT2(squeeze, dims, squeeze_dims_batch_rule);
   VMAP_SUPPORT(_reshape_alias, _reshape_alias_batch_rule);
-  VMAP_SUPPORT(_lazy_clone_alias, _lazy_clone_alias_batch_rule);
   VMAP_SUPPORT(roll, roll_batch_rule);
   VMAP_SUPPORT(permute, permute_batching_rule);
   VMAP_SUPPORT(diagonal, diagonal_batching_rule);
diff --git a/aten/src/ATen/native/AutogradComposite.cpp b/aten/src/ATen/native/AutogradComposite.cpp
index eb8fc453507..f15c142abe3 100644
--- a/aten/src/ATen/native/AutogradComposite.cpp
+++ b/aten/src/ATen/native/AutogradComposite.cpp
@@ -141,7 +141,7 @@ Tensor _lazy_clone_alias(Tensor const& self) {
 }
 
 Tensor _lazy_clone_future(Tensor const& self) {
-  return _lazy_clone_impl(self, /*future=*/true);
+  return _lazy_clone_impl(self, /*future=*/false);
 }
 
 Tensor _lazy_clone(Tensor const& self) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index b79c8d8ff28..0912a86801f 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1255,7 +1255,7 @@
   dispatch:
     CompositeExplicitAutograd: _lazy_clone_future
 
-- func: _lazy_clone_alias(Tensor(a) self) -> Tensor(a)
+- func: _lazy_clone_alias(Tensor self) -> Tensor
   variants: function, method
   dispatch:
     CompositeExplicitAutograd: _lazy_clone_alias
@@ -1266,12 +1266,6 @@
   dispatch:
     CompositeExplicitAutogradNonFunctional: _lazy_clone_copy
 
-- func: _lazy_clone_alias_copy(Tensor self) -> Tensor
-  variants: function
-  tags: view_copy
-  dispatch:
-    CompositeExplicitAutogradNonFunctional: _lazy_clone_alias_copy
-
 - func: logical_not(Tensor self) -> Tensor
   device_check: NoCheck   # TensorIterator
   variants: function, method
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index a614ba419ca..a978a2546ed 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -442,7 +442,7 @@
   self: grad
   result: auto_linear
 
-- name: _lazy_clone_alias(Tensor(a) self) -> Tensor(a)
+- name: _lazy_clone_alias(Tensor self) -> Tensor
   self: grad
   result: auto_linear
 
@@ -454,10 +454,6 @@
   self: grad
   result: auto_linear
 
-- name: _lazy_clone_alias_copy(Tensor self) -> Tensor
-  self: grad
-  result: auto_linear
-
 - name: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
   self: _to_copy_backward(grad, self.options())
   result: _to_copy(self_t, dtype, layout, device, pin_memory, non_blocking, memory_format)
diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py
index 671b726a719..d1392f5407c 100644
--- a/tools/autograd/gen_inplace_or_view_type.py
+++ b/tools/autograd/gen_inplace_or_view_type.py
@@ -98,7 +98,6 @@ VIEW_FUNCTIONS = {
     # FIXME: clone indices on construction.
     "sparse_coo_tensor_with_dims_and_tensors": "values",
     "_reshape_alias": "self",
-    "_lazy_clone_alias": "self",
     "_test_autograd_multiple_dispatch_view": "self",
 }
 
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index 6e5984e5322..8604f2ba61c 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -135,7 +135,6 @@ _SKIP_PYTHON_BINDINGS = [
     "_reshape_copy",
     "_reshape_copy_out",
     "_lazy_clone_alias",
-    "_lazy_clone_alias_copy",
     "_lazy_clone_copy",
     "_lazy_clone_future",
     "copy_sparse_to_sparse_",
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 0c45bfad1cb..b2158fecb27 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -182,7 +182,6 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
     "clone",
     "_lazy_clone",
     "_lazy_clone_alias",
-    "_lazy_clone_alias_copy",
     "_lazy_clone_copy",
     "_lazy_clone_future",
     "block_diag",
diff --git a/torch/masked/maskedtensor/passthrough.py b/torch/masked/maskedtensor/passthrough.py
index a3a2460cc78..4a2e79456c8 100644
--- a/torch/masked/maskedtensor/passthrough.py
+++ b/torch/masked/maskedtensor/passthrough.py
@@ -28,7 +28,6 @@ PASSTHROUGH_FNS = [
     torch.ops.aten.view,
     torch.ops.aten._unsafe_view,
     torch.ops.aten._reshape_alias,
-    torch.ops.aten._lazy_clone_alias,
     torch.ops.aten.cat,
     torch.ops.aten.unsqueeze,
 ]

That made the script pass in all four cases. But this is also not an acceptable solution--the whole reason why I have a separate _lazy_clone_alias op is that we need the alias annotations in the future=False case, and we need the output of _lazy_clone_alias to have ._is_view() == True.

I still have some avenues to investigate, but I'm wondering if anything sticks out to anyone. The specific question is, what is wrong with the way I've registered _lazy_clone_alias such that its behavior would change with double batched grad?

EDIT:

Here is the error message for the failing case of future=False, check_batched_grad=True if I remove the try/except from the script above:

Click to expand
Traceback (most recent call last):
  File "/home/kurtamohler/tmp/tmp20.py", line 13, in <module>
    torch.autograd.gradgradcheck(
  File "/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py", line 2255, in gradgradcheck
    return gradcheck(
  File "/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py", line 2053, in gradcheck
    return _gradcheck_helper(**args)
  File "/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py", line 2107, in _gradcheck_helper
    _test_batched_grad(tupled_inputs, o, i)
  File "/home/kurtamohler/develop/pytorch-1/torch/autograd/gradcheck.py", line 1167, in _test_batched_grad
    raise GradcheckError(
torch.autograd.gradcheck.GradcheckError: For output 1 and input 0:

gradcheck or gradgradcheck failed while testing batched gradient computation.
This could have been invoked in a number of ways (via a test that calls
gradcheck/gradgradcheck directly or via an autogenerated test).

If you are adding a new operator, please file an issue and then use one of the
workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.
If the test
- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
  with `check_batched_grad=False` as a keyword argument.
- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
  to have `check_batched_grad=False` and/or `check_batched_gradgrad=False`.

If you're modifying an existing operator that supports batched grad computation,
or wish to make a new operator work with batched grad computation, please read
the following.

To compute batched grads (e.g., jacobians, hessians), we vmap over the backward
computation. The most common failure case is if there is a 'vmap-incompatible
operation' in the backward pass. Please see
NOTE: [How to write vmap-compatible backward formulas]
in the codebase for an explanation of how to fix this.

Got:
tensor([[ 0.0051,  0.0442,  0.0354, -0.0930, -0.0006,  0.0887],
        [ 0.0640, -0.0238, -0.0391,  0.0469, -0.0245, -0.0226]],
       dtype=torch.float64)

Expected:
tensor([[-0.0075,  0.0241,  0.0596, -0.0334, -0.0340,  0.0274],
        [ 0.0474,  0.0070, -0.0321,  0.0148, -0.0098, -0.0162]],
       dtype=torch.float64)

@ezyang
Copy link
Contributor

ezyang commented Jul 18, 2024

I'm sorry, I've been swamped recently, I will read this carefully when I get a chance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor keep-going Don't stop on first failure, keep running tests until the end module: dynamo open source
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants