Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Memory Allocation for Autotuning out of the critical path #129258

Closed
shunting314 opened this issue Jun 21, 2024 · 8 comments
Closed

Move Memory Allocation for Autotuning out of the critical path #129258

shunting314 opened this issue Jun 21, 2024 · 8 comments
Assignees
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@shunting314
Copy link
Contributor

shunting314 commented Jun 21, 2024

🐛 Describe the bug

When working on #129043 (comment) we found the memory saving is
much larger then expected for llm.c (expect 6GB saving but actually saves about 15GB). It turns out that there the baseline uses more memory then expected due to autotuning happening for the first iteration.

Screenshot 2024-06-21 at 11 56 29 AM

The kernel does mutation, so autotuning needs to allocate a clone. Although this specific case has been fixed by the mentioned PR since we convert Scatter to a Pointwise, the general problem still exists.

I think one solution is, we do all those mutation kernels's autotuning right after we codegen the wrapper. This is a good time to do autotuning since we have not started allocating memory for activation/gradient/optimizer state etc.

cc @ezyang @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @jansel @Chillee @eellison

Error logs

No response

Minified repro

No response

Versions

.

cc @ezyang @anijain2305 @chauhang

@Chillee
Copy link
Contributor

Chillee commented Jun 21, 2024

@shunting314 But for the backwards pass, right after we codegen the wrapper aren't we at the peak memory location?

@shunting314
Copy link
Contributor Author

shunting314 commented Jun 21, 2024

But for the backwards pass, right after we codegen the wrapper aren't we at the peak memory location?

Good point. We probably also need to 'eagerly' generated the backward graph. If we generate the bwd wrapper before we run the fwd wrapper, we can still avoid increasing peak memory due to autotuning.

shunting314 added a commit that referenced this issue Jun 24, 2024
…rse matrix in CE bwd"


Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.

The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
       full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
       scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0);  full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor

allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:

```
    def inner_fn(idx):
        selector_idx = list(idx)
        selector_idx[dim] = 0  # can do this since the index tensor has a single element on the scatter dimension

        selector = selector_loader(selector_idx)
        return ops.where(
            selector == ops.index_expr(idx[dim], torch.int64),
            ops.constant(val, dtype),
            ops.constant(background_val, dtype),
        )
```

## Test result on microbenchmark

For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).

The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms

## Test result on llm.c

We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check #129258 and #129399 for more details).

For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)

## Test on PyTorch 2.0 Dashboard

The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).

TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.

cc jansel Chillee eellison 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this issue Jun 24, 2024
… bwd"


Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.

The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
       full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
       scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0);  full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor

allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:

```
    def inner_fn(idx):
        selector_idx = list(idx)
        selector_idx[dim] = 0  # can do this since the index tensor has a single element on the scatter dimension

        selector = selector_loader(selector_idx)
        return ops.where(
            selector == ops.index_expr(idx[dim], torch.int64),
            ops.constant(val, dtype),
            ops.constant(background_val, dtype),
        )
```

## Test result on microbenchmark

For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).

The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms

## Test result on llm.c

We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check #129258 and #129399 for more details).

For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)

## Test on PyTorch 2.0 Dashboard

The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).

TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.

cc jansel Chillee eellison 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this issue Jun 24, 2024
…rse matrix in CE bwd"


Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.

The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
       full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
       scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0);  full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor

allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:

```
    def inner_fn(idx):
        selector_idx = list(idx)
        selector_idx[dim] = 0  # can do this since the index tensor has a single element on the scatter dimension

        selector = selector_loader(selector_idx)
        return ops.where(
            selector == ops.index_expr(idx[dim], torch.int64),
            ops.constant(val, dtype),
            ops.constant(background_val, dtype),
        )
```

## Test result on microbenchmark

For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).

The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms

## Test result on llm.c

We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check #129258 and #129399 for more details).

For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)

## Test on PyTorch 2.0 Dashboard

The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).

TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.

