Skip to content

Commit

Permalink
Export, detect and validation with TensorRT engine file (ultralytics#…
Browse files Browse the repository at this point in the history
…5699)

* Export and detect with TensorRT engine file

* Resolve `isort`

* Make validation works with TensorRT engine

* feat: update export docstring

* feat: change suffix from *.trt to *.engine

* feat: get rid of pycuda

* feat: make compatiable with val.py

* feat: support detect with fp16 engine

* Add Lite to Edge TPU string

* Remove *.trt comment

* Revert to standard success logger.info string

* Fix Deprecation Warning

```
export.py:310: DeprecationWarning: Use build_serialized_network instead.
  with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
```

* Revert deprecation warning fix

@imyhxy it seems we can't apply the deprecation warning fix because then export fails, so I'm reverting my previous change here.

* Update export.py

* Update export.py

* Update common.py

* export onnx to file before building TensorRT engine file

* feat: triger ONNX export failed early

* feat: load ONNX model from file

Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
imyhxy and glenn-jocher committed Nov 22, 2021
1 parent cfc32b7 commit 873145e
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 11 deletions.
4 changes: 2 additions & 2 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
# Load model
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn)
stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
imgsz = check_img_size(imgsz, s=stride) # check image size

# Half
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
if pt:
model.model.half() if half else model.model.float()

Expand Down
55 changes: 54 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
TensorFlow GraphDef | yolov5s.pb | 'pb'
TensorFlow Lite | yolov5s.tflite | 'tflite'
TensorFlow.js | yolov5s_web_model/ | 'tfjs'
TensorRT | yolov5s.engine | 'engine'
Usage:
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs
Expand All @@ -24,6 +25,7 @@
yolov5s_saved_model
yolov5s.pb
yolov5s.tflite
yolov5s.engine
TensorFlow.js:
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
Expand Down Expand Up @@ -263,6 +265,51 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
LOGGER.info(f'\n{prefix} export failure: {e}')


def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
try:
check_requirements(('tensorrt',))
import tensorrt as trt

opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
export_onnx(model, im, file, opset, train, False, simplify)
onnx = file.with_suffix('.onnx')
assert onnx.exists(), f'failed to export ONNX file: {onnx}'

LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
f = str(file).replace('.pt', '.engine') # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO)
if verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE

builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30

flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx)):
raise RuntimeError(f'failed to load ONNX file: {onnx}')

inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
LOGGER.info(f'{prefix} Network Description:')
for inp in inputs:
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
for out in outputs:
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')

half &= builder.platform_has_fast_fp16
LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}')
if half:
config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
t.write(engine.serialize())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')

except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')

@torch.no_grad()
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
weights=ROOT / 'yolov5s.pt', # weights path
Expand All @@ -278,6 +325,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
dynamic=False, # ONNX/TF: dynamic axes
simplify=False, # ONNX: simplify model
opset=12, # ONNX: opset version
verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB)
topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold
Expand Down Expand Up @@ -322,6 +371,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
export_torchscript(model, im, file, optimize)
if 'onnx' in include:
export_onnx(model, im, file, opset, train, dynamic, simplify)
if 'engine' in include:
export_engine(model, im, file, train, half, simplify, workspace, verbose)
if 'coreml' in include:
export_coreml(model, im, file)

Expand Down Expand Up @@ -360,13 +411,15 @@ def parse_opt():
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument('--include', nargs='+',
default=['torchscript', 'onnx'],
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
help='available formats are (torchscript, onnx, engine, coreml, saved_model, pb, tflite, tfjs)')
opt = parser.parse_args()
print_args(FILE.stem, opt)
return opt
Expand Down
32 changes: 28 additions & 4 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import platform
import warnings
from collections import namedtuple
from copy import copy
from pathlib import Path

Expand Down Expand Up @@ -285,11 +286,12 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
# TensorFlow Lite: *.tflite
# ONNX Runtime: *.onnx
# OpenCV DNN: *.onnx with dnn=True
# TensorRT: *.engine
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel']
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
pt, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
jit = pt and 'torchscript' in w.lower()
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults

Expand Down Expand Up @@ -317,6 +319,23 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
elif engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...')
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
logger = trt.Logger(trt.Logger.INFO)
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
bindings = dict()
for index in range(model.num_bindings):
name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index))
shape = tuple(model.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
binding_addrs = {n: d.ptr for n, d in bindings.items()}
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
else: # TensorFlow model (TFLite, pb, saved_model)
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
Expand All @@ -334,7 +353,7 @@ def wrap_frozen_graph(gd, inputs, outputs):
model = tf.keras.models.load_model(w)
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
if 'edgetpu' in w.lower():
LOGGER.info(f'Loading {w} for TensorFlow Edge TPU inference...')
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
import tflite_runtime.interpreter as tfli
delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime
'Darwin': 'libedgetpu.1.dylib',
Expand Down Expand Up @@ -369,6 +388,11 @@ def forward(self, im, augment=False, visualize=False, val=False):
y = self.net.forward()
else: # ONNX Runtime
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
elif self.engine: # TensorRT
assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
self.binding_addrs['images'] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = self.bindings['output'].data
else: # TensorFlow model (TFLite, pb, saved_model)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.pb:
Expand All @@ -391,7 +415,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
y[..., 1] *= h # y
y[..., 2] *= w # w
y[..., 3] *= h # h
y = torch.tensor(y)
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
return (y, []) if val else y


Expand Down
10 changes: 6 additions & 4 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def run(data,
# Initialize/load model and set device
training = model is not None
if training: # called by train.py
device, pt = next(model.parameters()).device, True # get model device, PyTorch model
device, pt, engine = next(model.parameters()).device, True, False # get model device, PyTorch model

half &= device.type != 'cpu' # half precision only supported on CUDA
model.half() if half else model.float()
Expand All @@ -124,11 +124,13 @@ def run(data,

# Load model
model = DetectMultiBackend(weights, device=device, dnn=dnn)
stride, pt = model.stride, model.pt
stride, pt, engine = model.stride, model.pt, model.engine
imgsz = check_img_size(imgsz, s=stride) # check image size
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
if pt:
model.model.half() if half else model.model.float()
elif engine:
batch_size = model.batch_size
else:
half = False
batch_size = 1 # export.py models default to batch-size 1
Expand Down Expand Up @@ -165,7 +167,7 @@ def run(data,
pbar = tqdm(dataloader, desc=s, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
t1 = time_sync()
if pt:
if pt or engine:
im = im.to(device, non_blocking=True)
targets = targets.to(device)
im = im.half() if half else im.float() # uint8 to fp16/32
Expand Down

0 comments on commit 873145e

Please sign in to comment.