This repository is a deployment project of BEVFormer on TensorRT, supporting FP32/FP16/INT8 inference. Meanwhile, in order to improve the inference speed of BEVFormer on TensorRT, this project implements some TensorRT Ops that support nv_half, nv_half2 and INT8. With the accuracy almost unaffected, the inference speed of the BEVFormer base can be increased by nearly four times, the engine size can be reduced by more than 90%, and the GPU memory usage can be saved by more than 80%. In addition, the project also supports common 2D object detection models in MMDetection, which support INT8 Quantization and TensorRT Deployment with a small number of code changes.
Model | Data | Batch Size | NDS/mAP | FPS | Size (MB) | Memory (MB) | Device |
---|---|---|---|---|---|---|---|
BEVFormer tiny download |
NuScenes | 1 | NDS: 0.354 mAP: 0.252 |
15.9 | 383 | 2167 | RTX 3090 |
BEVFormer small download |
NuScenes | 1 | NDS: 0.478 mAP: 0.370 |
5.1 | 680 | 3147 | RTX 3090 |
BEVFormer base download |
NuScenes | 1 | NDS: 0.517 mAP: 0.416 |
2.4 | 265 | 5435 | RTX 3090 |
Model | Data | Batch Size | Float/Int | Quantization Method | NDS/mAP | FPS | Size (MB) | Memory (MB) | Device |
---|---|---|---|---|---|---|---|---|---|
BEVFormer tiny | NuScenes | 1 | FP32 | - | NDS: 0.354 mAP: 0.252 |
37.9 | 136 | 2159 | RTX 3090 |
BEVFormer tiny | NuScenes | 1 | FP16 | - | NDS: 0.354 mAP: 0.252 |
69.2 |
74 |
1729 |
RTX 3090 |
BEVFormer tiny | NuScenes | 1 | FP32/INT8 | PTQ max/per-tensor | NDS: 0.353 mAP: 0.249 |
65.1 |
58 |
1737 |
RTX 3090 |
BEVFormer tiny | NuScenes | 1 | FP16/INT8 | PTQ max/per-tensor | NDS: 0.353 mAP: 0.249 |
70.7 |
54 |
1665 |
RTX 3090 |
BEVFormer small | NuScenes | 1 | FP32 | - | NDS: 0.478 mAP: 0.370 |
6.6 | 245 | 4663 | RTX 3090 |
BEVFormer small | NuScenes | 1 | FP16 | - | NDS: 0.478 mAP: 0.370 |
12.8 |
126 |
3719 |
RTX 3090 |
BEVFormer small | NuScenes | 1 | FP32/INT8 | PTQ max/per-tensor | NDS: 0.476 mAP: 0.367 |
8.7 |
158 |
4079 |
RTX 3090 |
BEVFormer small | NuScenes | 1 | FP16/INT8 | PTQ max/per-tensor | NDS: 0.477 mAP: 0.368 |
13.3 |
106 |
3441 |
RTX 3090 |
BEVFormer base * | NuScenes | 1 | FP32 | - | NDS: 0.517 mAP: 0.416 |
1.5 | 1689 | 13893 | RTX 3090 |
BEVFormer base | NuScenes | 1 | FP16 | - | NDS: 0.517 mAP: 0.416 |
1.8 |
849 |
11865 |
RTX 3090 |
BEVFormer base * | NuScenes | 1 | FP32/INT8 | PTQ max/per-tensor | NDS: 0.516 mAP: 0.414 |
1.8 |
426 |
12429 |
RTX 3090 |
BEVFormer base * | NuScenes | 1 | FP16/INT8 | PTQ max/per-tensor | NDS: 0.515 mAP: 0.414 |
2.2 |
244 |
11011 |
RTX 3090 |
* Out of Memory
when onnx2trt with TensorRT-8.5.1.7, but they convert successfully with TensorRT-8.4.3.1. So the version of these engines is TensorRT-8.4.3.1.
FP16 Plugins with nv_half
Model | Data | Batch Size | Float/Int | Quantization Method | NDS/mAP | FPS/Improve | Size (MB) | Memory (MB) | Device |
---|---|---|---|---|---|---|---|---|---|
BEVFormer tiny | NuScenes | 1 | FP32 | - | NDS: 0.354 mAP: 0.252 |
41.4 |
135 |
1699 |
RTX 3090 |
BEVFormer tiny | NuScenes | 1 | FP16 | - | NDS: 0.354 mAP: 0.252 |
76.8 |
73 |
1203 |
RTX 3090 |
BEVFormer tiny | NuScenes | 1 | FP32/INT8 | PTQ max/per-tensor | NDS: 0.352 mAP: 0.249 |
84.0 |
57 |
1077 |
RTX 3090 |
BEVFormer tiny | NuScenes | 1 | FP16/INT8 | PTQ max/per-tensor | NDS: 0.353 mAP: 0.250 |
96.1 |
54 |
1109 |
RTX 3090 |
BEVFormer small | NuScenes | 1 | FP32 | - | NDS: 0.478 mAP: 0.370 |
7.0 |
246 |
2645 |
RTX 3090 |
BEVFormer small | NuScenes | 1 | FP16 | - | NDS: 0.479 mAP: 0.370 |
16.3 |
124 |
1789 |
RTX 3090 |
BEVFormer small | NuScenes | 1 | FP32/INT8 | PTQ max/per-tensor | NDS: 0.477 mAP: 0.368 |
10.4 |
157 |
1925 |
RTX 3090 |
BEVFormer small | NuScenes | 1 | FP16/INT8 | PTQ max/per-tensor | NDS: 0.477 mAP: 0.368 |
17.8 |
103 |
1627 |
RTX 3090 |
BEVFormer base | NuScenes | 1 | FP32 | - | NDS: 0.516 mAP: 0.416 |
3.2 |
283 |
5175 |
RTX 3090 |
BEVFormer base | NuScenes | 1 | FP16 | - | NDS: 0.515 mAP: 0.415 |
6.5 |
144 |
3323 |
RTX 3090 |
BEVFormer base | NuScenes | 1 | FP32/INT8 | PTQ max/per-tensor | NDS: 0.516 mAP: 0.414 |
4.2 |
188 |
3139 |
RTX 3090 |
BEVFormer base | NuScenes | 1 | FP16/INT8 | PTQ max/per-tensor | NDS: 0.516 mAP: 0.414 |
5.8 |
125 |
3073 |
RTX 3090 |
FP16 Plugins with nv_half2
Model | Data | Batch Size | Float/Int | Quantization Method | NDS/mAP | FPS | Size (MB) | Memory (MB) | Device |
---|---|---|---|---|---|---|---|---|---|
BEVFormer tiny | NuScenes | 1 | FP16 | - | NDS: 0.354 mAP: 0.251 |
90.7 |
73 |
1211 |
RTX 3090 |
BEVFormer tiny | NuScenes | 1 | FP16/INT8 | PTQ max/per-tensor | NDS: 0.353 mAP: 0.250 |
98.4 |
54 |
1109 |
RTX 3090 |
BEVFormer small | NuScenes | 1 | FP16 | - | NDS: 0.478 mAP: 0.370 |
18.2 |
124 |
1843 |
RTX 3090 |
BEVFormer small | NuScenes | 1 | FP16/INT8 | PTQ max/per-tensor | NDS: 0.477 mAP: 0.368 |
18.4 |
105 |
1629 |
RTX 3090 |
BEVFormer base | NuScenes | 1 | FP16 | - | NDS: 0.515 mAP: 0.415 |
7.3 |
144 |
3323 |
RTX 3090 |
BEVFormer base | NuScenes | 1 | FP16/INT8 | PTQ max/per-tensor | NDS: 0.516 mAP: 0.414 |
6.7 |
124 |
2437 |
RTX 3090 |
This project also supports common 2D object detection models in MMDetection with little modification. The following are deployment examples of YOLOx and CenterNet.
Model | Data | Framework | Batch Size | Float/Int | Quantization Method | mAP | FPS | Size (MB) | Memory (MB) | Device |
---|---|---|---|---|---|---|---|---|---|---|
YOLOx download |
COCO | PyTorch | 32 | FP32 | - | mAP: 0.506 mAP_50: 0.685 mAP_75: 0.55 mAP_s: 0.32 mAP_m: 0.557 mAP_l: 0.667 |
63.1 | 379 | 7617 | RTX 3090 |
YOLOx | COCO | TensorRT | 32 | FP32 | - | mAP: 0.506 mAP_50: 0.685 mAP_75: 0.55 mAP_s: 0.32 mAP_m: 0.556 mAP_l: 0.667 |
71.3 | 546 | 9943 | RTX 3090 |
YOLOx | COCO | TensorRT | 32 | FP16 | - | mAP: 0.506 mAP_50: 0.685 mAP_75: 0.55 mAP_s: 0.32 mAP_m: 0.556 mAP_l: 0.668 |
296.8 | 192 | 4567 | RTX 3090 |
YOLOx | COCO | TensorRT | 32 | FP32/INT8 | PTQ max/per-tensor | mAP: 0.488 mAP_50: 0.671 mAP_75: 0.538 mAP_s: 0.311 mAP_m: 0.538 mAP_l: 0.649 |
556.4 | 99 | 5225 | RTX 3090 |
YOLOx | COCO | TensorRT | 32 | FP16/INT8 | PTQ max/per-tensor | mAP: 0.479 mAP_50: 0.662 mAP_75: 0.53 mAP_s: 0.307 mAP_m: 0.533 mAP_l: 0.634 |
550.6 | 99 | 5119 | RTX 3090 |
Model | Data | Framework | Batch Size | Float/Int | Quantization Method | mAP | FPS | Size (MB) | Memory (MB) | Device |
---|---|---|---|---|---|---|---|---|---|---|
CenterNet download |
COCO | PyTorch | 32 | FP32 | - | mAP: 0.299 mAP_50: 0.466 mAP_75: 0.319 mAP_s: 0.106 mAP_m: 0.337 mAP_l: 0.463 |
337.4 | 56 | 5171 | RTX 3090 |
CenterNet | COCO | TensorRT | 32 | FP32 | - | mAP: 0.299 mAP_50: 0.466 mAP_75: 0.319 mAP_s: 0.106 mAP_m: 0.337 mAP_l: 0.463 |
475.6 | 58 | 8241 | RTX 3090 |
CenterNet | COCO | TensorRT | 32 | FP16 | - | mAP: 0.297 mAP_50: 0.463 mAP_75: 0.316 mAP_s: 0.106 mAP_m: 0.336 mAP_l: 0.46 |
1247.1 | 29 | 5183 | RTX 3090 |
CenterNet | COCO | TensorRT | 32 | FP32/INT8 | PTQ max/per-tensor | mAP: 0.27 mAP_50: 0.426 mAP_75: 0.285 mAP_s: 0.086 mAP_m: 0.299 mAP_l: 0.438 |
1534.0 | 20 | 6549 | RTX 3090 |
CenterNet | COCO | TensorRT | 32 | FP16/INT8 | PTQ max/per-tensor | mAP: 0.285 mAP_50: 0.448 mAP_75: 0.303 mAP_s: 0.096 mAP_m: 0.319 mAP_l: 0.451 |
1889.0 | 17 | 6453 | RTX 3090 |
git clone [email protected]:DerryHub/BEVFormer_tensorrt.git
cd BEVFormer_tensorrt
PROJECT_DIR=$(pwd)
Download the COCO 2017 datasets to /path/to/coco
and unzip them.
cd ${PROJECT_DIR}/data
ln -s /path/to/coco coco
Download nuScenes V1.0 full dataset data and CAN bus expansion data HERE as /path/to/nuscenes
and /path/to/can_bus
.
Prepare nuscenes data like BEVFormer.
cd ${PROJECT_DIR}/data
ln -s /path/to/nuscenes nuscenes
ln -s /path/to/can_bus can_bus
cd ${PROJECT_DIR}
sh samples/bevformer/create_data.sh
${PROJECT_DIR}/data/.
├── can_bus
│ ├── scene-0001_meta.json
│ ├── scene-0001_ms_imu.json
│ ├── scene-0001_pose.json
│ └── ...
├── coco
│ ├── annotations
│ ├── test2017
│ ├── train2017
│ └── val2017
└── nuscenes
├── maps
├── samples
├── sweeps
└── v1.0-trainval
Download and install the CUDA-11.6/cuDNN-8.6.0/TensorRT-8.5.1.7
following NVIDIA.
Install PyTorch and TorchVision following the official instructions.
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
git clone https://github.com/open-mmlab/mmcv.git
cd mmcv
git checkout v1.5.0
pip install -r requirements/optional.txt
MMCV_WITH_OPS=1 pip install -e .
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
git checkout v2.25.1
pip install -v -e .
# "-v" means verbose, or more output
# "-e" means installing a project in editable mode,
# thus any local modifications made to the code will take effect without reinstallation.
git clone [email protected]:open-mmlab/mmdeploy.git
cd mmdeploy
git checkout v0.10.0
git clone [email protected]:NVIDIA/cub.git third_party/cub
cd third_party/cub
git checkout c3cceac115
# go back to third_party directory and git clone pybind11
cd ..
git clone [email protected]:pybind/pybind11.git pybind11
cd pybind11
git checkout 70a58c5
Make sure cmake version >= 3.14.0 and gcc version >= 7.
export MMDEPLOY_DIR=/the/root/path/of/MMDeploy
export TENSORRT_DIR=/the/path/of/tensorrt
export CUDNN_DIR=/the/path/of/cuda
export LD_LIBRARY_PATH=$TENSORRT_DIR/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$CUDNN_DIR/lib64:$LD_LIBRARY_PATH
cd ${MMDEPLOY_DIR}
mkdir -p build
cd build
cmake -DCMAKE_CXX_COMPILER=g++-7 -DMMDEPLOY_TARGET_BACKENDS=trt -DTENSORRT_DIR=${TENSORRT_DIR} -DCUDNN_DIR=${CUDNN_DIR} ..
make -j$(nproc)
make install
cd ${MMDEPLOY_DIR}
pip install -v -e .
# "-v" means verbose, or more output
# "-e" means installing a project in editable mode,
# thus any local modifications made to the code will take effect without reinstallation.
cd ${PROJECT_DIR}/TensorRT/build
cmake .. -DCMAKE_TENSORRT_PATH=/path/to/TensorRT
make -j$(nproc)
make install
Run Unit Test of Custom TensorRT Plugins
cd ${PROJECT_DIR}
sh samples/test_trt_ops.sh
cd ${PROJECT_DIR}/third_party/bevformer
python setup.py build develop
Download above PyTorch checkpoints to ${PROJECT_DIR}/checkpoints/pytorch/
. The ONNX files and TensorRT engines will be saved in ${PROJECT_DIR}/checkpoints/onnx/
and ${PROJECT_DIR}/checkpoints/tensorrt/
.
Support Common TensorRT Ops in BEVFormer: Grid Sampler
, Multi-scale Deformable Attention
, Modulated Deformable Conv2d
and Rotate
.
Each operation is implemented as 2 versions: FP32/FP16 (nv_half)/INT8 and FP32/FP16 (nv_half2)/INT8.
For specific speed comparison, see Custom TensorRT Plugins.
The following tutorial uses BEVFormer base
as an example.
- Evaluate with PyTorch
cd ${PROJECT_DIR}
# defult gpu_id is 0
sh samples/bevformer/base/pth_evaluate.sh -d ${gpu_id}
- Evaluate with TensorRT and MMDeploy Plugins
# convert .pth to .onnx
sh samples/bevformer/base/pth2onnx.sh -d ${gpu_id}
# convert .onnx to TensorRT engine (FP32)
sh samples/bevformer/base/onnx2trt.sh -d ${gpu_id}
# convert .onnx to TensorRT engine (FP16)
sh samples/bevformer/base/onnx2trt_fp16.sh -d ${gpu_id}
# evaluate with TensorRT engine (FP32)
sh samples/bevformer/base/trt_evaluate.sh -d ${gpu_id}
# evaluate with TensorRT engine (FP16)
sh samples/bevformer/base/trt_evaluate_fp16.sh -d ${gpu_id}
# Quantization
# calibration and convert .onnx to TensorRT engine (FP32/INT8)
sh samples/bevformer/base/onnx2trt_int8.sh -d ${gpu_id}
# calibration and convert .onnx to TensorRT engine (FP16/INT8)
sh samples/bevformer/base/onnx2trt_int8_fp16.sh -d ${gpu_id}
# evaluate with TensorRT engine (FP32/INT8)
sh samples/bevformer/base/trt_evaluate_int8.sh -d ${gpu_id}
# evaluate with TensorRT engine (FP16/INT8)
sh samples/bevformer/base/trt_evaluate_int8_fp16.sh -d ${gpu_id}
# quantization aware train
# defult gpu_ids is 0,1,2,3,4,5,6,7
sh samples/bevformer/base/quant_aware_train.sh -d ${gpu_ids}
# then following the post training quantization process
- Evaluate with TensorRT and Custom Plugins
# nv_half
# convert .pth to .onnx
sh samples/bevformer/plugin/base/pth2onnx.sh -d ${gpu_id}
# convert .onnx to TensorRT engine (FP32)
sh samples/bevformer/plugin/base/onnx2trt.sh -d ${gpu_id}
# convert .onnx to TensorRT engine (FP16-nv_half)
sh samples/bevformer/plugin/base/onnx2trt_fp16.sh -d ${gpu_id}
# evaluate with TensorRT engine (FP32)
sh samples/bevformer/plugin/base/trt_evaluate.sh -d ${gpu_id}
# evaluate with TensorRT engine (FP16-nv_half)
sh samples/bevformer/plugin/base/trt_evaluate_fp16.sh -d ${gpu_id}
# nv_half2
# convert .pth to .onnx
sh samples/bevformer/plugin/base/pth2onnx_2.sh -d ${gpu_id}
# convert .onnx to TensorRT engine (FP16-nv_half2)
sh samples/bevformer/plugin/base/onnx2trt_fp16_2.sh -d ${gpu_id}
# evaluate with TensorRT engine (FP16-nv_half2)
sh samples/bevformer/plugin/base/trt_evaluate_fp16_2.sh -d ${gpu_id}
# Quantization
# nv_half
# calibration and convert .onnx to TensorRT engine (FP32/INT8)
sh samples/bevformer/plugin/base/onnx2trt_int8.sh -d ${gpu_id}
# calibration and convert .onnx to TensorRT engine (FP16-nv_half/INT8)
sh samples/bevformer/plugin/base/onnx2trt_int8_fp16.sh -d ${gpu_id}
# evaluate with TensorRT engine (FP32/INT8)
sh samples/bevformer/plugin/base/trt_evaluate_int8.sh -d ${gpu_id}
# evaluate with TensorRT engine (FP16-nv_half/INT8)
sh samples/bevformer/plugin/base/trt_evaluate_int8_fp16.sh -d ${gpu_id}
# nv_half2
# calibration and convert .onnx to TensorRT engine (FP16-nv_half2/INT8)
sh samples/bevformer/plugin/base/onnx2trt_int8_fp16_2.sh -d ${gpu_id}
# evaluate with TensorRT engine (FP16-nv_half2/INT8)
sh samples/bevformer/plugin/base/trt_evaluate_int8_fp16_2.sh -d ${gpu_id}
This project is mainly based on these excellent open source projects: