Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support timm backbones. #399

Merged
merged 3 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
python -m pip install -U pip
python -m pip install ninja opencv-python-headless onnx pytest-xdist codecov
python -m pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install Cython termcolor numpy tensorboard pycocotools matplotlib pyaml opencv-python tqdm pytorch-lightning torchmetrics codecov flake8 pytest
python -m pip install Cython termcolor numpy tensorboard pycocotools matplotlib pyaml opencv-python tqdm pytorch-lightning torchmetrics codecov flake8 pytest timm
python -m pip install -r requirements.txt
- name: Setup
run: rm -rf .eggs && python setup.py develop
Expand Down
3 changes: 3 additions & 0 deletions nanodet/model/backbone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .repvgg import RepVGG
from .resnet import ResNet
from .shufflenetv2 import ShuffleNetV2
from .timm_wrapper import TIMMWrapper


def build_backbone(cfg):
Expand All @@ -40,5 +41,7 @@ def build_backbone(cfg):
return CustomCspNet(**backbone_cfg)
elif name == "RepVGG":
return RepVGG(**backbone_cfg)
elif name == "TIMMWrapper":
return TIMMWrapper(**backbone_cfg)
else:
raise NotImplementedError
66 changes: 66 additions & 0 deletions nanodet/model/backbone/timm_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2022 RangiLyu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import torch.nn as nn

logger = logging.getLogger("NanoDet")


class TIMMWrapper(nn.Module):
"""Wrapper to use backbones in timm
https://github.com/rwightman/pytorch-image-models."""

def __init__(
self,
model_name,
features_only=True,
pretrained=True,
checkpoint_path="",
in_channels=3,
**kwargs,
):
try:
import timm
except ImportError as exc:
raise RuntimeError(
"timm is not installed, please install it first"
) from exc
super(TIMMWrapper, self).__init__()
self.timm = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs,
)

# Remove unused layers
self.timm.global_pool = None
self.timm.fc = None
self.timm.classifier = None

feature_info = getattr(self.timm, "feature_info", None)
if feature_info:
logger.info(f"TIMM backbone feature channels: {feature_info.channels()}")

def forward(self, x):
outs = self.timm(x)
if isinstance(outs, (list, tuple)):
features = tuple(outs)
else:
features = (outs,)
return features
39 changes: 39 additions & 0 deletions tests/test_models/test_backbone/test_timm_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch

from nanodet.model.backbone import build_backbone
from nanodet.model.backbone.timm_wrapper import TIMMWrapper


def test_timm_wrapper():
cfg = dict(
name="TIMMWrapper",
model_name="resnet18",
features_only=True,
pretrained=False,
output_stride=32,
out_indices=(1, 2, 3, 4),
)
model = build_backbone(cfg)

input = torch.rand(1, 3, 64, 64)
output = model(input)
assert len(output) == 4
assert output[0].shape == (1, 64, 16, 16)
assert output[1].shape == (1, 128, 8, 8)
assert output[2].shape == (1, 256, 4, 4)
assert output[3].shape == (1, 512, 2, 2)

model = TIMMWrapper(
model_name="mobilenetv3_large_100",
features_only=True,
pretrained=False,
output_stride=32,
out_indices=(1, 2, 3, 4),
)
output = model(input)

assert len(output) == 4
assert output[0].shape == (1, 24, 16, 16)
assert output[1].shape == (1, 40, 8, 8)
assert output[2].shape == (1, 112, 4, 4)
assert output[3].shape == (1, 960, 2, 2)