Source code for homura.vision.models.cifar_resnet

# ResNet variants
from functools import partial
from typing import Callable, Optional, Type, Union

import torch
from torch import nn
from torchvision import models

from homura.modules.attention import AttentionPool2d
from homura.vision.models import MODEL_REGISTRY
from homura.vision.models._utils import SELayer, conv1x1, conv3x3, init_parameters


[docs]def initialization(module: nn.Module, use_zero_init: bool): init_parameters(module) if use_zero_init: # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 for m in module.modules(): if isinstance(m, BasicBlock): nn.init.constant_(m.norm2.weight, 0)
[docs]class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes: int, planes: int, stride: int, groups: int, width_per_group: int, norm: Optional[Type[nn.BatchNorm2d]], act: Callable[[torch.Tensor], torch.Tensor] ): super().__init__() planes = int(planes * (width_per_group / 16)) * groups self.conv1 = conv3x3(in_planes, planes, stride, bias=norm is None) self.conv2 = conv3x3(planes, planes, bias=norm is None) self.act = act self.norm1 = nn.Identity() if norm is None else norm(num_features=planes) self.norm2 = nn.Identity() if norm is None else norm(num_features=planes) self.downsample = nn.Identity() if in_planes != planes: self.downsample = nn.Sequential(conv1x1(in_planes, planes, stride=stride, bias=norm is None), nn.Identity() if norm is None else norm(num_features=planes))
[docs] def forward(self, x: torch.Tensor ) -> torch.Tensor: out = self.conv1(x) out = self.norm1(out) out = self.act(out) out = self.conv2(out) out = self.norm2(out) out += self.downsample(x) out = self.act(out) return out
[docs]class PreactBasicBlock(BasicBlock): def __init__(self, in_planes: int, planes: int, stride: int, groups: int, width_per_group: int, norm: Optional[Type[nn.BatchNorm2d]], act: Callable[[torch.Tensor], torch.Tensor] ): super().__init__(in_planes, planes, stride, groups, width_per_group, norm, act) self.norm1 = nn.Identity() if norm is None else norm(num_features=in_planes) if in_planes != planes: self.downsample = conv1x1(in_planes, planes, stride=stride, bias=norm is None)
[docs] def forward(self, x: torch.Tensor ) -> torch.Tensor: out = self.norm1(x) out = self.act(out) out = self.conv1(out) out = self.norm2(out) out = self.act(out) out = self.conv2(out) out += self.downsample(x) return out
# for resnext
[docs]class Bottleneck(nn.Module): expansion = 4 def __init__(self, in_planes: int, planes: int, stride: int, groups: int, width_per_group: int, norm: Optional[Type[nn.BatchNorm2d]], act: Callable[[torch.Tensor], torch.Tensor] ): super().__init__() width = int(planes * (width_per_group / 64)) * groups self.conv1 = conv1x1(in_planes, width, bias=norm is None) self.conv2 = conv3x3(width, width, stride, groups=groups, bias=norm is None) self.conv3 = conv1x1(width, planes * self.expansion, bias=norm is None) self.act = act self.norm1 = nn.Identity() if norm is None else norm(width) self.norm2 = nn.Identity() if norm is None else norm(width) self.norm3 = nn.Identity() if norm is None else norm(planes * self.expansion) self.downsample = nn.Identity() if stride != 1 or in_planes != planes * self.expansion: self.downsample = nn.Sequential( conv1x1(in_planes, planes * self.expansion, stride=stride, bias=norm is None), nn.Identity() if norm is None else norm(num_features=planes * self.expansion) )
[docs] def forward(self, x: torch.Tensor ) -> torch.Tensor: out = self.conv1(x) out = self.norm1(out) out = self.act(out) out = self.conv2(out) out = self.norm2(out) out = self.act(out) out = self.conv3(out) out = self.norm3(out) out += self.downsample(x) return self.act(out)
# for SENet
[docs]class SEBasicBlock(BasicBlock): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.norm2 = nn.Sequential(self.norm2, SELayer(self.conv2.out_channels, kwargs['reduction']))
[docs]class SEBottleneck(Bottleneck): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.norm3 = nn.Sequential(self.norm3, SELayer(self.conv3.out_channels, kwargs['reduction']))
[docs]class ResNet(nn.Module): """ResNet for CIFAR data. For ImageNet classification, use `torchvision`'s. """ def __init__(self, block: Type[Union[BasicBlock, Bottleneck]], num_classes: int, layer_depth: int, width: int = 16, widen_factor: int = 1, in_channels: int = 3, groups: int = 1, width_per_group: int = 16, norm: Optional[Type[nn.BatchNorm2d]] = nn.BatchNorm2d, act: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(), preact: bool = False, final_pool: Callable[[torch.Tensor], torch.Tensor] = nn.AdaptiveAvgPool2d(1), initializer: Optional[Callable[[nn.Module], None]] = None ): super(ResNet, self).__init__() self.inplane = width self.groups = groups self.norm = norm self.width_per_group = width_per_group self.preact = preact self.conv1 = conv3x3(in_channels, width, stride=1, bias=norm is None) self.norm1 = nn.Identity() if norm is None else norm(4 * width * block.expansion * widen_factor if self.preact else width) self.act = act self.layer1 = self._make_layer(block, width * widen_factor, layer_depth=layer_depth, stride=1) self.layer2 = self._make_layer(block, width * 2 * widen_factor, layer_depth=layer_depth, stride=2) self.layer3 = self._make_layer(block, width * 4 * widen_factor, layer_depth=layer_depth, stride=2) self.final_pool = final_pool self.fc = nn.Linear(4 * width * block.expansion * widen_factor, num_classes) if initializer is None: initialization(self, False) else: initializer(self) def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, layer_depth: int, stride: int, ) -> nn.Sequential: layers = [] for i in range(layer_depth): layers.append( block(self.inplane, planes, stride if i == 0 else 1, self.groups, self.width_per_group, self.norm, self.act) ) if i == 0: self.inplane = planes * block.expansion return nn.Sequential(*layers)
[docs] def forward(self, x): x = self.conv1(x) if not self.preact: x = self.norm1(x) x = self.act(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) if self.preact: x = self.norm1(x) x = self.act(x) x = self.final_pool(x) x = x.flatten(1) x = self.fc(x) return x
[docs]def resnet(num_classes: int, depth: int, in_channels: int = 3, norm: Optional[Type[nn.BatchNorm2d]] = nn.BatchNorm2d, act: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(), block: Type[BasicBlock] = BasicBlock, **kwargs ) -> ResNet: f"resnet-{depth}" assert (depth - 2) % 6 == 0 layer_depth = (depth - 2) // 6 return ResNet(block, num_classes, layer_depth, in_channels=in_channels, norm=norm, act=act, **kwargs)
[docs]def wide_resnet(num_classes: int, depth: int, widen_factor: int, in_channels: int = 3, norm: Optional[Type[nn.BatchNorm2d]] = nn.BatchNorm2d, act: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(), block: Type[BasicBlock] = PreactBasicBlock, **kwargs ) -> ResNet: f"wideresnet-{depth}-{widen_factor}" assert (depth - 4) % 6 == 0 layer_depth = (depth - 4) // 6 return ResNet(block, num_classes, layer_depth, in_channels=in_channels, widen_factor=widen_factor, norm=norm, act=act, preact=True, **kwargs)
[docs]def resnext(num_classes: int, depth: int, width_per_group: int, groups: int, in_channels: int, norm: Optional[Type[nn.BatchNorm2d]] = nn.BatchNorm2d, act: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(), block: Type[Bottleneck] = Bottleneck, **kwargs ) -> ResNet: f"resnext-{depth}_{groups}x{width_per_group}d" assert (depth - 2) % 9 == 0 layer_depth = (depth - 2) // 9 return ResNet(block, num_classes, layer_depth, width=64, in_channels=in_channels, groups=groups, width_per_group=width_per_group, norm=norm, act=act, **kwargs)
[docs]@MODEL_REGISTRY.register def resnet20(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ ResNet by He+16 """ return resnet(num_classes, 20, in_channels)
[docs]@MODEL_REGISTRY.register def resnet32(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ ResNet by He+16 """ return resnet(num_classes, 32, in_channels)
[docs]@MODEL_REGISTRY.register def resnet56(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ ResNet by He+16 """ return resnet(num_classes, 56, in_channels)
[docs]@MODEL_REGISTRY.register def resnet110(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ ResNet by He+16 """ return resnet(num_classes, 110, in_channels)
[docs]@MODEL_REGISTRY.register def se_resnet20(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ SEResNet by Hu+18 """ return resnet(num_classes, 20, in_channels, block=partial(SEBasicBlock, reduction=16))
[docs]@MODEL_REGISTRY.register def se_resnet56(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ SEResNet by Hu+18 """ return resnet(num_classes, 56, in_channels, block=partial(SEBasicBlock, reduction=16))
[docs]@MODEL_REGISTRY.register def wrn16_8(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ WideResNet by Zagoruyko&Komodakis 17 """ return wide_resnet(num_classes, 16, 8, in_channels)
[docs]@MODEL_REGISTRY.register def wrn28_2(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ WideResNet by Zagoruyko&Komodakis 17 """ return wide_resnet(num_classes, 28, 2, in_channels)
[docs]@MODEL_REGISTRY.register def wrn28_10(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ WideResNet by Zagoruyko&Komodakis 17 """ return wide_resnet(num_classes, 28, 10, in_channels)
[docs]@MODEL_REGISTRY.register def wrn40_2(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ WideResNet by Zagoruyko&Komodakis 17 """ return wide_resnet(num_classes, 40, 2, in_channels)
[docs]@MODEL_REGISTRY.register def resnext29_32x4d(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ ResNeXT by Xie+17 """ return resnext(num_classes, 29, 4, 32, in_channels)
[docs]@MODEL_REGISTRY.register def resnext29_8x64d(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ ResNeXT by Xie+17 """ return resnext(num_classes, 29, 64, 8, in_channels)
[docs]@MODEL_REGISTRY.register def wrn28_2_attention_pool(num_classes: int = 10, in_channels: int = 3, num_heads: int = 2 ) -> ResNet: return wide_resnet(num_classes, 28, 2, in_channels, final_pool=AttentionPool2d(2 * 64, num_heads))
[docs]@MODEL_REGISTRY.register def wrn28_10_attention_pool(num_classes: int = 10, in_channels: int = 3, num_heads: int = 10 ) -> ResNet: return wide_resnet(num_classes, 28, 10, in_channels, final_pool=AttentionPool2d(10 * 64, num_heads))
[docs]class TVResNet(models.ResNet): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.maxpool = nn.Identity()
[docs]@MODEL_REGISTRY.register def cifar_resnet18(num_classes: int = 10, ) -> TVResNet: return TVResNet(models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
[docs]@MODEL_REGISTRY.register def cifar_resnet50(num_classes: int = 10, ) -> TVResNet: return TVResNet(models.resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes)