Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

export onnx error #32

Open
williamhyin opened this issue Jun 22, 2021 · 10 comments
Open

export onnx error #32

williamhyin opened this issue Jun 22, 2021 · 10 comments

Comments

@williamhyin
Copy link

Hi

when exporting onnx, i faced a problem. Do you know how to solve this problem? I used leakyrelu to replace silu(not supported bei onnx ).

python -m torch.distributed.launch --nproc_per_node 8 --master_port 9527 train.py --batch-size 256 --img 640 640 --data data/.yaml --cfg models/yolor-ssss-dwt.yaml --weights '' --sync-bn --device 0,1,2,3,4,5,6,7 --name yolor-ssss --hyp hyp.scratch.1280.yaml --epochs 2

Starting ONNX export with onnx 1.7.0...
ONNX export failure: Sizes of tensors must match except in dimension 1. Got 23 and 24 in dimension 2 (The offending index is 1)
@WongKinYiu
Copy link
Owner

training won't show this error. i think your command and error are mismatch.
also, --data argument in your command is wrong and this hyp file is not for small model.
by the way, i am not sure if dwt can be exported to onnx or not.

@williamhyin
Copy link
Author

training won't show this error. i think your command and error are mismatch.
also, --data argument in your command is wrong and this hyp file is not for small model.
by the way, i am not sure if dwt can be exported to onnx or not.

Thanks for your quick reply. I delete the name of yaml just for personal privacy, when i training, the training command is not run, I just run 2 epoch for quick test. No matter i convert your trained yolor_p6.pt or yolor-ssss-dwt.pt, I got the same error.
Even i trained it from scratch and converted the trained pt to onnx, the error is same.

@likyoo
Copy link

likyoo commented Jul 21, 2021

training won't show this error. i think your command and error are mismatch.
also, --data argument in your command is wrong and this hyp file is not for small model.
by the way, i am not sure if dwt can be exported to onnx or not.

Thanks for your quick reply. I delete the name of yaml just for personal privacy, when i training, the training command is not run, I just run 2 epoch for quick test. No matter i convert your trained yolor_p6.pt or yolor-ssss-dwt.pt, I got the same error.
Even i trained it from scratch and converted the trained pt to onnx, the error is same.

Is this problem solved?

@abhigoku10
Copy link

@williamhyin where you able to convert in to onnx ? can you share the script

@DefTruth
Copy link

DefTruth commented Aug 7, 2021

You must convert the saved SyncBatchNorm in pretrained yolor-xx.pt to BatchNorm first ! if you reload it use CPU.

import argparse
import sys
import time

sys.path.append('./')  # to run '$ python *.py' files in subdirectories

import torch
import torch.nn as nn
import models
from models.experimental import attempt_load
from utils.activations import Hardswish
from utils.general import set_logging, check_img_size


# convert SyncBatchNorm to BatchNorm2d
def convert_sync_batchnorm_to_batchnorm(module):
    module_output = module
    if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
        module_output = torch.nn.BatchNorm2d(module.num_features,
                                             module.eps, module.momentum,
                                             module.affine,
                                             module.track_running_stats)

        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
    for name, child in module.named_children():
        module_output.add_module(name, convert_sync_batchnorm_to_batchnorm(child))
    del module
    return module_output


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default='./yolor-p6.pt', help='weights path')
    parser.add_argument('--img-size', nargs='+', type=int, default=[1280, 1280], help='image size')  # height, width
    parser.add_argument('--batch-size', type=int, default=1, help='batch size')
    opt = parser.parse_args()
    opt.img_size *= 2 if len(opt.img_size) == 1 else 1  # expand
    print(opt)
    set_logging()
    t = time.time()

    # Load PyTorch model
    model = attempt_load(opt.weights, map_location=torch.device('cpu'))  # load FP32 model
    labels = model.names
    model.eval()
    model = model.to("cpu")

    model = convert_sync_batchnorm_to_batchnorm(model)

    print(model)

    # Checks
    gs = int(max(model.stride))  # grid size (max stride)
    opt.img_size = [check_img_size(x, gs) for x in opt.img_size]  # verify img_size are gs-multiples

    # Input
    img = torch.zeros(opt.batch_size, 3, *opt.img_size)  # image size(1,3,320,192) iDetection

    # Update model
    for k, m in model.named_modules():
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
        if isinstance(m, models.common.Conv) and isinstance(m.act, nn.Hardswish):
            m.act = Hardswish()  # assign activation
        # if isinstance(m, models.yolo.Detect):
        #     m.forward = m.forward_export  # assign forward (optional)
    model.model[-1].export = True  # set Detect() layer export=True
    y = model(img)  # dry run

    # TorchScript export
    # try:
    #     print('\nStarting TorchScript export with torch %s...' % torch.__version__)
    #     f = opt.weights.replace('.pt', '.torchscript.pt')  # filename
    #     ts = torch.jit.trace(model, img)
    #     ts.save(f)
    #     print('TorchScript export success, saved as %s' % f)
    # except Exception as e:
    #     print('TorchScript export failure: %s' % e)

    # ONNX export
    try:
        import onnx
        import onnxruntime as ort

        print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
        f = opt.weights.replace('.pt', f'-{opt.img_size[0]}-{opt.img_size[1]}.onnx')  # filename
        torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
                          output_names=['classes', 'boxes'] if y is None else ['output'])

        # Checks
        onnx_model = onnx.load(f)  # load onnx model
        onnx.checker.check_model(onnx_model)  # check onnx model
        print(onnx.helper.printable_graph(onnx_model.graph))  # print a human readable model

        do_simplify = True
        if do_simplify:
            from onnxsim import simplify

            onnx_model, check = simplify(onnx_model, check_n=3)
            assert check, 'assert simplify check failed'
            onnx.save(onnx_model, f)

        session = ort.InferenceSession(f)

        for ii in session.get_inputs():
            print("input: ", ii)

        for oo in session.get_outputs():
            print("output: ", oo)

        print('ONNX export success, saved as %s' % f)
    except Exception as e:
        print('ONNX export failure: %s' % e)

    # CoreML export
    # try:
    #     import coremltools as ct
    #
    #     print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
    #     # convert model from torchscript and apply pixel scaling as per detect.py
    #     model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
    #     f = opt.weights.replace('.pt', '.mlmodel')  # filename
    #     model.save(f)
    #     print('CoreML export success, saved as %s' % f)
    # except Exception as e:
    #     print('CoreML export failure: %s' % e)

    # Finish
    print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t))

    """
    PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 640 
    """

also there is a bug in the source code will make the converted onnx get wrong output shape.

class IDetect(nn.Module):
   # ...
    def forward(self, x):
        # x = x.copy()  # for profiling
        z = []  # inference output
        # self.training |= self.export # 这句有点问题 error 
        self.training = not self.export  # change to this line
        print("self.training: ", self.training, self.nl)

my log:

Checking 0/3...
Checking 1/3...
Checking 2/3...
input:  NodeArg(name='images', type='tensor(float)', shape=[1, 3, 320, 320])
output:  NodeArg(name='output', type='tensor(float)', shape=[1, 6375, 85])
output:  NodeArg(name='910', type='tensor(float)', shape=[1, 3, 40, 40, 85])
output:  NodeArg(name='944', type='tensor(float)', shape=[1, 3, 20, 20, 85])
output:  NodeArg(name='978', type='tensor(float)', shape=[1, 3, 10, 10, 85])
output:  NodeArg(name='1012', type='tensor(float)', shape=[1, 3, 5, 5, 85])
ONNX export success, saved as ./weights/yolor-p6-320-320.onnx

还有一段逻辑在导出成onnx的时候有问题,直到我修改成和yolov5一样之后,导出的onnx就可以跑出正常的结果了. 也许是pytorch的自动broadcast没有在onnx中执行

            if not self.training:  # inference
                # if self.grid[i].shape[2:4] != x[i].shape[2:4]:
                #     self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

                self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
                y = x[i].sigmoid()
                # in yolor
                # y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
                # y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                
                # change it as yolov5
                xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy (bs,na,ny,nx,2)
                wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2)  # wh (bs,na,ny,nx,2)
                y = torch.cat((xy, wh, y[..., 4:]), -1)  # (bs,na,ny,nx,2+2+1+nc=xy+wh+conf+cls_prob)

                z.append(y.view(bs, -1, self.no))

@DefTruth
Copy link

DefTruth commented Aug 7, 2021

在pytorch中的直接内存操作,在转换onnx后会有问题。似乎onnx还不支持这种操作,虽然能转换,但是效果无法对齐。因此将源码的直接内存操作修改成普通操作,就没有问题了。

@hamedmh
Copy link

hamedmh commented Oct 18, 2021

@DefTruth Hi,

I get this error when trying the code above: ModuleNotFoundError: No module named 'models.experimental'
Please advice me where I can find models/experimental.py

Thanks!

@DefTruth
Copy link

@DefTruth Hi,

I get this error when trying the code above: ModuleNotFoundError: No module named 'models.experimental' Please advice me where I can find models/experimental.py

Thanks!

just checkout to 'paper' branch, then you will see models.experimental' script, or you can download the exported onnx files from my repo

@hamedmh
Copy link

hamedmh commented Oct 18, 2021

@DefTruth Thanks for the prompt answer! I successfully used: git clone https://github.com/WongKinYiu/yolor -b paper

@siriasadeddin
Copy link

Did someone have an issue with yolor-ssss-dwt.pt export to ONNX? I can export all of the models, but not this one.
I got the error "Couldn't export Python operator AFB2D", I am running "python ./models/export.py --weights ./yolor-ssss-dwt.pt --img-size 640" and I have done a similar command for the rest of the model without failing. Any idea?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants