Skip to content

Commit

Permalink
init lightllm repo
Browse files Browse the repository at this point in the history
  • Loading branch information
XHPlus committed Jul 22, 2023
1 parent d2ccc02 commit 360eb45
Show file tree
Hide file tree
Showing 107 changed files with 7,898 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__pycache__/
.pyc
build
dist
*.egg-info
53 changes: 53 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
FROM debian:bullseye-slim as pytorch-install
ARG PYTORCH_VERSION=2.0.0
ARG PYTHON_VERSION=3.9
ARG CUDA_VERSION=11.8
ARG MAMBA_VERSION=23.1.0-1
ARG CUDA_CHANNEL=nvidia
ARG INSTALL_CHANNEL=pytorch
ARG TARGETPLATFORM

ENV PATH /opt/conda/bin:$PATH

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
ccache \
curl \
git && \
rm -rf /var/lib/apt/lists/*


RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh

RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch==$PYTORCH_VERSION "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
esac && \
/opt/conda/bin/conda clean -ya

FROM nvidia/cuda:11.8.0-devel-ubuntu20.04 as base

ENV PATH=/opt/conda/bin:$PATH \
CONDA_PREFIX=/opt/conda

WORKDIR /usr/src

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
libssl-dev \
ca-certificates \
make \
&& rm -rf /var/lib/apt/lists/*

COPY --from=pytorch-install /opt/conda /opt/conda
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt && rm -rf requirements.txt
RUN apt update -y && apt install -y vim wget curl git
151 changes: 151 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
<div align="center">
<picture>
<img alt="LightLLM" src="assets/lightllm.drawio.png" width=90%>
</picture>
</div>

---
LightLLM is a Python-based LLM (Large Language Model) inference and serving framework, notable for its lightweight design, easy scalability, and high-speed performance. LightLLM harnesses the strengths of numerous well-regarded open-source implementations, including but not limited to FasterTransformer, TGI, vLLM, and FlashAttention.

## Features

- Tri-process asynchronous collaboration: tokenization, model inference, and detokenization are performed asynchronously, leading to a considerable improvement in GPU utilization.
- Nopad (Unpad): offers support for nopad attention operations across multiple models to efficiently handle requests with large length disparities.
- Dynamic Batch: enables dynamic batch scheduling of requests
- [FlashAttention](https://github.com/Dao-AILab/flash-attention): incorporates FlashAttention to improve speed and reduce GPU memory footprint during inference.
- Tensor Parallelism: utilizes tensor parallelism over multiple GPUs for faster inference.
- [Token Attention](./docs/TokenAttention.md): implements token-wise's KV cache memory management mechanism, allowing for zero memory waste during inference.
- High-performance Router: collaborates with Token Attention to meticulously manage the GPU memory of each token, thereby optimizing system throughput.

## Supported Model List

- [BLOOM](https://huggingface.co/bigscience/bloom)
- [LLaMA](https://github.com/facebookresearch/llama)
- [LLaMA V2](https://huggingface.co/meta-llama)

## Get started

### Requirements

The code has been tested with Pytorch>=1.3, CUDA 11.8, and Python 3.9. To install the necessary dependencies, please refer to the provided **requirements.txt** and follow the instructions as

~~~shell
pip install -r requirements.txt
~~~

A more straightforward approach is to use the official Docker container:

~~~shell
docker build -t image_name .
docker run -it --gpus all -p 8080:80 -v your_local_path:/data/ image_name /bin/bash
~~~

### Installation

- Install from the source code by

~~~shell
python setup.py install
~~~

The code has been tested on a range of GPUs including V100, A100, A800, 4090, and H800. If you are running the code on V100, A100, A800, etc., we recommend using triton==2.0.0.dev20221202. If you are running the code on 4090, H800, etc., it is necessary to compile and install the source code of [triton==2.1.0](https://github.com/openai/triton/tree/main) from the GitHub repository. If the code doesn't work on other GPUs, try modifying the triton kernel used in model inference.

### RUN LLaMA
With efficient Routers and TokenAttention, LightLLM can be deployed as a service and achieve the state-of-the-art throughput performance.

Launch the server:

~~~shell
python -m lightllm.server.api_server --model_dir /path/llama-7B --tp 1 --max_total_token_num 120000
~~~

The parameter `max_total_token_num` is influenced by the GPU memory of the deployment environment. A larger value for this parameter allows for the processing of more concurrent requests, thereby increasing system concurrency. For more startup parameters, please refer to [api_server.py](lightllm/server/api_server.py).

To initiate a query in the shell:

~~~shell
curl 127.0.0.1:8000/generate \
-X POST \
-d '{"inputs":"What is AI?","parameters":{"max_new_tokens":17, "frequency_penalty":1}}' \
-H 'Content-Type: application/json'
~~~

To query from Python:

~~~python
import time
import requests
import json

url = 'http:https://localhost:8000/generate'
headers = {'Content-Type': 'application/json'}
data = {
'inputs': 'What is AI?',
"parameters": {
'do_sample': False,
'ignore_eos': False,
'max_new_tokens': 1024,
}
}
response = requests.post(url, headers=headers, data=json.dumps(data))
if response.status_code == 200:
print(response.json())
else:
print('Error:', response.status_code, response.text)
~~~

## Performance

### Service Performance

We compared the service performance of LightLLM and vLLM==0.1.2 on LLaMA-7B using an A800 with 80G GPU memory.

To begin, prepare the data as follows:

~~~shell
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
~~~

Launch the service:

~~~shell
python -m lightllm.server.api_server --model_dir /path/llama-7b --tp 1 --max_total_token_num 121060 --tokenizer_mode auto
~~~

Evaluation:

~~~shell
cd test
python benchmark_serving.py --tokenizer /path/llama-7b --dataset /path/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000 --request-rate 200
~~~

The performance comparisons results are presented below:

| vLLM | LightLLM |
| ---------------------------------------------------- | ----------------------------------------------------- |
| Total time: 361.79 s<br/>Throughput: 5.53 requests/s | Total time: 188.85 s<br/>Throughput: 10.59 requests/s |

### Static inference performance

For debugging, we offer static performance testing scripts for various models. For instance, you can evaluate the inference performance of the LLaMA model by

~~~shell
cd test/lightllama
python test_model_infer.py
~~~

### FAQ

- In case the LLaMA tokenizer fails to load, consider resolving this by running the command 'pip install protobuf==3.20.0'.

## License

This repository is released under the [Apache-2.0](LICENSE) license.

## Acknowledgement

We learned a lot from the following projects when developing LightLLM.
- [Faster Transformer](https://github.com/NVIDIA/FasterTransformer)
- [Text Generation Inference](https://github.com/huggingface/text-generation-inference)
- [vLLM](https://github.com/vllm-project/vllm)
- [Flash Attention 1&2](https://github.com/Dao-AILab/flash-attention)
Binary file added assets/att.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/lightllm.drawio.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 26 additions & 0 deletions benchmark.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#### lightllm

#### Launch service

~~~shell
python -m lightllm.server.api_server --model_dir /path/llama-7b --tp 1 --max_total_token_num 121060 --tokenizer_mode auto
~~~

#### Evaluation

~~~shell
python benchmark_serving.py --tokenizer /path/llama-7b --dataset /path/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000 --request-rate 200
~~~

#### vllm

#### Launch service
~~~shell
python -m vllm.entrypoints.api_server --model /path/llama-7b --swap-space 16 --disable-log-requests --port 9009
~~~

#### Evaluation

~~~shell
python benchmark_serving_vllm.py --backend vllm --tokenizer /path/llama-7b --dataset /path/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000 --request-rate 200 --host 127.0.0.1 --port 9009
~~~
31 changes: 31 additions & 0 deletions docs/TokenAttention.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# TokenAttention

Transformers form the basis of modern large language models. During autoregressive decoding, these models cache key-value tensors of context tokens into GPU memory to facilitate fast generation of the next token. However, these caches occupy significant GPU memory. The unpredictable nature of cache size, due to the variability in the length of each request, exacerbates the issue, resulting in significant memory fragmentation in the absence of a suitable memory management mechanism.

To alleviate this issue, PagedAttention was proposed to store the KV cache in non-contiguous memory spaces. It partitions the KV cache of each sequence into multiple blocks, with each block containing the keys and values for a fixed number of tokens. This approach effectively controls memory waste within the last block during attention computation. While PagedAttention alleviates memory fragmentation to some extent, it still leaves room for memory waste. Additionally, when handling multiple high-concurrency requests, the allocation and deallocation of memory blocks fall short of efficiency, leading to suboptimal memory utilization.

To address the above challenges, we introduce TokenAttention, an attention mechanism that manages key and value caching at the token level. Compared to PagedAttention, our TokenAttention not only minimizes memory fragmentation and enables efficient memory sharing but also facilitates efficient memory allocation and deallocation. It allows for more precise and fine-grained memory management, thus optimizing memory utilization.

<div align="center">

| Features | PagedAttention | TokenAttention |
| -------------------------------------------- | :------------: | :------------: |
| Low memory fragmentation | &#x2713; | &#x2713; |
| Efficient memory sharing | &#x2713; | &#x2713; |
| Efficient memory allocation and deallocation | &#x2717; | &#x2713; |
| Fine-grained memory management | &#x2717; | &#x2713; |
</div>

The operation mechanism of TokenAttention is illustrated in the figure below:

<div align="center">
<img alt="TokenAtt" src="../assets/att.gif" width=60%>
</div>

During model initialization, the KV cache is pre-allocated based on the user-set **max_total_token_num** and a Token Table is created to record the actual storage locations of input tokens.

When handling new requests, the system first checks for available contiguous space in the pre-allocated Token cache for storing the key-value (KV) cache. TokenAttention favors assigning contiguous graphics memory space for requests to minimize memory access during the inference process. Only when contiguous space is insufficient does it allocate non-contiguous graphics memory for the requests. Since memory management is conducted on a token-by-token basis, TokenAttention achieves nearly zero waste, yielding higher throughput compared to vllm.

We have implemented an efficient TokenAttention operator using OpenAI Triton. When provided with a query vector, this operator can efficiently retrieve the corresponding KV cache based on the Token Table and conduct the attention computation.

Upon completion of requests, the corresponding graphics memory can be quickly freed by deleting their records on the Token Table, which makes way for scheduling new requests. Given that TokenAttention pre-allocates all KV cache space during model initialization, it can efficiently release memory for completed requests and merge different batches of requests during dynamic scheduling, thereby effectively maximizing GPU utilization.
6 changes: 6 additions & 0 deletions format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import os
import glob

for filename in glob.glob('./**/*.py', recursive=True):
print(filename)
os.system(f"autopep8 --max-line-length 140 --in-place --aggressive --aggressive {filename}")
Empty file added lightllm/__init__.py
Empty file.
Empty file added lightllm/common/__init__.py
Empty file.
Empty file.
6 changes: 6 additions & 0 deletions lightllm/common/configs/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

_DEFAULT_MAX_INPUT_ADD_OUTPUT_LEN = 1024 * 5

setting = {
"max_req_total_len" : _DEFAULT_MAX_INPUT_ADD_OUTPUT_LEN
}
5 changes: 5 additions & 0 deletions lightllm/common/gqa_mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .mem_manager import MemoryManager

class GQAMemoryManager(MemoryManager):
def __init__(self, size, dtype, key_value_head_num, head_dim, layer_num):
super().__init__(size, dtype, key_value_head_num, head_dim, layer_num)
8 changes: 8 additions & 0 deletions lightllm/common/infer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def init_bloc(b_loc, b_seq_len, max_len_in_batch, alloc_mem_index):
start_index = 0
b_seq_len_numpy = b_seq_len.cpu().numpy()
for i in range(len(b_seq_len)):
cur_seq_len = b_seq_len_numpy[i]
b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index + cur_seq_len]
start_index += cur_seq_len
return
67 changes: 67 additions & 0 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch


class MemoryManager:
def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
self._mem_cum_sum = torch.empty((size,), dtype=torch.int32, device="cuda")
self.indexes = torch.arange(0, size, dtype=torch.long, device="cuda")
self.can_use_mem_size = size
self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)]
self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)]

@torch.no_grad()
def alloc(self, need_size):
if need_size > self.can_use_mem_size:
print(f'warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}')
return None

torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum)
select_index = torch.logical_and(self._mem_cum_sum <= need_size, self.mem_state == 1)
select_index = self.indexes[select_index]
self.mem_state[select_index] = 0
self.can_use_mem_size -= len(select_index)
return select_index

@torch.no_grad()
def alloc_contiguous(self, need_size):
if need_size > self.can_use_mem_size:
print(f'warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}')
return None

torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum)
sum_size = len(self._mem_cum_sum)
loc_sums = self._mem_cum_sum[need_size - 1:] - self._mem_cum_sum[0:sum_size - need_size + 1] + self.mem_state[0:sum_size - need_size + 1]
can_used_loc = self.indexes[0:sum_size - need_size + 1][loc_sums == need_size]
if can_used_loc.shape[0] == 0:
# print(f'warn no enough cache to contiguous need_size {need_size} left_size {self.can_use_mem_size}')
return None
start_loc = can_used_loc[0]
select_index = self.indexes[start_loc : start_loc + need_size]

self.mem_state[select_index] = 0
self.can_use_mem_size -= len(select_index)
start = start_loc.item()
end = start + need_size
return select_index, start, end

@torch.no_grad()
def free(self, free_index):
"""_summary_
Args:
free_index (torch.Tensor): _description_
"""
self.can_use_mem_size += free_index.shape[0]
self.mem_state[free_index] = 1
if self.can_use_mem_size == len(self.mem_state):
print(f"freed all gpu mem size {self.can_use_mem_size}")
# print(f"free state {self.can_use_mem_size} all {len(self.mem_state)}")
return

@torch.no_grad()
def free_all(self):
self.can_use_mem_size = len(self.mem_state)
self.mem_state[:] = 1


Empty file.
Loading

0 comments on commit 360eb45

Please sign in to comment.