-
Notifications
You must be signed in to change notification settings - Fork 22k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch.compile Error: RuntimeError: aten::_conj() Expected a value of type 'Tensor' for argument 'self' but instead found type 'complex'. #105290
Comments
Which version you are using .
Can you put the type (print(type(Value)) to check the type of the value you are inserting |
Do you get the same error when you try |
@williamwen42 No error with backend="eager" but for backend="aot_eager" i get this error
|
@bdhirsh any thoughts? |
Just wanna add another data point to this discussion, we haven't been super highly prioritizing complex numbers because the argument was that people can rewrite their code to avoid using complex numbers using eulers identity So I tried to do that here #105665 (comment) for rotary embeddings and got 10x slowdowns so in theory the rewrite argument is maybe fine but then we should be a more prescriptive about how to do those rewrites. Good starting point is probably rotary positional embeddings |
hi @williamwen42 @bdhirsh please do you have any further comments for getting torch.compile to work with my code? |
It seems to me that there's an op using complex numbers that we don't support yet. It's hard for me to further investigate without having a runnable code example (can you provide code with hardcoded config and random input?). You can try seeing the output when setting the envvar |
Hey @lyndonlauder if this is still breaking for you, do you mind including a self-contained repro? That will make diagnosing easier. The repro above isn't runnable(e.g. args doesn't exist) |
This problem also affects AudioLM. https://gist.github.com/ezyang/6914fb5dedff6598dc673288898a7498 has the graph going into AOTAutograd and a stack trace Full repro code: gist.github.com/ezyang/64c24c9fc5529f3afed4ee4266f6adc5 but it fails before you get to this error; you need to bypass a different error by disabling compile in See also https://docs.google.com/document/d/14uWNDXa10I_Dpq1KxYShJwYmjHT0xbmv38ZWYCN40NE/edit |
Putting this back into the queue for a bit. |
Very simple min repro:
The problem is roughly: (1) the python arg parser converts 1j to a tensor, and the autograd engine sees a call to Still thinking about what the right thing to do here is... somewhere, we want to arrange for the scalar to become a tensor again. |
…k to python" #105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
#105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…k to python" #105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
#105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…k to python" #105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
#105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…k to python" #105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
#105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…k to python" #105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
#105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…k to python" #105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
#105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…k to python" #105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
#105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
#105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. Pull Request resolved: #131482 Approved by: https://github.com/zou3519, https://github.com/ezyang ghstack dependencies: #131403
🐛 Describe the bug
Training code
if name == "main":
train()`
Model code
Error logs
Minified repro
No response
Versions
cc @ezyang @gchanan @zou3519 @kadeng @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved @amjames @bdhirsh @msaroufim @anijain2305 @chauhang @wconstab
The text was updated successfully, but these errors were encountered: