Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Detect() anchors for ONNX <> OpenCV DNN compatibility #4833

Merged
merged 19 commits into from
Oct 11, 2021
Merged

Refactor Detect() anchors for ONNX <> OpenCV DNN compatibility #4833

merged 19 commits into from
Oct 11, 2021

Conversation

jebastin-nadar
Copy link
Contributor

@jebastin-nadar jebastin-nadar commented Sep 16, 2021

Continuation of #4811.

self.anchor_grid was causing a problem as it is a pytorch tensor, so the memory is already allocated and shape defined. Due to this, expanding anchor_grid to perform element-wise multiplication resulted in added nodes in exported onnx model. So, I have changed self.anchor_grid to a list of tensors instead (similar to self.grid) and refactored some other code related to anchor_grids to make this change workable.

New onnx models (notice the new shapes in Add and Mul node):
Screenshot (119)

The exported onnx model now works with opencv dnn module.
(with --simplify). For onnx model without --simplify option, the opset version must be set to 11 to import in opencv (ONNX changed behaviour of unsqueeze node in opset 13, need a small fix in opencv to import with opset 13 as well). Fixed by opencv/opencv#20713

Note:

This change breaks inference for models trained before this PR Added compatibility with previously trained models

🛠️ PR Summary

Made with ❤️ by Ultralytics Actions

🌟 Summary

Enhanced anchor handling in YOLOv5's detection layers for better compatibility and maintainability.

📊 Key Changes

  • Introduced conditions to accommodate the anchor_grid as a list within the Detect() layer, ensuring proper application of functions to the anchor_grid.
  • Updated the Detect() layer code to improve compatibility with changes in anchor handling.
  • Modified TensorFlow model definition to correctly reshape the anchor grid based on stride changes.
  • Improved the anchor grid initialization in the YOLO layer to work with dynamic grid sizes during inference.
  • Ensured the order of anchors matches the order of strides to prevent potential inconsistencies.
  • Simplified anchor checks and updates in the autoanchor module, aligning with the new anchor handling approach.

🎯 Purpose & Impact

  • These changes aim to fix compatibility issues and simplify the handling of anchor grids in the various neural network layers.
  • For users and developers, these adjustments lead to a more robust and flexible object detection model, especially when customizing anchor configurations or using dynamic inputs.
  • Improvements could lead to more accurate object detection and easier maintenance and updates to the YOLOv5 architecture. 🎯🛠️

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👋 Hello @SamFC10, thank you for submitting a 🚀 PR! To allow your work to be integrated as seamlessly as possible, we advise you to:

  • ✅ Verify your PR is up-to-date with origin/master. If your PR is behind origin/master an automatic GitHub actions rebase may be attempted by including the /rebase command in a comment body, or by running the following code, replacing 'feature' with the name of your local branch:
git remote add upstream https://github.com/ultralytics/yolov5.git
git fetch upstream
git checkout feature  # <----- replace 'feature' with local branch name
git rebase upstream/master
git push -u origin -f
  • ✅ Verify all Continuous Integration (CI) checks are passing.
  • ✅ Reduce changes to the absolute minimum required for your bug fix or feature addition. "It is not daily increase but daily decrease, hack away the unessential. The closer to the source, the less wastage there is." -Bruce Lee

@glenn-jocher
Copy link
Member

@SamFC10 thanks for the PR! CI checks are failing, can you take a look at this please?

@jebastin-nadar
Copy link
Contributor Author

@glenn-jocher Fixed all CI build failures.

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 17, 2021

@SamFC10 thanks for the CI fixes! I ran CI on GPU in Colab and saw the error below. To reproduce, run Setup cell and then CI Checks cell in Appendix:
https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb#scrollTo=FGH0ZjkGjejy

1. Setup

!git clone https://github.com/ultralytics/yolov5  # clone repo
%cd yolov5
%pip install -qr requirements.txt  # install dependencies

import torch
from IPython.display import Image, clear_output  # to display images

clear_output()
print(f"Setup complete. Using torch {torch.__version__} ({torch.cuda.get_device_properties(0).name if torch.cuda.is_available() else 'CPU'})")

2. CI Checks

%%shell
export PYTHONPATH="$PWD"  # to run *.py. files in subdirectories
rm -rf runs  # remove runs/
for m in yolov5s; do  # models
  python train.py --weights $m.pt --epochs 3 --img 320 --device 0  # train pretrained
  python train.py --weights '' --cfg $m.yaml --epochs 3 --img 320 --device 0  # train scratch
  for d in 0 cpu; do  # devices
    python detect.py --weights $m.pt --device $d  # detect official
    python detect.py --weights runs/train/exp/weights/best.pt --device $d  # detect custom
    python val.py --weights $m.pt --device $d # val official
    python val.py --weights runs/train/exp/weights/best.pt --device $d # val custom
  done
python hubconf.py  # hub
python models/yolo.py --cfg $m.yaml  # build PyTorch model
python models/tf.py --weights $m.pt  # build TensorFlow model
python export.py --img 128 --batch 1 --weights $m.pt --include torchscript onnx  # export
done

Error

Traceback (most recent call last):
  File "train.py", line 611, in <module>
    main(opt)
  File "train.py", line 509, in main
    train(opt.hyp, opt, device, callbacks)
  File "train.py", line 261, in train
    compute_loss = ComputeLoss(model)  # init loss class
  File "/content/yolov5/utils/loss.py", line 114, in __init__
    self.anchors = det.anchors / det.stride.view(-1, 1, 1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@jebastin-nadar
Copy link
Contributor Author

I ran CI on GPU in Colab and saw the error below

I was able to reproduce the errors. Latest commit fixes them and CI checks are now passed on Colab GPU runtime

@glenn-jocher
Copy link
Member

@SamFC10 awesome thank you! Unfortunately we don't have a seamless CPU-GPU CI in place, so our current workflow involves manually running Colab CI before merges.

@glenn-jocher glenn-jocher changed the title feature : refactor anchors and anchor_grid to make onnx export compatible with opencv dnn module Refactor Detect() anchors for ONNX <> OpenCV DNN compatibility Sep 18, 2021
@glenn-jocher glenn-jocher added the enhancement New feature or request label Sep 18, 2021
@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 18, 2021

@SamFC10 a problem came up in testing. Detect, Val, Export operate correctly. Train does not throw an error, but produces near zero results in comparison to master in the Colab notebook Train section. I think when --weights are being loaded for training the anchors are being overwritten or reordered incorrectly.
https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb#scrollTo=1NcFxRcFdJ_O&line=1&uniqifier=1

# Train YOLOv5s on COCO128 for 3 epochs
!python train.py --img 640 --batch 16 --epochs 3 --data coco128.yaml --weights yolov5s.pt --cache

This PR

     Epoch   gpu_mem       box       obj       cls    labels  img_size
       0/2     3.55G   0.09968   0.06903   0.02277       166       640: 100% 8/8 [00:04<00:00,  1.62it/s]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100% 4/4 [00:01<00:00,  3.97it/s]
                 all        128        929       0.24    0.00101   9.33e-05   1.87e-05

     Epoch   gpu_mem       box       obj       cls    labels  img_size
       1/2      4.4G   0.09012   0.06828   0.02285       205       640: 100% 8/8 [00:02<00:00,  3.90it/s]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100% 4/4 [00:00<00:00,  4.48it/s]
                 all        128        929       0.24    0.00101   9.67e-05   1.19e-05

     Epoch   gpu_mem       box       obj       cls    labels  img_size
       2/2      4.4G   0.07513   0.06287   0.02503       163       640: 100% 8/8 [00:02<00:00,  3.78it/s]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100% 4/4 [00:05<00:00,  1.25s/it]
                 all        128        929      0.127    0.00101   7.51e-05   3.07e-05

Master

     Epoch   gpu_mem       box       obj       cls    labels  img_size
       0/2     3.55G   0.04438   0.07064   0.01989       166       640: 100% 8/8 [00:04<00:00,  1.72it/s]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100% 4/4 [00:01<00:00,  3.79it/s]
                 all        128        929      0.682      0.578       0.65      0.426

     Epoch   gpu_mem       box       obj       cls    labels  img_size
       1/2      4.4G   0.04641   0.06915   0.01959       205       640: 100% 8/8 [00:02<00:00,  3.83it/s]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100% 4/4 [00:00<00:00,  4.26it/s]
                 all        128        929      0.683      0.583      0.657      0.426

     Epoch   gpu_mem       box       obj       cls    labels  img_size
       2/2      4.4G   0.04365   0.06366   0.02201       163       640: 100% 8/8 [00:02<00:00,  3.85it/s]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100% 4/4 [00:03<00:00,  1.15it/s]
                 all        128        929      0.694      0.575      0.662      0.432

@jebastin-nadar
Copy link
Contributor Author

@glenn-jocher Was able to reproduce this behaviour

I think when --weights are being loaded for training the anchors are being overwritten

Quite likely. In train.py when pre-trained weights are used, model.load_state_dict() probably overwrites the new anchors with the old ones (which is actually anchors/strides). Investigating and will add a fix soon.

@jebastin-nadar
Copy link
Contributor Author

Latest commit solves the issue of poor training accuracy with `--weights option

This PR

     Epoch   gpu_mem       box       obj       cls    labels  img_size
       0/2     3.45G   0.04438   0.07067   0.01988       166       640: 100% 8/8 [00:11<00:00,  1.42s/it]
               Class     Images     Labels          P          R     [email protected] [email protected]:.95: 100% 4/4 [00:02<00:00,  1.55it/s]
                 all        128        929       0.68      0.579       0.65      0.427

     Epoch   gpu_mem       box       obj       cls    labels  img_size
       1/2     4.58G   0.04641   0.06921   0.01958       205       640: 100% 8/8 [00:08<00:00,  1.05s/it]
               Class     Images     Labels          P          R     [email protected] [email protected]:.95: 100% 4/4 [00:02<00:00,  1.54it/s]
                 all        128        929      0.684      0.582      0.659      0.426

     Epoch   gpu_mem       box       obj       cls    labels  img_size
       2/2     4.58G   0.04365   0.06366     0.022       163       640: 100% 8/8 [00:08<00:00,  1.06s/it]
               Class     Images     Labels          P          R     [email protected] [email protected]:.95: 100% 4/4 [00:05<00:00,  1.40s/it]
                 all        128        929      0.696      0.575      0.663      0.433

train.py Outdated Show resolved Hide resolved
@glenn-jocher glenn-jocher self-assigned this Sep 29, 2021
@glenn-jocher
Copy link
Member

@SamFC10 I looked at this again a couple days ago. It seems like the PR works well as a standalone branch (i.e. to train and detect), but for master branch we need to ensure that any updates are backwards compatible with all of the existing YOLOv5 models out in the wild from earlier releases.

It seemed like this might cause problems in those cases, but I wasn't sure. It's hard to know without a more comprehensive set of tests.

If I have some more time this week I will try to revisit.

@jebastin-nadar
Copy link
Contributor Author

we need to ensure that any updates are backwards compatible with all of the existing YOLOv5 models

Backwards compatibility was needed because this PR changed the definition of anchors (they are stored as anchors/strides in all previous models). But the original problem was changing the shape of anchor_grid to make the model compatible with opencv. So it makes sense to not modify the definition of anchors.

Looking at the changes made, I think the current solution is not elegant/ideal. I have another branch locally which fixes the problem and does not modify anchors, so no issue with backwards compatibility. I'll add a new commit with this solution after testing with your colab notebook.

Marking this PR as draft for now.

@jebastin-nadar jebastin-nadar marked this pull request as draft October 5, 2021 06:03
@glenn-jocher
Copy link
Member

@SamFC10 I did a suite of tests yesterday and everything seemed to work well:

train new from --weights
train new from --weights --cfg
train new from --weights '' --cfg
detect from new and trained above
val from new and trained above
train, break, resume

I have not tested export yet. But based on the above tests backwards compatibility seemed to work correctly so far.

@glenn-jocher
Copy link
Member

OpenCV DNN Inference Instructions

# Export to ONNX
python export.py --weights yolov5s.pt --include onnx --simplify

# Inference
python detect.py --weights yolov5s.onnx  # ONNX Runtime inference
# -- or --
python detect.py --weights yolov5s.onnx --dnn  # OpenCV DNN inference

@ghost
Copy link

ghost commented Oct 21, 2021

Hello @glenn-jocher @SamFC10 and anyone who may have the same issue !

After pulling the repo today I could not use an old weight file (trained about a year ago) to perform inference.

I was getting this error :

detect: weights=['/tmp/model/last_95.pt'], source=yolov5/data/images, imgsz=[640, 640], conf_thres=0.25, iou_thres=0.45, max_det=1000, device=, view_img=False, save_txt=False, save_conf=False, save_crop=False, nosave=False, classes=None, agnostic_nms=False, augment=False, visualize=False, update=False, project=yolov5/runs/detect, name=exp, exist_ok=False, line_thickness=3, hide_labels=False, hide_conf=False, half=False, dnn=False
YOLOv5 🚀 v6.0-25-g15e8c4c torch 1.9.0 CUDA:0 (Tesla K80, 11441.25MB)

Fusing layers... 
Model Summary: 484 layers, 88410801 parameters, 0 gradients
Traceback (most recent call last):
  File "yolov5/detect.py", line 336, in <module>
    main(opt)
  File "yolov5/detect.py", line 331, in main
    run(**vars(opt))
  File "/anaconda/envs/py38_pytorch/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "yolov5/detect.py", line 137, in run
    model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters())))  # run once
  File "/anaconda/envs/py38_pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/yolov5/models/yolo.py", line 126, in forward
    return self._forward_once(x, profile, visualize)  # single-scale inference, train
  File "/mnt/yolov5/models/yolo.py", line 149, in _forward_once
    x = m(x)  # run
  File "/anaconda/envs/py38_pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/yolov5/models/yolo.py", line 66, in forward
    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Using git bisect, I tracked down the appearance of the error to this commit.

I do not have enough data to post a new issue. Nevertheless, I wanted to report it and to propose a temporary solution for whoever ecounters the same error.

So, I solved the issue by loading the weights in training script (as if I were continuing training) and saving them with new ckpt before the epochs loop. After this, no more error during detection.

Thank you very much for your work !
Hope this helps.

@glenn-jocher
Copy link
Member

glenn-jocher commented Oct 21, 2021

@adrienanton thanks for the bug report! I think there may be built-in functionality to do what you need. If you run detect.py --update with your weights it will process them into the latest release format and resave them under the same name, i.e.:

python detect.py --weights /tmp/model/last_95.pt --update

@ghost
Copy link

ghost commented Oct 22, 2021

@glenn-jocher Thanks for the response !

I do not think detect.py --update is what I was looking for since the error occurs before the update.
Even if I move up the update before the inference I still have the same issue.

@glenn-jocher
Copy link
Member

@adrienanton ok got it. The workaround then is just to retrain your model. You'll get better results and the error will be resolved.

@Tyler-D
Copy link

Tyler-D commented May 13, 2022

@glenn-jocher @SamFC10 I'm reading the YOLOV5 code, just curious why do we need to do anchors = anchors / stride ?

@glenn-jocher
Copy link
Member

@Tyler-D loss is computed in grid space (i.e. 0-19) rather than pixel space (0-639)

@Tyler-D
Copy link

Tyler-D commented May 13, 2022

@Tyler-D loss is computed in grid space (i.e. 0-19) rather than pixel space (0-639)

Got it. Thanks !

BjarneKuehl pushed a commit to fhkiel-mlaip/yolov5 that referenced this pull request Aug 26, 2022
…ralytics#4833)

* refactor anchors and anchor_grid in Detect Layer

* fix CI failures by adding compatibility

* fix tf failure

* fix different devices errors

* Cleanup

* fix anchors overwriting issue

* better refactoring

* Remove self.anchor_grid shape check (redundant with self.grid check)

Also PEP8 / 120 line width

* Convert _make_grid() from static to dynamic method

* Remove anchor_grid.to(device)

clone() should already clone to same device as self.anchors

* fix different devices error

Co-authored-by: Glenn Jocher <[email protected]>
@alkhalisy
Copy link

Dear, I try to remove P3 and P5 detection and still tp P4, and I do require a change in the Neck, and everything becomes well and works. When I try to delete C5 from the feature Extraction Backbone and do modifications in the neck, the error happened
"File "/content/yolov5/models/yolo.py", line 334, in
args.append([ch[x] for x in f])
IndexError: list index out of range"
why any I can not do any change to the backbone???

nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:

[30,61, 62,45, 59,119] # P4/16
backbone:

[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 3, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 3, C3, [512]],
[-1, 1, SPPF, [512, 5]], # 9
]

head:
[
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13

[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)

[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)

[[20], 1, Detect, [nc, anchors]], # Detect( P4)
]

@glenn-jocher
Copy link
Member

@alkhalisy it seems like the error is occurring because you are deleting C5 from the feature extraction backbone but not updating the head accordingly. Since the head is referencing the deleted C5, it causes an index out of range error when trying to access it. Make sure to update the head according to the changes made in the backbone to resolve this error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
5 participants