Skip to content

Commit

Permalink
[Feature] Support timm backbones. (#399)
Browse files Browse the repository at this point in the history
* [Feature] Support timm backbones.

* update ci

* fix lint
  • Loading branch information
RangiLyu committed Aug 26, 2022
1 parent 837d65b commit 9795110
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
python -m pip install -U pip
python -m pip install ninja opencv-python-headless onnx pytest-xdist codecov
python -m pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install Cython termcolor numpy tensorboard pycocotools matplotlib pyaml opencv-python tqdm pytorch-lightning torchmetrics codecov flake8 pytest
python -m pip install Cython termcolor numpy tensorboard pycocotools matplotlib pyaml opencv-python tqdm pytorch-lightning torchmetrics codecov flake8 pytest timm
python -m pip install -r requirements.txt
- name: Setup
run: rm -rf .eggs && python setup.py develop
Expand Down
3 changes: 3 additions & 0 deletions nanodet/model/backbone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .repvgg import RepVGG
from .resnet import ResNet
from .shufflenetv2 import ShuffleNetV2
from .timm_wrapper import TIMMWrapper


def build_backbone(cfg):
Expand All @@ -40,5 +41,7 @@ def build_backbone(cfg):
return CustomCspNet(**backbone_cfg)
elif name == "RepVGG":
return RepVGG(**backbone_cfg)
elif name == "TIMMWrapper":
return TIMMWrapper(**backbone_cfg)
else:
raise NotImplementedError
66 changes: 66 additions & 0 deletions nanodet/model/backbone/timm_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2022 RangiLyu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import torch.nn as nn

logger = logging.getLogger("NanoDet")


class TIMMWrapper(nn.Module):
"""Wrapper to use backbones in timm
https://github.com/rwightman/pytorch-image-models."""

def __init__(
self,
model_name,
features_only=True,
pretrained=True,
checkpoint_path="",
in_channels=3,
**kwargs,
):
try:
import timm
except ImportError as exc:
raise RuntimeError(
"timm is not installed, please install it first"
) from exc
super(TIMMWrapper, self).__init__()
self.timm = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs,
)

# Remove unused layers
self.timm.global_pool = None
self.timm.fc = None
self.timm.classifier = None

feature_info = getattr(self.timm, "feature_info", None)
if feature_info:
logger.info(f"TIMM backbone feature channels: {feature_info.channels()}")

def forward(self, x):
outs = self.timm(x)
if isinstance(outs, (list, tuple)):
features = tuple(outs)
else:
features = (outs,)
return features
39 changes: 39 additions & 0 deletions tests/test_models/test_backbone/test_timm_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch

from nanodet.model.backbone import build_backbone
from nanodet.model.backbone.timm_wrapper import TIMMWrapper


def test_timm_wrapper():
cfg = dict(
name="TIMMWrapper",
model_name="resnet18",
features_only=True,
pretrained=False,
output_stride=32,
out_indices=(1, 2, 3, 4),
)
model = build_backbone(cfg)

input = torch.rand(1, 3, 64, 64)
output = model(input)
assert len(output) == 4
assert output[0].shape == (1, 64, 16, 16)
assert output[1].shape == (1, 128, 8, 8)
assert output[2].shape == (1, 256, 4, 4)
assert output[3].shape == (1, 512, 2, 2)

model = TIMMWrapper(
model_name="mobilenetv3_large_100",
features_only=True,
pretrained=False,
output_stride=32,
out_indices=(1, 2, 3, 4),
)
output = model(input)

assert len(output) == 4
assert output[0].shape == (1, 24, 16, 16)
assert output[1].shape == (1, 40, 8, 8)
assert output[2].shape == (1, 112, 4, 4)
assert output[3].shape == (1, 960, 2, 2)

0 comments on commit 9795110

Please sign in to comment.