Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeRO-Inference refresh #4197

Merged
merged 113 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
113 commits
Select commit Hold shift + click to select a range
ade9096
INT4 weight only quantization (#479)
donglinz May 5, 2023
2461449
Moving quantization into post_init_method and add int4 dequantization…
donglinz May 17, 2023
8751edf
Refactor: move int4 code to deepspeed/inference (#528)
donglinz Jun 5, 2023
df1859d
zero++ tutorial PR (#3783)
HeyangQin Jun 21, 2023
d81a6ad
[Fix] _conv_flops_compute when padding is a str and stride=1 (#3169)
zhiruiluo Jun 21, 2023
a8c182a
fix interpolate flops compute (#3782)
cli99 Jun 22, 2023
c4c442f
use `Flops Profiler` to test `model.generate()` (#2515)
CaffreyR Jun 22, 2023
fc9e1ee
revert PR #3611 (#3786)
jeffra Jun 22, 2023
40045dc
bump to 0.9.6
jeffra Jun 22, 2023
49a0a1b
ZeRO++ chinese blog (#3793)
HeyangQin Jun 23, 2023
2c62cb4
remove staging trigger (#3792)
jeffra Jun 23, 2023
4dc65f7
DeepSpeed-Triton for Inference (#3748)
stephen-youn Jun 23, 2023
e1119d8
ZeRO++ (#3784)
HeyangQin Jun 23, 2023
01b843a
adding zero++ to navigation panel of deepspeed.ai (#3796)
HeyangQin Jun 23, 2023
319b64e
Add ZeRO++ Japanese blog (#3797)
tohtana Jun 23, 2023
b4a2c0a
Bug Fixes for autotuner and flops profiler (#1880)
cli99 Jun 23, 2023
b7e1010
Missing strided copy for gated MLP (#3788)
cmikeh2 Jun 23, 2023
e5b1ead
Requires grad checking. (#3789)
jomayeri Jun 23, 2023
9c756cf
bump to 0.10.0
jeffra Jun 23, 2023
a204edc
Fix Bug in transform.cu (#3534)
rraminen Jun 23, 2023
f6e2e38
bug fix: triton importing error (#3799)
stephen-youn Jun 23, 2023
c1a7d3c
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jun 23, 2023
65ed548
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jun 24, 2023
d7ac329
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jun 26, 2023
83f1102
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jun 27, 2023
16555b2
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jun 27, 2023
9d7b654
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jun 28, 2023
c121f90
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jun 29, 2023
f6b2962
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jun 29, 2023
dd6bb04
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jun 29, 2023
6e5a1f1
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jun 29, 2023
1fbbbbf
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jun 30, 2023
e44eb86
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jun 30, 2023
26d8823
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 1, 2023
83e1752
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 3, 2023
b5446a2
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 4, 2023
7f9b2fa
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 5, 2023
9fb79a3
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 5, 2023
9643da2
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 6, 2023
464f99a
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 6, 2023
fbf5068
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 6, 2023
208870a
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 6, 2023
62b47f3
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 6, 2023
c5f62c3
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 7, 2023
78528ae
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 7, 2023
b52c407
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 8, 2023
dfe3b82
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 10, 2023
f086c39
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 10, 2023
04718e4
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 12, 2023
63db286
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 12, 2023
f32c947
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 13, 2023
3b7c583
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 13, 2023
7963bc7
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 13, 2023
441fffe
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 14, 2023
72f37ab
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 14, 2023
f5eb5df
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 17, 2023
e595621
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 18, 2023
427d9eb
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 19, 2023
e9708d6
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 19, 2023
6da3e48
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 19, 2023
c031179
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 19, 2023
0de04b3
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 20, 2023
a733676
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 20, 2023
fd2ca3a
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 21, 2023
9665c46
Rebase
tjruwase Jul 21, 2023
af181e5
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 21, 2023
68935e8
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 22, 2023
37b7743
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 22, 2023
b8153eb
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 24, 2023
8b9815f
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 25, 2023
79781ef
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 25, 2023
2b3664c
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 25, 2023
f6ded65
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 26, 2023
cea2dd9
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 26, 2023
30626f0
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 27, 2023
ccb6817
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 27, 2023
a20815b
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 28, 2023
2d8c49a
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 28, 2023
762a1bb
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Jul 31, 2023
8ac993f
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 1, 2023
9d054da
Workaround qaunt bug
tjruwase Aug 2, 2023
fabf2c0
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 3, 2023
1bc1af2
Fix dequant bug
tjruwase Aug 3, 2023
cc0b0f1
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 3, 2023
fd1cede
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 5, 2023
f694a93
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 7, 2023
4d29d5e
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 8, 2023
636e5e4
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 9, 2023
a39b6e2
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 9, 2023
de64c54
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 9, 2023
cf479bf
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 9, 2023
df4f25c
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 9, 2023
38bb552
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 10, 2023
e86e1c5
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 11, 2023
ec59340
Address PR feedback
tjruwase Aug 11, 2023
8e19577
Use super() __exit__
tjruwase Aug 14, 2023
704e0f0
Fix unit tests
tjruwase Aug 14, 2023
95643e3
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 15, 2023
19018b6
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 16, 2023
b886682
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 16, 2023
7948971
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 17, 2023
9242d36
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 17, 2023
c51a072
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 18, 2023
c816d50
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 19, 2023
d9b1672
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 21, 2023
a746aca
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 22, 2023
f0afcf3
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 22, 2023
956ed2f
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 22, 2023
e1276ab
Merge branch 'master' of github.com:microsoft/DeepSpeed
jeffra Aug 23, 2023
f940e1e
Rebase
tjruwase Aug 23, 2023
0129db2
Fix rebase conflict
tjruwase Aug 24, 2023
a63e92b
Merge branch 'master' into staging-zero-inference-v1
tjruwase Aug 30, 2023
9723d93
Merge branch 'master' into staging-zero-inference-v1
awan-10 Sep 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
DeepSpeed-Triton for Inference (#3748)
Co-authored-by: Stephen Youn <[email protected]>
Co-authored-by: Arash Bakhtiari <[email protected]>
Co-authored-by: Cheng Li <[email protected]>
Co-authored-by: Ethan Doe <[email protected]>
Co-authored-by: yidoe <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
7 people committed Jun 23, 2023
commit 4dc65f7b99f57c3d6e93221011ef5e36dd48a966
2 changes: 1 addition & 1 deletion .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:

- name: Install deepspeed
run: |
pip install .[dev,autotuning]
pip install .[dev,autotuning,triton]
ds_report

- name: Formatting checks
Expand Down
Binary file added blogs/assets/images/triton-bert-base-latency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/assets/images/triton-bert-large-latency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
95 changes: 95 additions & 0 deletions blogs/deepspeed-triton/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# DeepSpeed with Triton compiler

# 1. Overview

We have integrated [Triton](https://github.com/openai/triton), an open source compiler for GPU programming, into DeepSpeed, which further boosts the inference speed of BERT-like models in float16 precision.
By replacing some CUDA kernels or torch operators with Triton kernels, we achieved 1.14\~1.68x speedup (or 12\~41% latency reduction) for different models and GPUs, as shown in Table 1.

<div align="center">

| Hardware | Bert-base | Bert-large | Roberta-base | Roberta-large |
|----------|:------:|:------:|:------:|:------:|
| A100 |1.65x | 1.68x | 1.53x | 1.61x |
| V100 | 1.29x | 1.14x | 1.23x | 1.21x |

Table 1. The average speedup (see NOTE below for more detail)


</div>

For those transformer operators in float16, we have implemented kernels written in Triton language that replace ordinary CUDA kernels or torch operators.
The Triton kernels we implemented include softmax, layer-normalization, residual-addition and all the matrix multiplications except MLP layers (see NOTE below for details).
In our experiments, Triton kernels help to reduce the average latecy (over difference sequence lengths) by 6\~24% (depending on model and hardware) when compared to the latency with CUDA-only kernels.


Figures below show the latency reduction in more detail.
Figure 1 visualizes latency reduction in different sequence lengths in A100 GPU for Bert-base model.
The baseline (blue) is from Huggingface transformers without any kernel injection, the orange is from Deepspeed with CUDA-only kernels and the gray is from Deepspeed with Triton kernels.
Figure 2 shows the same plot for Bert-large model in A100 GPU.

<div align="center">

<img src="../assets/images/triton-bert-base-latency.png" width="500px" alt="triton-bert-base-latency"/>

*Figure 1: Normalized P90 latency for Bert-base model in A100 GPU across different sequence lengths*

<img src="../assets/images/triton-bert-large-latency.png" width="500px" alt="triton-bert-large-latency"/>

*Figure 2: Normalized P90 latency for Bert-large model in A100 GPU across different sequence lengths*

</div>


Next, we dive deeper into this new feature in DeepSpeed.

# 2. How to use Triton in Deepspeed

You can enable Triton compilers to optimize these kernels by setting a flag in the DeepSpeed config file.

```
pipe = pipeline('fill-mask', model='bert-base-cased', framework='pt', device=0)
pipe.model = deepspeed.init_inference(pipe.model,
dtype=torch.float16,
replace_with_kernel_inject=True,
enable_cuda_graph=True,
use_triton=True,
triton_autotune=True,
max_out_tokens=pipe.tokenizer.model_max_length)
```


## Running BERT inference with Triton kernels

We use an example of Bert-base here.

```python
pip install deepspeed[triton]

git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/inference/huggingface/fill-mask

deepspeed --num_gpus 1 test-bert.py --triton
```

To run a performance benchmark, you can use the following command:

```python
pip install deepspeed[triton]

git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/benchmarks/inference

deepspeed --num_gpus 1 triton-bert-benchmark.py --model bert-base-cased --dtype fp16 --kernel-inject --deepspeed --graphs --triton
```

# NOTE
<!-- **_NOTE:_** -->
* For more information on how to use DeepSpeed, please visit our [GitHub Page](https://github.com/microsoft/DeepSpeedExamples) and our [website](https://www.deepspeed.ai/), where you can find blog posts, tutorials, and documentation.

* This feature is currently only supported for BERT, Roberta and other BERT-like models, and not for text-generation models yet.

* To achieve the best performance with Triton optimization, you need to activate CUDA graph and ‘triton_autotune’ in the DeepSpeed config. CUDA graph prevents the overhead of JIT compilation and a deep call stack in Triton. ‘triton_autotune’ executes an initial step to find the most suitable parameters for Triton kernels, which may take some time.

* We used [Triton 2.0.0.post1 release](https://pypi.org/project/triton/2.0.0.post1/) in our experiments.

* In our experiments, we used a batch size of 1, a sequence length range of 8 to 512, and a ‘fill-mask’ task. Table 1 shows the average P90 latency over the entire sequence length range, while Figures 1 and 2 show the P90 latency for specific sub-ranges. The baseline is the Huggingface transformers without any optimization. The speedup is calculated as (baseline P90 latency)/(DeepSpeed-Triton P90 Latency). We found that the CUDA kernel in MLP performed better than the Triton kernel in our experiments, so we used a hybrid approach that combines both kernels when Triton is enabled in the DeepSpeed config.
6 changes: 6 additions & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from torch.optim.lr_scheduler import _LRScheduler
from packaging import version as pkg_version

try:
import triton # noqa: F401
HAS_TRITON = True
except ImportError:
HAS_TRITON = False

from . import ops
from . import module_inject

Expand Down
19 changes: 19 additions & 0 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import torch
import deepspeed
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from pydantic import Field
Expand Down Expand Up @@ -152,6 +153,18 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
can run faster using the graph replay method.
"""

use_triton: bool = False
"""
Use this flag to use triton kernels for inference ops.
"""

triton_autotune: bool = False
"""
Use this flag to enable triton autotuning.
Turning it on is better for performance but increase the 1st runtime for
autotuning.
"""

zero: DeepSpeedZeroConfig = {}
"""
ZeRO configuration to use with the Inference Engine. Expects a dictionary
Expand Down Expand Up @@ -279,6 +292,12 @@ def moe_backward_compat(cls, field_value, values):
return DeepSpeedMoEConfig(moe=field_value)
return field_value

@validator("use_triton")
def has_triton(cls, field_value, values):
if field_value and not deepspeed.HAS_TRITON:
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
return field_value

class Config:
# Get the str representation of the datatype for serialization
json_encoders = {torch.dtype: lambda x: str(x)}
1 change: 1 addition & 0 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def forward(self, *inputs, **kwargs):
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)

else:
outputs = self.module(*inputs, **kwargs)

Expand Down
22 changes: 18 additions & 4 deletions deepspeed/model_implementations/transformers/ds_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
import deepspeed
if deepspeed.HAS_TRITON:
from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP
from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention

inference_module = None

Expand Down Expand Up @@ -55,14 +59,24 @@ def __init__(self,

if DeepSpeedTransformerInference.layer_id == 1:
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
if deepspeed.HAS_TRITON and self.config.use_triton:
log_dist(f"Injecting Triton kernels ...", [0])

if self.config.bigscience_bloom:
self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count)
assert not self.config.use_triton
else:
self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
merge_count)
self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
mlp_extra_grouping)
if deepspeed.HAS_TRITON and self.config.use_triton:
self.attention = TritonSelfAttention(self.config)
else:
self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
merge_count)

if deepspeed.HAS_TRITON and self.config.use_triton:
self.mlp = TritonMLP(self.config)
else:
self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
mlp_extra_grouping)

device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
if self.config.set_empty_params:
Expand Down
14 changes: 13 additions & 1 deletion deepspeed/module_inject/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch

import deepspeed
from deepspeed.ops.transformer.inference.config import DeepSpeedInferenceConfig
from deepspeed.accelerator import get_accelerator

Expand Down Expand Up @@ -79,6 +80,10 @@ def __init__(self, policy, config, model_config, layer_id, child):
self.input_nb = None

self.mp_group = None
self.use_triton = False

# Triton
self.use_triton = config.use_triton and deepspeed.HAS_TRITON

def create_ds_model_config(self):
self.set_hidden_heads(*self.policy.get_hidden_heads())
Expand Down Expand Up @@ -110,7 +115,14 @@ def create_ds_model_config(self):
use_mup=self.use_mup,
return_single_tuple=self.return_single_tuple,
set_empty_params=self.config.set_empty_params,
transposed_mode=self.config.transposed_mode)
transposed_mode=self.config.transposed_mode,
use_triton=self.use_triton,
triton_autotune=self.config.triton_autotune)

if self.use_triton and deepspeed.HAS_TRITON:
if not self.config.triton_autotune:
from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
fp16_matmul.skip_autotune()

return self.ds_model_config

Expand Down
1 change: 1 addition & 0 deletions deepspeed/module_inject/containers/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, **kwargs):
# All model specific things should be defined here instead of the base class.
self.return_tuple = True
self.triangular_masking = False
self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON

def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config
Expand Down
1 change: 1 addition & 0 deletions deepspeed/module_inject/containers/distil_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, **kwargs):
# All model specific things should be defined here instead of the base class.
self.triangular_masking = False
self.return_single_tuple = True
self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON

def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config
Expand Down
7 changes: 6 additions & 1 deletion deepspeed/ops/transformer/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class DeepSpeedInferenceConfig(TransformerConfig):
scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation.
return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor
bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture.
use_triton: This flag is to enable triton kernels in inference or not.
"""

def __init__(self,
Expand Down Expand Up @@ -77,7 +78,9 @@ def __init__(self,
scale_attn_by_inverse_layer_idx=False,
return_single_tuple=False,
set_empty_params=False,
transposed_mode=False):
transposed_mode=False,
use_triton=False,
triton_autotune=False):
super(DeepSpeedInferenceConfig,
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
num_hidden_layers)
Expand Down Expand Up @@ -109,6 +112,8 @@ def __init__(self,
self.return_single_tuple = return_single_tuple
self.set_empty_params = set_empty_params
self.transposed_mode = transposed_mode
self.use_triton = use_triton
self.triton_autotune = triton_autotune

@classmethod
def from_dict(cls, json_object):
Expand Down
9 changes: 7 additions & 2 deletions deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
import deepspeed


class GELUGemmOp(BaseOp):
Expand All @@ -14,9 +15,13 @@ def __init__(self, config: DeepSpeedInferenceConfig):
super(GELUGemmOp, self).__init__(config)
try:
if self.config.dtype in [torch.float16, torch.int8]:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore
if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
from deepspeed.ops.transformer.inference.triton.ops import fused_gemm_gelu as _triton_fused_gemm_gelu
self.fused_gemm_gelu = _triton_fused_gemm_gelu # type: ignore
else:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16 # type: ignore
else:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp32 # type: ignore
except AttributeError:
Expand Down
21 changes: 21 additions & 0 deletions deepspeed/ops/transformer/inference/op_binding/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
import deepspeed


class LinearOp(BaseOp):
Expand All @@ -14,6 +15,14 @@ def __init__(self, config: DeepSpeedInferenceConfig):
super(LinearOp, self).__init__(config)
try:
if self.config.dtype in [torch.float16, torch.int8]:
if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
from deepspeed.ops.transformer.inference.triton.ops import linear_func as _triton_linear_func
self.linear_func = _triton_linear_func
triton_autotune = config.triton_autotune and config.layer_id == 0
if triton_autotune:
__class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size)
else:
self.linear_func = self.inference_module.linear_layer_fp16
self.linear_func = self.inference_module.linear_layer_fp16
elif self.config.dtype == torch.bfloat16:
self.linear_func = self.inference_module.linear_layer_bf16
Expand All @@ -37,3 +46,15 @@ def forward(self,
qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads,
self.config.transposed_mode)
return qkv_out

@staticmethod
def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16):
from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul
seqlen = [(min_seqlen + i)
for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
Fp16Matmul._read_autotune_table()
for N in seqlen:
A = torch.randn((N, hidden_size), dtype=dtype, device='cuda')
B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda')
matmul(A, B)
Fp16Matmul._update_autotune_table()
4 changes: 3 additions & 1 deletion deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def __init__(self, config: DeepSpeedInferenceConfig):
super(MLPGemmOp, self).__init__(config)
try:
if self.config.norm_type == NormType.LayerNorm:
if self.config.dtype in [torch.float16, torch.int8]:
if self.config.dtype in [
torch.float16, torch.int8
]: # non-triton cuda kernel has a higher performance in MLP than mlp_gemm_func in triton.ops
self.mlp_gemm_func = self.inference_module.mlp_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.mlp_gemm_func = self.inference_module.mlp_gemm_bf16
Expand Down
Loading
Loading