diff --git a/README.md b/README.md index fe38315c..aa91c375 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,7 @@ Supported Heads/Methods: * [UPerNet][upernet] * [SFNet][sfnet] * [SegFormer][segformer] +* [CondNet][condnet] Supported Standalone Models: * [DDRNet][ddrnet] diff --git a/models/heads/__init__.py b/models/heads/__init__.py index 8ec562c2..aac7208c 100644 --- a/models/heads/__init__.py +++ b/models/heads/__init__.py @@ -4,5 +4,6 @@ from .fpn import FPNHead from .fapn import FaPNHead from .fcn import FCNHead +from .condnet import CondHead -__all__ = ['UPerHead', 'SegFormerHead', 'SFHead', 'FPNHead', 'FaPNHead', 'FCNHead'] \ No newline at end of file +__all__ = ['UPerHead', 'SegFormerHead', 'SFHead', 'FPNHead', 'FaPNHead', 'FCNHead', 'CondHead'] \ No newline at end of file diff --git a/models/heads/condnet.py b/models/heads/condnet.py new file mode 100644 index 00000000..3b21354c --- /dev/null +++ b/models/heads/condnet.py @@ -0,0 +1,67 @@ +import torch +from torch import nn, Tensor +from torch.nn import functional as F + + +class ConvModule(nn.Sequential): + def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1): + super().__init__( + nn.Conv2d(c1, c2, k, s, p, d, g, bias=False), + nn.BatchNorm2d(c2), + nn.ReLU(True) + ) + + +class CondHead(nn.Module): + def __init__(self, in_channel: int = 2048, channel: int = 512, num_classes: int = 19): + super().__init__() + self.num_classes = num_classes + self.weight_num = channel * num_classes + self.bias_num = num_classes + + self.conv = ConvModule(in_channel, channel, 1) + self.dropout = nn.Dropout2d(0.1) + + self.guidance_project = nn.Conv2d(channel, num_classes, 1) + self.filter_project = nn.Conv2d(channel*num_classes, self.weight_num + self.bias_num, 1, groups=num_classes) + + def forward(self, features) -> Tensor: + x = self.dropout(self.conv(features[-1])) + B, C, H, W = x.shape + guidance_mask = self.guidance_project(x) + cond_logit = guidance_mask + + key = x + value = x + guidance_mask = guidance_mask.softmax(dim=1).view(*guidance_mask.shape[:2], -1) + key = key.view(B, C, -1).permute(0, 2, 1) + + cond_filters = torch.matmul(guidance_mask, key) + cond_filters /= H * W + cond_filters = cond_filters.view(B, -1, 1, 1) + cond_filters = self.filter_project(cond_filters) + cond_filters = cond_filters.view(B, -1) + + weight, bias = torch.split(cond_filters, [self.weight_num, self.bias_num], dim=1) + weight = weight.reshape(B * self.num_classes, -1, 1, 1) + bias = bias.reshape(B * self.num_classes) + + value = value.view(-1, H, W).unsqueeze(0) + seg_logit = F.conv2d(value, weight, bias, 1, 0, groups=B).view(B, self.num_classes, H, W) + + if self.training: + return cond_logit, seg_logit + return seg_logit + + +if __name__ == '__main__': + import sys + sys.path.insert(0, '.') + from models.backbones.resnetd import ResNetD + backbone = ResNetD('50') + head = CondHead() + x = torch.randn(2, 3, 224, 224) + features = backbone(x) + outs = head(features) + for out in outs: + print(out.shape) \ No newline at end of file