cc jansel Chillee eellison 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this issue Jun 24, 2024
… bwd"


Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.

The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
       full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
       scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0);  full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor

allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:

```
    def inner_fn(idx):
        selector_idx = list(idx)
        selector_idx[dim] = 0  # can do this since the index tensor has a single element on the scatter dimension

        selector = selector_loader(selector_idx)
        return ops.where(
            selector == ops.index_expr(idx[dim], torch.int64),
            ops.constant(val, dtype),
            ops.constant(background_val, dtype),
        )
```

## Test result on microbenchmark

For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).

The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms

## Test result on llm.c

We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check #129258 and #129399 for more details).

For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)

## Test on PyTorch 2.0 Dashboard

The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).

TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.

cc jansel Chillee eellison 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this issue Jun 25, 2024
…rse matrix in CE bwd"


Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.

The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
       full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
       scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0);  full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor

allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:

```
    def inner_fn(idx):
        selector_idx = list(idx)
        selector_idx[dim] = 0  # can do this since the index tensor has a single element on the scatter dimension

        selector = selector_loader(selector_idx)
        return ops.where(
            selector == ops.index_expr(idx[dim], torch.int64),
            ops.constant(val, dtype),
            ops.constant(background_val, dtype),
        )
```

## Test result on microbenchmark

For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).

The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms

## Test result on llm.c

We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check #129258 and #129399 for more details).

For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)

## Test on PyTorch 2.0 Dashboard

The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).

TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.

cc jansel Chillee eellison 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this issue Jun 25, 2024
… bwd"


Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.

The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
       full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
       scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0);  full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor

allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:

```
    def inner_fn(idx):
        selector_idx = list(idx)
        selector_idx[dim] = 0  # can do this since the index tensor has a single element on the scatter dimension

        selector = selector_loader(selector_idx)
        return ops.where(
            selector == ops.index_expr(idx[dim], torch.int64),
            ops.constant(val, dtype),
            ops.constant(background_val, dtype),
        )
```

## Test result on microbenchmark

For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).

The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms

## Test result on llm.c

We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check #129258 and #129399 for more details).

For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)

## Test on PyTorch 2.0 Dashboard

The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).

TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.

cc jansel Chillee eellison 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this issue Jun 25, 2024
…rse matrix in CE bwd"


Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.

The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
       full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
       scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0);  full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor

allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:

```
    def inner_fn(idx):
        selector_idx = list(idx)
        selector_idx[dim] = 0  # can do this since the index tensor has a single element on the scatter dimension

        selector = selector_loader(selector_idx)
        return ops.where(
            selector == ops.index_expr(idx[dim], torch.int64),
            ops.constant(val, dtype),
            ops.constant(background_val, dtype),
        )
```

## Test result on microbenchmark

For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).

The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms

## Test result on llm.c

We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check #129258 and #129399 for more details).

For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)

## Test on PyTorch 2.0 Dashboard

The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).

TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.

cc jansel Chillee eellison 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this issue Jun 25, 2024
… bwd"


Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.

The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
       full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
       scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0);  full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor

allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:

```
    def inner_fn(idx):
        selector_idx = list(idx)
        selector_idx[dim] = 0  # can do this since the index tensor has a single element on the scatter dimension

        selector = selector_loader(selector_idx)
        return ops.where(
            selector == ops.index_expr(idx[dim], torch.int64),
            ops.constant(val, dtype),
            ops.constant(background_val, dtype),
        )
```

## Test result on microbenchmark

For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).

The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms

## Test result on llm.c

We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check #129258 and #129399 for more details).

For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)

## Test on PyTorch 2.0 Dashboard

The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).

TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.

cc jansel Chillee eellison 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this issue Jun 25, 2024
…rse matrix in CE bwd"


Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.

The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
       full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
       scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0);  full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor

allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:

```
    def inner_fn(idx):
        selector_idx = list(idx)
        selector_idx[dim] = 0  # can do this since the index tensor has a single element on the scatter dimension

        selector = selector_loader(selector_idx)
        return ops.where(
            selector == ops.index_expr(idx[dim], torch.int64),
            ops.constant(val, dtype),
            ops.constant(background_val, dtype),
        )
```

## Test result on microbenchmark

For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).

The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms

## Test result on llm.c

We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check #129258 and #129399 for more details).

For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)

## Test on PyTorch 2.0 Dashboard

The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).

TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.

cc jansel Chillee eellison 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this issue Jun 25, 2024
… bwd"


Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.

The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
       full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
       scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0);  full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor

allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:

```
    def inner_fn(idx):
        selector_idx = list(idx)
        selector_idx[dim] = 0  # can do this since the index tensor has a single element on the scatter dimension

        selector = selector_loader(selector_idx)
        return ops.where(
            selector == ops.index_expr(idx[dim], torch.int64),
            ops.constant(val, dtype),
            ops.constant(background_val, dtype),
        )
```

## Test result on microbenchmark

For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).

The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms

## Test result on llm.c

We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check #129258 and #129399 for more details).

For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)

## Test on PyTorch 2.0 Dashboard

The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).

TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.

cc jansel Chillee eellison 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@masnesral masnesral added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jun 25, 2024
@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 25, 2024

One thing I'm worried about is that there are some reasons it might be a good thing to delay backward compilation until runtime.

The big example is when we have tangents that are subclasses, and we need to wait until runtime backward to know their proper metadata so we can get a correct backward graph. This doesn't exist today, but I was hoping to find some time to work on it in H2.

I'd be curious what @Chillee thinks - just to throw out alternatives, what do you think of something like cloning the inputs to the kernel onto cpu, and directly copying them from cpu back into the inputs (without needing extra cuda memory?). I'm not sure if there are other cases where inductor compilation needs extra memory though.

pytorchmergebot pushed a commit that referenced this issue Jun 25, 2024
Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.

The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
       full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
       scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0);  full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor

allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:

```
    def inner_fn(idx):
        selector_idx = list(idx)
        selector_idx[dim] = 0  # can do this since the index tensor has a single element on the scatter dimension

        selector = selector_loader(selector_idx)
        return ops.where(
            selector == ops.index_expr(idx[dim], torch.int64),
            ops.constant(val, dtype),
            ops.constant(background_val, dtype),
        )
```

## Test result on microbenchmark

For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).

The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms

## Test result on llm.c

We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check #129258 and #129399 for more details).

For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)

## Test on PyTorch 2.0 Dashboard

The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).

TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.

Pull Request resolved: #129043
Approved by: https://github.com/jansel, https://github.com/Chillee
@eellison
Copy link
Contributor

Maybe the simplest short term fix is to copy input to cpu if we cloning it would increase high water memory mark.

Note, would need: #129496 to calculate this

@ezyang
Copy link
Contributor

ezyang commented Jun 28, 2024

Also, it's easy to end up in situations where you already have a lot of memory allocated and then you are compiling a new region (e.g., compiled autograd), so it's good to figure how to make sure the compiler in general doesn't use more CUDA memory.

@eellison
Copy link
Contributor

Also, this particular mutating kernel (scatter_) is idempotent so we don't really even need to clone

@eellison
Copy link
Contributor

eellison commented Jul 11, 2024

For some more pointers:

in codegen, we annotate the generated triton kernel with "mutated_arg_names":

"mutated_arg_names": mutated_args,

Those are then read in the runtime autotuner: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/runtime/triton_heuristics.py#L677

We'd also want to annotate the generated triton kernel with something like "clone_to_cpu" in cases where clones increase peak memory.

Separately, we'd need to calculate memory in scheduler.py. Would need: #129496 for it to be accurate. And we'd do it after self.compute_last_usage.

@masnesral masnesral self-assigned this Aug 13, 2024
@eellison
Copy link
Contributor

eellison commented Oct 7, 2024

Fixedw ith #136701

@eellison eellison closed this as completed Oct 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants