Skip to content

Commit

Permalink
add uniformer backbone
Browse files Browse the repository at this point in the history
  • Loading branch information
sithu31296 committed Feb 2, 2022
1 parent 8b4ffbe commit e69675f
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 3 deletions.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
* Human Parsing
* Face Parsing
* Medical Image Segmentation (Coming Soon)
* Background Removal (Coming Soon)
* 20+ Datasets
* 10+ SOTA Backbones
* 10+ SOTA Semantic Segmentation Models
Expand All @@ -31,12 +30,14 @@ Supported Backbones:
* [ResNetD](https://arxiv.org/abs/1812.01187) (ArXiv 2018)
* [MobileNetV2](https://arxiv.org/abs/1801.04381) (CVPR 2018)
* [MobileNetV3](https://arxiv.org/abs/1905.02244) (ICCV 2019)
* [MiT](https://arxiv.org/abs/2105.15203v2) (ArXiv 2021)
* [PVTv2](https://arxiv.org/abs/2106.13797) (ArXiv 2021)
* [ResT](https://arxiv.org/abs/2105.13677v3) (ArXiv 2021)
* [MicroNet](https://arxiv.org/abs/2108.05894) (ICCV 2021)
* [ResNet+](https://arxiv.org/abs/2110.00476) (ArXiv 2021)
* [PoolFormer](https://arxiv.org/abs/2111.11418) (ArXiv 2021)
* [ConvNeXt](https://arxiv.org/abs/2201.03545) (ArXiv 2022)
* [UniFormer](https://arxiv.org/abs/2201.09450) (ArXiv 2022)

Supported Heads/Methods:
* [FCN](https://arxiv.org/abs/1411.4038) (CVPR 2015)
Expand Down Expand Up @@ -369,6 +370,15 @@ $ python scripts/tflite_infer.py --model <TFLite_MODEL_PATH> --img-path <TEST_IM
primaryClass={cs.CV}
}
@misc{li2022uniformer,
title={UniFormer: Unifying Convolution and Self-attention for Visual Recognition},
author={Kunchang Li and Yali Wang and Junhao Zhang and Peng Gao and Guanglu Song and Yu Liu and Hongsheng Li and Yu Qiao},
year={2022},
eprint={2201.09450},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

</details>
6 changes: 4 additions & 2 deletions docs/BACKBONES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Backbone | Variants | ImageNet-1k Top-1 Acc (%) | Params (M) | GFLOPs | Weights
--- | --- | --- | --- | --- | ---
MicroNet | M1\|M2\|M3 | 51.4`\|`59.4`\|`62.5 | 0.3`\|`0.6`\|`0.7 | 7M`\|`16M`\|`28M | [download][micronetw]
MicroNet | M1\|M2\|M3 | 51.4`\|`59.4`\|`62.5 | 1`\|`2`\|`3 | 7M`\|`14M`\|`23M | [download][micronetw]
MobileNetV2 | 1.0 | 71.9 | 3 | 300M | [download][mobilenetv2w]
MobileNetV3 | S\|L | 67.7`\|`74.0 | 3`\|`5 | 56M`\|`219M | [S][mobilenetv3s]\|[L][mobilenetv3l]
||
Expand All @@ -13,6 +13,7 @@ PVTv2 | B1\|B2\|B4 | 78.7`\|`82.0`\|`83.6 | 14`\|`25`\|`63 | 2`\|`4`\|`10 | [dow
ResT | S\|B\|L | 79.6`\|`81.6`\|`83.6 | 14`\|`30`\|`52 | 2`\|`4`\|`8 | [download][restw]
PoolFormer | S24\|S36\|M36 | 80.3`\|`81.4`\|`82.1 | 21`\|`31`\|`56 | 4`\|`5`\|`9 | [download][poolformerw]
ConvNeXt | T\|S\|B | 82.1`\|`83.1`\|`83.8 | 28`\|`50`\|`89 | 5`\|`9`\|`15 | [download][convnextw]
UniFormer | S\|B | 82.9`\|`83.8 | 22`\|`50 | 4`\|`8 | [download][uniformerw]

> Notes: Download backbones' weights for [HarDNet-70][hardnetw] and [DDRNet-23slim][ddrnetw].
Expand All @@ -29,4 +30,5 @@ ConvNeXt | T\|S\|B | 82.1`\|`83.1`\|`83.8 | 28`\|`50`\|`89 | 5`\|`9`\|`15 | [dow
[hardnetw]: https://drive.google.com/file/d/1HAFHvtodAPL_eb4LX_rb0FJZyKTOo4mK/view?usp=sharing
[ddrnetw]: https://drive.google.com/file/d/1TaDJ3yG8ojjcsbQZwkn5LlFMNEcr8vu2/view?usp=sharing
[poolformerw]: https://drive.google.com/drive/folders/18OyxHHpVq-9pMMG2eu1jot7n-po4dUpD?usp=sharing
[convnextw]: https://drive.google.com/drive/folders/1Oe50_zY4QKFZ0_22mSHKuNav0GiRcgWA?usp=sharing
[convnextw]: https://drive.google.com/drive/folders/1Oe50_zY4QKFZ0_22mSHKuNav0GiRcgWA?usp=sharing
[uniformerw]:
180 changes: 180 additions & 0 deletions semseg/models/backbones/uniformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import torch
from torch import nn, Tensor
from semseg.models.layers import DropPath


class MLP(nn.Module):
def __init__(self, dim, hidden_dim, out_dim=None) -> None:
super().__init__()
out_dim = out_dim or dim
self.fc1 = nn.Linear(dim, hidden_dim)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, out_dim)

def forward(self, x: Tensor) -> Tensor:
return self.fc2(self.act(self.fc1(x)))


class CMLP(nn.Module):
def __init__(self, dim, hidden_dim, out_dim=None) -> None:
super().__init__()
out_dim = out_dim or dim
self.fc1 = nn.Conv2d(dim, hidden_dim, 1)
self.act = nn.GELU()
self.fc2 = nn.Conv2d(hidden_dim, out_dim, 1)

def forward(self, x: Tensor) -> Tensor:
return self.fc2(self.act(self.fc1(x)))


class Attention(nn.Module):
def __init__(self, dim, num_heads=8) -> None:
super().__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
self.qkv = nn.Linear(dim, dim*3)
self.proj = nn.Linear(dim, dim)

def forward(self, x: Tensor) -> Tensor:
B, N, C = x.shape
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, C//self.num_heads).permute(2, 0, 3, 1, 4)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x


class CBlock(nn.Module):
def __init__(self, dim, dpr=0.):
super().__init__()
self.pos_embed = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
self.norm1 = nn.BatchNorm2d(dim)
self.conv1 = nn.Conv2d(dim, dim, 1)
self.conv2 = nn.Conv2d(dim, dim, 1)
self.attn = nn.Conv2d(dim, dim, 5, 1, 2, groups=dim)
self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
self.norm2 = nn.BatchNorm2d(dim)
self.mlp = CMLP(dim, int(dim*4))

def forward(self, x: Tensor) -> Tensor:
x = x + self.pos_embed(x)
x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x


class SABlock(nn.Module):
def __init__(self, dim, num_heads, dpr=0.) -> None:
super().__init__()
self.pos_embed = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads)
self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, int(dim*4))

def forward(self, x: Tensor) -> Tensor:
x = x + self.pos_embed(x)
B, N, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).reshape(B, N, H, W)
return x


class PatchEmbed(nn.Module):
def __init__(self, patch_size=16, in_ch=3, embed_dim=768) -> None:
super().__init__()
self.norm = nn.LayerNorm(embed_dim)
self.proj = nn.Conv2d(in_ch, embed_dim, patch_size, patch_size)

def forward(self, x: Tensor) -> Tensor:
x = self.proj(x)
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
return x


uniformer_settings = {
'S': [3, 4, 8, 3], # [depth]
'B': [5, 8, 20, 7]
}


class UniFormer(nn.Module):
def __init__(self, model_name: str = 'S') -> None:
super().__init__()
assert model_name in uniformer_settings.keys(), f"UniFormer model name should be in {list(uniformer_settings.keys())}"
depth = uniformer_settings[model_name]

head_dim = 64
drop_path_rate = 0.
embed_dims = [64, 128, 320, 512]

for i in range(4):
self.add_module(f"patch_embed{i+1}", PatchEmbed(4 if i == 0 else 2, 3 if i == 0 else embed_dims[i-1], embed_dims[i]))
self.add_module(f"norm{i+1}", nn.LayerNorm(embed_dims[i]))

self.pos_drop = nn.Dropout(0.)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))]
num_heads = [dim // head_dim for dim in embed_dims]

self.blocks1 = nn.ModuleList([
CBlock(embed_dims[0], dpr[i])
for i in range(depth[0])])

self.blocks2 = nn.ModuleList([
CBlock(embed_dims[1], dpr[i+depth[0]])
for i in range(depth[1])])

self.blocks3 = nn.ModuleList([
SABlock(embed_dims[2], num_heads[2], dpr[i+depth[0]+depth[1]])
for i in range(depth[2])])

self.blocks4 = nn.ModuleList([
SABlock(embed_dims[3], num_heads[3], dpr[i+depth[0]+depth[1]+depth[2]])
for i in range(depth[3])])


def forward(self, x: torch.Tensor):
outs = []

x = self.patch_embed1(x)
x = self.pos_drop(x)
for blk in self.blocks1:
x = blk(x)
x_out = self.norm1(x.permute(0, 2, 3, 1))
outs.append(x_out.permute(0, 3, 1, 2))

x = self.patch_embed2(x)
for blk in self.blocks2:
x = blk(x)
x_out = self.norm2(x.permute(0, 2, 3, 1))
outs.append(x_out.permute(0, 3, 1, 2))

x = self.patch_embed3(x)
for blk in self.blocks3:
x = blk(x)
x_out = self.norm3(x.permute(0, 2, 3, 1))
outs.append(x_out.permute(0, 3, 1, 2))

x = self.patch_embed4(x)
for blk in self.blocks4:
x = blk(x)
x_out = self.norm4(x.permute(0, 2, 3, 1))
outs.append(x_out.permute(0, 3, 1, 2))

return outs

if __name__ == '__main__':
model = UniFormer('S')
model.load_state_dict(torch.load('C:\\Users\\sithu\\Documents\\weights\\backbones\\uniformer\\uniformer_small_in1k.pth', map_location='cpu')['model'], strict=False)
x = torch.randn(1, 3, 224, 224)
feats = model(x)
for y in feats:
print(y.shape)

0 comments on commit e69675f

Please sign in to comment.