Skip to content

Commit

Permalink
remove the support for triton==2.0.0 (ModelTC#395)
Browse files Browse the repository at this point in the history
Co-authored-by: wangzaijun <[email protected]>
  • Loading branch information
hiworldwzj and wangzaijun committed Apr 12, 2024
1 parent 390ac96 commit bcb3212
Show file tree
Hide file tree
Showing 15 changed files with 1,040 additions and 1,516 deletions.
16 changes: 9 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,19 @@ You can use the official Docker container to run the model more easily. To do th
python setup.py install
~~~

The code has been tested on a range of GPUs including A100, A800, 4090, and H800. If you are running the code on A100, A800, etc., we recommend using triton==2.1.0 or triton==2.0.0.dev20221202. If you are running the code on H800, etc., it is necessary to compile and install the source code of [triton==2.1.0](https://github.com/openai/triton/tree/main) from the GitHub repository. If the code doesn't work on other GPUs, try modifying the triton kernel used in model inference.
- Install Triton Package

use triton==2.1.0 (Better performance, but the code is under continuous development and may be unstable.)

- Install Triton Package

The code has been tested on a range of GPUs including V100, A100, A800, 4090, and H800. If you are running the code on A100, A800, etc., we recommend using triton==2.1.0.

~~~shell
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
pip install triton==2.1.0 --no-deps
~~~

use triton==2.0.0.dev20221202 (This version has a memory leak bug. Refer to the [issue #209](https://github.com/ModelTC/lightllm/issues/209) for the fix method. )
If you are running the code on H800 or V100., we recommend using triton-nightly, triton-nightly has a significant CPU bottleneck, leading to high decode latency at low concurrency levels. You can observe [this issue](https://github.com/openai/triton/issues/3619) and [fix PR](https://github.com/openai/triton/pull/3638).You can try modifying and compiling the
source code yourself to resolve this issue.
~~~shell
pip install triton==2.0.0.dev20221202
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps
~~~

### RUN LLaMA
Expand Down
46 changes: 15 additions & 31 deletions lightllm/models/bloom/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,37 +63,21 @@ def _context_attention_kernel(
self, q, kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None
) -> torch.Tensor:
o_tensor = torch.empty_like(q) if out is None else out
import triton
if triton.__version__ >= "2.1.0":
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :],
infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.b_req_idx,
layer_weight.tp_alibi,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
infer_state.req_manager.req_to_token_indexs,
)
elif triton.__version__ == "2.0.0":
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
layer_weight.tp_alibi,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
)
else:
assert False

context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :],
infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.b_req_idx,
layer_weight.tp_alibi,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
infer_state.req_manager.req_to_token_indexs,
)
return o_tensor

def _token_attention_kernel(
Expand Down
Loading

0 comments on commit bcb3212

Please sign in to comment.