Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accuracy mismatch with torch.compile(backend="eager") for float16 #121238

Open
janeyx99 opened this issue Mar 5, 2024 · 1 comment
Open

Accuracy mismatch with torch.compile(backend="eager") for float16 #121238

janeyx99 opened this issue Mar 5, 2024 · 1 comment
Labels
dynamo-triage-june2024 module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@janeyx99
Copy link
Contributor

janeyx99 commented Mar 5, 2024

🐛 Describe the bug

For float16 (the repro passes if dtype is torch.float32), the two implementations differ tangibly when passed through dynamo when they are the same in eager. Example results shown below.

One thing to note is that if we switch the following to not pass in -1 to addcdiv, the difference is okay again.

    denom = exp_inf * bias_correction
    param.addcdiv_(exp_avg, denom, value=-1)

to

    denom = exp_inf * -bias_correction
    param.addcdiv_(exp_avg, denom)

Error logs

(pytorch-3.10) [[email protected] ~/local/pytorch (1e63ab5f)]$ python playground2.py 
The following are the same param=tensor([[ -0.8535,   0.1125,  -0.9385],
        [ -3.8652, -59.9062,  -6.5234]], device='cuda:0', dtype=torch.float16) and pc=tensor([[ -0.8535,   0.1125,  -0.9385],
        [ -3.8652, -59.9062,  -6.5234]], device='cuda:0', dtype=torch.float16)
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
Traceback (most recent call last):
  File "/data/users/janeyx/pytorch/playground2.py", line 54, in <module>
    assert torch.allclose(param, pc), f"results are not the same! {param=} {pc=}"
AssertionError: results are not the same! param=tensor([[ -0.7939, -22.8125,  -4.2969],
        [ -1.7324,  -7.2891,  -4.1484]], device='cuda:0', dtype=torch.float16) pc=tensor([[ -0.7944, -22.8125,  -4.2969],
        [ -1.7314,  -7.2852,  -4.1445]], device='cuda:0', dtype=torch.float16)

Minified repro

import torch
dtype = torch.float16

beta1 = 0.9
step_t = torch.tensor(2, dtype=dtype, device="cuda")
g_step_t = [torch.tensor(2, dtype=dtype, device="cuda")]


param = torch.rand(2, 3, dtype=dtype, device="cuda") 
pc = param.clone()
exp_inf = torch.rand(2, 3, dtype=dtype, device="cuda")
exp_inf_c = exp_inf.clone()
exp_avg = torch.rand(2, 3, dtype=dtype, device="cuda")
exp_avg_c = exp_avg.clone()

def reset():
    global beta1, step_t, g_step_t, param, pc, exp_inf, exp_inf_c, exp_avg, exp_avg_c
    beta1 = 0.9
    step_t = torch.tensor(2, dtype=dtype, device="cuda")
    g_step_t = [torch.tensor(2, dtype=dtype, device="cuda")]

    param = torch.rand(2, 3, dtype=dtype, device="cuda") 
    pc = param.clone()
    exp_inf = torch.rand(2, 3, dtype=dtype, device="cuda")
    exp_inf_c = exp_inf.clone()
    exp_avg = torch.rand(2, 3, dtype=dtype, device="cuda")
    exp_avg_c = exp_avg.clone()


def forloop_capturable():
    bias_correction = 1 - beta1 ** step_t
    denom = exp_inf * bias_correction
    param.addcdiv_(exp_avg, denom, value=-1)

def foreach_capturable():
    bias_corrections = torch._foreach_pow(beta1, g_step_t)
    # foreach_sub doesn't allow a scalar as the first arg
    torch._foreach_sub_(bias_corrections, 1)

    denom = torch._foreach_mul([exp_inf_c], bias_corrections)
    torch._foreach_addcdiv_([pc], [exp_avg_c], denom)

# these match in eager
forloop_capturable()
foreach_capturable()
assert torch.allclose(param, pc)
print(f"The following are the same {param=} and {pc=}")

reset()

# but not in compile
torch.compile(forloop_capturable, backend="eager")()
torch.compile(foreach_capturable, backend="eager")()
assert torch.allclose(param, pc), f"results are not the same! {param=} {pc=}"

Versions

main

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

janeyx99 added a commit that referenced this issue Mar 5, 2024
Finishes the work started in #118697. Thanks MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop.

Next steps:
* This PR discovered two bugs: #121178 and #121238.
* Move the now hefty graph optim tests in test_cuda to use OptimInfo.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
janeyx99 added a commit that referenced this issue Mar 5, 2024
Finishes the work started in #118697. Thanks MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop.

Next steps:
* This PR discovered two bugs: #121178 and #121238.
* Move the now hefty graph optim tests in test_cuda to use OptimInfo.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
janeyx99 added a commit that referenced this issue Mar 5, 2024
Finishes the work started in #118697. Thanks MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop.

Next steps:
* This PR discovered two bugs: #121178 and #121238.
* Move the now hefty graph optim tests in test_cuda to use OptimInfo.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
janeyx99 added a commit that referenced this issue Mar 5, 2024
Finishes the work started in #118697. Thanks MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop.

Next steps:
* This PR discovered two bugs: #121178 and #121238.
* Move the now hefty graph optim tests in test_cuda to use OptimInfo.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
janeyx99 added a commit that referenced this issue Mar 5, 2024
Finishes the work started in #118697. Thanks MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop.

Next steps:
* This PR discovered two bugs: #121178 and #121238.
* Move the now hefty graph optim tests in test_cuda to use OptimInfo.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
janeyx99 added a commit that referenced this issue Mar 5, 2024
Finishes the work started in #118697. Thanks MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop.

Next steps:
* This PR discovered two bugs: #121178 and #121238.
* Move the now hefty graph optim tests in test_cuda to use OptimInfo.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@williamwen42 williamwen42 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamo labels Mar 5, 2024
@mlazos
Copy link
Contributor

mlazos commented Mar 6, 2024

@janeyx99 can you try comparing to float64 and see which one is closer? This will determine if this is an actual bug, or the precision of torch.compile is higher than eager. This is especially pronounced in bfloat16 because eager ops upcast to float32 perform an op and then downcast. However in torch.compile since we generate a single kernel, this casting will only happen once.

pytorchmergebot pushed a commit that referenced this issue Mar 7, 2024
Finishes the work started in #118697. Thanks @MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop.

Next steps:
* This PR discovered two bugs: #121178 and #121238.
* Move the now hefty graph optim tests in test_cuda to use OptimInfo.

Pull Request resolved: #121183
Approved by: https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-triage-june2024 module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants