Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aota] Needs autograd if an input requires_grad, agnostic to enable_grad #128890

Closed
wants to merge 12 commits into from

Conversation

IvanKobzarev
Copy link
Contributor

@IvanKobzarev IvanKobzarev commented Jun 17, 2024

Stack from ghstack (oldest at bottom):

Original issue: #114338

Reland of: #128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()

Changes in partitioner?

Inference and Training graphs had difference in return container, list/tuple.
The changes in partitioner are done to unify and return always tuple.
As a result - some changes in test_aotdispatch.py for graph contents list -> tuple.

Why was revert?

There was a regression of hf_Reformer model on inference.

TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with torch.no_grad() - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse

Copy link

pytorch-bot bot commented Jun 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128890

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 9c5de20 with merge base e6cddc9 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

IvanKobzarev added a commit that referenced this pull request Jun 17, 2024
ghstack-source-id: 35d40650db81040f8eca0476194ffc7f918be565
Pull Request resolved: #128890
@ezyang
Copy link
Contributor

ezyang commented Jun 17, 2024

I had an instinctive dislike to the approach taken here but after thinking about it I couldn't figure any other way to do it lol. We should make sure we have hit all the paths that actually need enable_grad, and I guess we need to make sure there aren't funny no grad / view interactions (cc @albanD). This could be right, but a more detailed case analysis for the various situations would make me feel more confident about it

…to enable_grad"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
@IvanKobzarev IvanKobzarev added the topic: not user facing topic category label Jun 20, 2024
@IvanKobzarev
Copy link
Contributor Author

About funny view interactions:

Just found that view type ops preserve requires_grad even in no_grad, that resulted in one of the models in torchbench to start generating training graph under no_grad as it had outputs, that were torch.expand of nn.Parameter(requires_grad=True) inptus :)

@@ -346,6 +346,8 @@ def load_model(
model.train()
else:
model.eval()
model.requires_grad_(False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm i'm not convinced we want to do this... since this will just paper over any bugs that we have. If torchbench gets bad inference performance unless we manually set all parameters to not require grad, it should show up as a regression so we can actually diagnose and fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, forgot to remove this. Will clean it.

@bdhirsh bdhirsh requested a review from a team June 20, 2024 15:11
…to enable_grad"


Reland of:  #128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()


Why was revert?

There was a regression of hf_Reformer model on inference. 
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jun 20, 2024
ghstack-source-id: a5a0e9110ea3697c672eeba9bf13959b6b52ba90
Pull Request resolved: #128890
x.requires_grad for x in fw_metadata.output_info
x.requires_grad
# view operations preserve requires_grad even in no_grad.
# Do not count aliases of inputs with requires_grad as reason to make a training graph.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would mention in the comment that it AOTAutograd will perform view-replay to regenerate the view outputs at runtime, setting their grad_fn properly.

Otherwise, it's not clear just from the comment about why it's ok to "ignore" views that require grad when deciding that we can generate an inference graph here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, added.

…to enable_grad"


Reland of:  #128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()


Why was revert?

There was a regression of hf_Reformer model on inference. 
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
…to enable_grad"


Reland of:  #128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()


Why was revert?

There was a regression of hf_Reformer model on inference. 
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jun 24, 2024
ghstack-source-id: d435aee0988324af359380c6b360ba0b1c8aa10d
Pull Request resolved: #128890
@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 24, 2024

chatted offline, but failures are because you need to update the expecttests. you can do so with EXPECTTEST_ACCEPT=1

…to enable_grad"


Reland of:  #128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()


Why was revert?

There was a regression of hf_Reformer model on inference. 
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jun 24, 2024
ghstack-source-id: 945f8691c094e78de2c8797a259f6ca5b3259cf8
Pull Request resolved: #128890
@bdhirsh bdhirsh requested a review from a team June 24, 2024 23:15
…to enable_grad"


Reland of:  #128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()


Why was revert?

There was a regression of hf_Reformer model on inference. 
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jul 9, 2024
ghstack-source-id: 3ecaa4defa64e4917eb821d6645d5988b1ba6270
Pull Request resolved: #128890
@zou3519
Copy link
Contributor

zou3519 commented Jul 18, 2024

Seems to have broken tests on trunk. Example logs: https://ossci-raw-job-status.s3.amazonaws.com/log/27619471188

@zou3519
Copy link
Contributor

zou3519 commented Jul 18, 2024

@pytorchbot help

Copy link

pytorch-bot bot commented Jul 18, 2024

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'help' (choose from 'merge', 'revert', 'rebase', 'label', 'drci', 'cherry-pick', 'close')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick,close} ...

Try @pytorchbot --help for more info.

@zou3519
Copy link
Contributor

zou3519 commented Jul 18, 2024

@pytorchbot revert -m "broke trunk tests, probably a landrace" -c landrace

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@IvanKobzarev your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Jul 18, 2024
…enable_grad (#128890)"

This reverts commit e98135d.

Reverted #128890 on behalf of https://github.com/zou3519 due to broke trunk tests, probably a landrace ([comment](#128890 (comment)))
@IvanKobzarev IvanKobzarev added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label Jul 18, 2024
…to enable_grad"


Reland of:  #128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()

Changes in partitioner?

Inference and Training graphs had difference in return container, list/tuple.
The changes in partitioner are done to unify and return always tuple.
As a result - some changes in test_aotdispatch.py for graph contents list -> tuple.


Why was revert?

There was a regression of hf_Reformer model on inference. 
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.






cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jul 18, 2024
ghstack-source-id: 0c9c536017361d0e14837b7b3abd8ad9b15f8f44
Pull Request resolved: #128890
DiweiSun pushed a commit to DiweiSun/pytorch that referenced this pull request Jul 22, 2024
…rad (pytorch#128890)

Reland of:  pytorch#128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()

Changes in partitioner?

Inference and Training graphs had difference in return container, list/tuple.
The changes in partitioner are done to unify and return always tuple.
As a result - some changes in test_aotdispatch.py for graph contents list -> tuple.

Why was revert?

There was a regression of hf_Reformer model on inference.
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.

Pull Request resolved: pytorch#128890
Approved by: https://github.com/bdhirsh
DiweiSun pushed a commit to DiweiSun/pytorch that referenced this pull request Jul 22, 2024
…enable_grad (pytorch#128890)"

This reverts commit e98135d.

Reverted pytorch#128890 on behalf of https://github.com/zou3519 due to broke trunk tests, probably a landrace ([comment](pytorch#128890 (comment)))
bigfootjon pushed a commit that referenced this pull request Jul 24, 2024
…rad (#128890)

Reland of:  #128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()

Changes in partitioner?

Inference and Training graphs had difference in return container, list/tuple.
The changes in partitioner are done to unify and return always tuple.
As a result - some changes in test_aotdispatch.py for graph contents list -> tuple.

Why was revert?

There was a regression of hf_Reformer model on inference.
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.

Pull Request resolved: #128890
Approved by: https://github.com/bdhirsh

(cherry picked from commit e98135d)
bigfootjon pushed a commit that referenced this pull request Jul 24, 2024
…enable_grad (#128890)"

This reverts commit e98135d.

Reverted #128890 on behalf of https://github.com/zou3519 due to broke trunk tests, probably a landrace ([comment](#128890 (comment)))

(cherry picked from commit 120fdf7)
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…rad (pytorch#128890)

Reland of:  pytorch#128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()

Changes in partitioner?

Inference and Training graphs had difference in return container, list/tuple.
The changes in partitioner are done to unify and return always tuple.
As a result - some changes in test_aotdispatch.py for graph contents list -> tuple.

Why was revert?

There was a regression of hf_Reformer model on inference.
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.

Pull Request resolved: pytorch#128890
Approved by: https://github.com/bdhirsh
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…rad (pytorch#128890)

Reland of:  pytorch#128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()

Changes in partitioner?

Inference and Training graphs had difference in return container, list/tuple.
The changes in partitioner are done to unify and return always tuple.
As a result - some changes in test_aotdispatch.py for graph contents list -> tuple.

Why was revert?

There was a regression of hf_Reformer model on inference.
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.

Pull Request resolved: pytorch#128890
Approved by: https://github.com/bdhirsh
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…enable_grad (pytorch#128890)"

This reverts commit e98135d.

Reverted pytorch#128890 on behalf of https://github.com/zou3519 due to broke trunk tests, probably a landrace ([comment](pytorch#128890 (comment)))
…to enable_grad"


Reland of:  #128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()

Changes in partitioner?

Inference and Training graphs had difference in return container, list/tuple.
The changes in partitioner are done to unify and return always tuple.
As a result - some changes in test_aotdispatch.py for graph contents list -> tuple.


Why was revert?

There was a regression of hf_Reformer model on inference. 
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.






cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jul 30, 2024
ghstack-source-id: 8c1bb666dd096a59f00f85f675f2af509b9ef94b
Pull Request resolved: #128890
@IvanKobzarev
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: AO frontend Reverted topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants