Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo,torch_function] __torch_function__ does not respect kwargs #117971

Open
lezcano opened this issue Jan 22, 2024 · 0 comments
Open

[dynamo,torch_function] __torch_function__ does not respect kwargs #117971

lezcano opened this issue Jan 22, 2024 · 0 comments
Assignees
Labels
dynamo-tensor-subclasses dynamo-torch-function dynamo-triage-june2024 module: dynamo module: __torch_function__ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lezcano
Copy link
Collaborator

lezcano commented Jan 22, 2024

🐛 Describe the bug

Found when debugging some new failures coming from #117625.

Our implementation of torch_function does not trace the call within if has_torch_function in eager. This makes kwargs not to be populated correctly. For example in

pytorch/torch/nn/functional.py

Lines 2232 to 2243 in c378001

if has_torch_function_variadic(input, weight):
return handle_torch_function(
embedding,
(input, weight),
input,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
)

we see that some kwargs are populated, even when all the args are passed as positional args. This a few calls further down where the implementation expects just 2 positional args:
ctx.input, ctx.weight = expanded_args

cc @hameerabbasi @rgommers @ezyang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @peterbell10 @aakhundov @mlazos

Versions

master

@lezcano lezcano added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-tensor-subclasses dynamo-torch-function dynamo-triage-june2024 module: dynamo module: __torch_function__ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants