Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add channel mapper neck #25

Merged
merged 3 commits into from
Aug 26, 2022
Merged

Add channel mapper neck #25

merged 3 commits into from
Aug 26, 2022

Conversation

rentainhe
Copy link
Collaborator

@rentainhe rentainhe commented Aug 25, 2022

TODO

  • Add ChannelMapper neck for deformable-detr
  • Test DAB-Deformable-DETR inference mAP: 48.74

@rentainhe
Copy link
Collaborator Author

Simple Test

from typing import Dict, List
import torch
import torch.nn as nn
from detectron2.modeling import ShapeSpec

from ideadet.layers import ConvNormAct

class ChannelMapper(nn.Module):
    def __init__(self,
                 input_shapes: Dict[str, ShapeSpec],
                 in_features: List[str],
                 out_channels: int,
                 kernel_size: int = 1,
                 stride: int = 1,
                 bias: bool = True,
                 groups: int = 1,
                 dilation: int = 1,
                 norm_layer: nn.Module = None,
                 activation: nn.Module = None,
                 num_outs: int = None,
                 **kwargs,
                ):
        super(ChannelMapper, self).__init__()
        self.extra_convs = None
        
        in_channels_per_feature = [input_shapes[f].channels for f in in_features]

        if num_outs is None:
            num_outs = len(input_shapes)
        
        self.convs = nn.ModuleList()
        for in_channel in in_channels_per_feature:
            self.convs.append(
                ConvNormAct(
                    in_channels=in_channel,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=(kernel_size - 1) // 2,
                    bias=bias,
                    groups=groups,
                    dilation=dilation,
                    norm_layer=norm_layer,
                    activation=activation,
                )
            )

        if num_outs > len(in_channels_per_feature):
            self.extra_convs = nn.ModuleList()
            for i in range(len(in_channels_per_feature), num_outs):
                if i == len(in_channels_per_feature):
                    in_channel = in_channels_per_feature[-1]
                else:
                    in_channel = out_channels
                self.extra_convs.append(
                    ConvNormAct(
                        in_channels=in_channel,
                        out_channels=out_channels,
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        bias=bias,
                        groups=groups,
                        dilation=dilation,
                        norm_layer=norm_layer,
                        activation=activation,
                    )
                )
    
        self.input_shapes = input_shapes
        self.in_features = in_features
        self.out_channels = out_channels
    
    def forward(self, inputs):
        # inputs: key, value
        assert len(inputs) == len(self.convs)
        outs = [self.convs[i](inputs[self.in_features[i]]) for i in range(len(inputs))]
        if self.extra_convs:
            for i in range(len(self.extra_convs)):
                if i == 0:
                    outs.append(self.extra_convs[0](inputs[self.in_features[-1]]))
                else:
                    outs.append(self.extra_convs[i](outs[-1]))
        return tuple(outs)


channel_mapper = ChannelMapper(
    input_shapes={
    "res3": ShapeSpec(channels=512),
    "res4": ShapeSpec(channels=1024),
    "res5": ShapeSpec(channels=2048),
    },
    in_features=["res3", "res4", "res5"],
    out_channels=256,
    num_outs=4
)

x = {
    "res3": torch.randn(1, 512, 48, 48),
    "res4": torch.randn(1, 1024, 24, 24),
    "res5": torch.randn(1, 2048, 12, 12)
}

out = channel_mapper(x)

@rentainhe rentainhe merged commit d0e7124 into main Aug 26, 2022
@rentainhe rentainhe deleted the add_neck branch August 26, 2022 03:27
Lontoone pushed a commit to Lontoone/detrex that referenced this pull request Jan 8, 2024
* add channel mapper neck

* add channel mapper and test inference mAP

* refine args

Co-authored-by: ntianhe ren <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant