Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed sync_batchnorm export error and add detect_onnx.py #98

Open
wants to merge 2 commits into
base: paper
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions detect_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
import os
import cv2
import torch
import numpy as np
import onnxruntime as ort
from utils.general import non_max_suppression

names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush']


def infer_yolor(onnx_path="./weights/yolor-d6-640-640.onnx"):
ort.set_default_logger_severity(4)
ort_session = ort.InferenceSession(onnx_path)

outputs_info = ort_session.get_outputs()
print("num outputs: ", len(outputs_info))
print(outputs_info)

test_path = "./inference/images/horses.jpg"
save_path = f"./inference/images/horses_{os.path.basename(onnx_path)}.jpg"

img_bgr = cv2.imread(test_path)
height, width, _ = img_bgr.shape

img_rgb = img_bgr[:, :, ::-1]
img_rgb = cv2.resize(img_rgb, (640, 640))
img = img_rgb.transpose(2, 0, 1).astype(np.float32) # (3,640,640) RGB

img /= 255.0

img = np.expand_dims(img, 0)
# [1,num_anchors,num_outputs=2+2+1+nc=cxcy+wh+conf+cls_prob]
pred = ort_session.run(["output"], input_feed={"images": img})[0]

print(pred.shape)
print(pred[0, :4].min())
print(pred[0, :4].max())
pred_tensor = torch.from_numpy(pred).float()

boxes_tensor = non_max_suppression(pred_tensor)[0] # [n,6] [x1,y1,x2,y2,conf,cls]

boxes = boxes_tensor.cpu().numpy().astype(np.float32)

if boxes.shape[0] == 0:
print("no bounding boxes detected.")
return
scale_w = width / 640.
scale_h = height / 640.

print(boxes[:2, :])

boxes[:, 0] *= scale_w
boxes[:, 1] *= scale_h
boxes[:, 2] *= scale_w
boxes[:, 3] *= scale_h

print(f"detect {boxes.shape[0]} bounding boxes.")

for i in range(boxes.shape[0]):
x1, y1, x2, y2, conf, label = boxes[i]
print(boxes[i])
x1, y1, x2, y2, label = int(x1), int(y1), int(x2), int(y2), int(label)
cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2, 2)
cv2.putText(img_bgr, names[label] + ":{:.2f}".format(conf), (x1, y1),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2, 2)

cv2.imwrite(save_path, img_bgr)

print("detect done.")


if __name__ == "__main__":
np.set_printoptions(suppress=True)

infer_yolor(onnx_path="./weights/yolor-p6-640-640.onnx")
infer_yolor(onnx_path="./weights/yolor-d6-640-640.onnx")
infer_yolor(onnx_path="./weights/yolor-e6-640-640.onnx")
infer_yolor(onnx_path="./weights/yolor-w6-640-640.onnx")
infer_yolor(onnx_path="./weights/yolor-ssss-s2d-640-640.onnx")

"""
PYTHONPATH=. python3 ./detect_onnx.py
"""
62 changes: 59 additions & 3 deletions models/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,36 @@

import torch
import torch.nn as nn

import models
from models.experimental import attempt_load
from utils.activations import Hardswish
from utils.general import set_logging, check_img_size


# need convert SyncBatchNorm to BatchNorm2d
def convert_sync_batchnorm_to_batchnorm(module):
module_output = module
if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
module_output = torch.nn.BatchNorm2d(module.num_features,
module.eps, module.momentum,
module.affine,
module.track_running_stats)

if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, convert_sync_batchnorm_to_batchnorm(child))
del module
return module_output


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./yolor-p6.pt', help='weights path')
Expand All @@ -26,6 +50,12 @@
# Load PyTorch model
model = attempt_load(opt.weights, map_location=torch.device('cpu')) # load FP32 model
labels = model.names
model.eval()
model = model.to("cpu")

model = convert_sync_batchnorm_to_batchnorm(model)

print(model)

# Checks
gs = int(max(model.stride)) # grid size (max stride)
Expand All @@ -44,6 +74,8 @@
model.model[-1].export = True # set Detect() layer export=True
y = model(img) # dry run

print(y[0].shape)

# TorchScript export
try:
print('\nStarting TorchScript export with torch %s...' % torch.__version__)
Expand All @@ -57,16 +89,34 @@
# ONNX export
try:
import onnx
import onnxruntime as ort

print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
f = opt.weights.replace('.pt', '.onnx') # filename
f = opt.weights.replace('.pt', f'-{opt.img_size[0]}-{opt.img_size[1]}.onnx') # filename
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
output_names=['classes', 'boxes'] if y is None else ['output'])

# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model

do_simplify = True
if do_simplify:
from onnxsim import simplify

onnx_model, check = simplify(onnx_model, check_n=3)
assert check, 'assert simplify check failed'
onnx.save(onnx_model, f)

session = ort.InferenceSession(f)

for ii in session.get_inputs():
print("input: ", ii)

for oo in session.get_outputs():
print("output: ", oo)

print('ONNX export success, saved as %s' % f)
except Exception as e:
print('ONNX export failure: %s' % e)
Expand All @@ -86,3 +136,9 @@

# Finish
print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t))

"""
PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 640
PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 320
PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 1280
"""