Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply FReLU to YOLOv5 #2

Open
glenn-jocher opened this issue Jul 30, 2020 · 4 comments
Open

Apply FReLU to YOLOv5 #2

glenn-jocher opened this issue Jul 30, 2020 · 4 comments

Comments

@glenn-jocher
Copy link

glenn-jocher commented Jul 30, 2020

Thank you for your great contributions! In YOLOv5 we skipped over using Swish or Mish due to their expensive nature, especially during training, and retained the same LeakyReLU(0.1) as YOLOv3. We have a PR for the addition of FReLU
ultralytics/yolov5#556 as this. Is this a correct implementation? Should we use torch.max() or torch.maximum()?

# FReLU https://arxiv.org/abs/2007.11824 --------------------------------------
class FReLU(nn.Module):
    def __init__(self, c1, k=3):  # ch_in, kernel
        super().__init()__()
        self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1)
        self.bn = nn.BatchNorm2d(c1)

    @staticmethod
    def forward(self, x):
        return torch.maximum(x, self.bn(self.conv(x)))

Secondly, I saw you recommended using FReLU only in the backbone of a detection network. So would you recommend replacing all of our LeakyReLU(0.1) activations in the backbone with FReLU() to start, and then doing the same with the head activations if we see positive results? Our main model structure is here:

https://github.com/ultralytics/yolov5/blob/master/models/yolov5s.yaml

# YOLOv5 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Focus, [64, 3]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, BottleneckCSP, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 9, BottleneckCSP, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, BottleneckCSP, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 1, SPP, [1024, [5, 9, 13]]],
   [-1, 3, BottleneckCSP, [1024, False]],  # 9
  ]

# YOLOv5 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, BottleneckCSP, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, BottleneckCSP, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, BottleneckCSP, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, BottleneckCSP, [1024, False]],  # 23 (P5/32-large)

   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

Thank you in advance in advance for your time and any recommendations for our situation!

@nmaac
Copy link
Collaborator

nmaac commented Jul 30, 2020

Hi @glenn-jocher ,
(1) you should use torch.max(), which is different from MegEngine API.
(2) yes you could try to use FReLU only in the backbone first, then try it in head. you could also replace all the LeakyReLU activations to FReLU because there is only one activation in each BottleneckCSP which are not originally many.

@glenn-jocher
Copy link
Author

@nmaac great, thanks for the feedback! We'll try to experiment with this and see if it helps our results!

@clw5180
Copy link

clw5180 commented Aug 1, 2020

We'll try to experiment with this and see if it helps our results!

Hi glenn, is there any improment on COCO or VOC dataset about FReLU ? Or how does it perform on the verification set? Thanks !

@glenn-jocher
Copy link
Author

@clw5180 I've run some experiments with YOLOv5s, which seem to show some early improvement. I have not tried larger models, and improvements from one size to another sometimes does not extrapolate well, so I can't draw any conclusions across the entire v5 lineup, but so for small models it appears to help.

One warning is that FReLU increases training requirements greatly, similar to Swish perhaps, not quite as bad as Mish, so it's probably not a viable replacement for all activation functions in a model, but might be suitable for partial replacement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants