Source code for homura.vision.models.cifar_resnet

# ResNet variants
from __future__ import annotations

from functools import partial
from typing import Callable, Type

import torch
from torch import nn
from torchvision import models

from homura.modules.attention import AttentionPool2d
from homura.vision.models import MODEL_REGISTRY
from homura.vision.models._utils import SELayer, conv1x1, conv3x3, init_parameters


[docs]def initialization(module: nn.Module, use_zero_init: bool): init_parameters(module) if use_zero_init: # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 for m in module.modules(): if isinstance(m, BasicBlock): nn.init.constant_(m.norm2.weight, 0)
[docs]class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes: int, planes: int, stride: int, groups: int, width_per_group: int, norm: Type[nn.BatchNorm2d], act: Callable[[torch.Tensor], torch.Tensor] ): super().__init__() planes = int(planes * (width_per_group / 16)) * groups self.conv1 = conv3x3(in_planes, planes, stride, bias=norm is None) self.conv2 = conv3x3(planes, planes, bias=norm is None) self.act = act self.norm1 = nn.Identity() if norm is None else norm(num_features=planes) self.norm2 = nn.Identity() if norm is None else norm(num_features=planes) self.downsample = nn.Identity() if in_planes != planes: self.downsample = nn.Sequential(conv1x1(in_planes, planes, stride=stride, bias=norm is None), nn.Identity() if norm is None else norm(num_features=planes))
[docs] def forward(self, x: torch.Tensor ) -> torch.Tensor: out = self.conv1(x) out = self.norm1(out) out = self.act(out) out = self.conv2(out) out = self.norm2(out) out += self.downsample(x) out = self.act(out) return out
[docs]class PreactBasicBlock(BasicBlock): def __init__(self, in_planes: int, planes: int, stride: int, groups: int, width_per_group: int, norm: Type[nn.BatchNorm2d], act: Callable[[torch.Tensor], torch.Tensor] ): super().__init__(in_planes, planes, stride, groups, width_per_group, norm, act) self.norm1 = nn.Identity() if norm is None else norm(num_features=in_planes) if in_planes != planes: self.downsample = conv1x1(in_planes, planes, stride=stride, bias=norm is None)
[docs] def forward(self, x: torch.Tensor ) -> torch.Tensor: out = self.norm1(x) out = self.act(out) out = self.conv1(out) out = self.norm2(out) out = self.act(out) out = self.conv2(out) out += self.downsample(x) return out
# for resnext
[docs]class Bottleneck(nn.Module): expansion = 4 def __init__(self, in_planes: int, planes: int, stride: int, groups: int, width_per_group: int, norm: Type[nn.BatchNorm2d], act: Callable[[torch.Tensor], torch.Tensor] ): super().__init__() width = int(planes * (width_per_group / 64)) * groups self.conv1 = conv1x1(in_planes, width, bias=norm is None) self.conv2 = conv3x3(width, width, stride, groups=groups, bias=norm is None) self.conv3 = conv1x1(width, planes * self.expansion, bias=norm is None) self.act = act self.norm1 = nn.Identity() if norm is None else norm(width) self.norm2 = nn.Identity() if norm is None else norm(width) self.norm3 = nn.Identity() if norm is None else norm(planes * self.expansion) self.downsample = nn.Identity() if stride != 1 or in_planes != planes * self.expansion: self.downsample = nn.Sequential( conv1x1(in_planes, planes * self.expansion, stride=stride, bias=norm is None), nn.Identity() if norm is None else norm(num_features=planes * self.expansion) )
[docs] def forward(self, x: torch.Tensor ) -> torch.Tensor: out = self.conv1(x) out = self.norm1(out) out = self.act(out) out = self.conv2(out) out = self.norm2(out) out = self.act(out) out = self.conv3(out) out = self.norm3(out) out += self.downsample(x) return self.act(out)
# for SENet
[docs]class SEBasicBlock(BasicBlock): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.norm2 = nn.Sequential(self.norm2, SELayer(self.conv2.out_channels, kwargs['reduction']))
[docs]class SEBottleneck(Bottleneck): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.norm3 = nn.Sequential(self.norm3, SELayer(self.conv3.out_channels, kwargs['reduction']))
[docs]class ResNet(nn.Module): """ResNet for CIFAR data. For ImageNet classification, use `torchvision`'s. """ def __init__(self, block: Type[BasicBlock | Bottleneck], num_classes: int, layer_depth: int, width: int = 16, widen_factor: int = 1, in_channels: int = 3, groups: int = 1, width_per_group: int = 16, norm: Type[nn.BatchNorm2d] = nn.BatchNorm2d, act: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(), preact: bool = False, final_pool: Callable[[torch.Tensor], torch.Tensor] = nn.AdaptiveAvgPool2d(1), initializer: Callable[[nn.Module, None]] = None ): super(ResNet, self).__init__() self.inplane = width self.groups = groups self.norm = norm self.width_per_group = width_per_group self.preact = preact self.conv1 = conv3x3(in_channels, width, stride=1, bias=norm is None) self.norm1 = nn.Identity() if norm is None else norm(4 * width * block.expansion * widen_factor if self.preact else width) self.act = act self.layer1 = self._make_layer(block, width * widen_factor, layer_depth=layer_depth, stride=1) self.layer2 = self._make_layer(block, width * 2 * widen_factor, layer_depth=layer_depth, stride=2) self.layer3 = self._make_layer(block, width * 4 * widen_factor, layer_depth=layer_depth, stride=2) self.final_pool = final_pool self.fc = nn.Linear(4 * width * block.expansion * widen_factor, num_classes) if initializer is None: initialization(self, False) else: initializer(self) def _make_layer(self, block: Type[BasicBlock | Bottleneck], planes: int, layer_depth: int, stride: int, ) -> nn.Sequential: layers = [] for i in range(layer_depth): layers.append( block(self.inplane, planes, stride if i == 0 else 1, self.groups, self.width_per_group, self.norm, self.act) ) if i == 0: self.inplane = planes * block.expansion return nn.Sequential(*layers)
[docs] def forward(self, x): x = self.conv1(x) if not self.preact: x = self.norm1(x) x = self.act(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) if self.preact: x = self.norm1(x) x = self.act(x) x = self.final_pool(x) x = x.flatten(1) x = self.fc(x) return x
[docs]def resnet(num_classes: int, depth: int, in_channels: int = 3, norm: Type[nn.BatchNorm2d] = nn.BatchNorm2d, act: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(), block: Type[BasicBlock] = BasicBlock, **kwargs ) -> ResNet: f"resnet-{depth}" assert (depth - 2) % 6 == 0 layer_depth = (depth - 2) // 6 return ResNet(block, num_classes, layer_depth, in_channels=in_channels, norm=norm, act=act, **kwargs)
[docs]def wide_resnet(num_classes: int, depth: int, widen_factor: int, in_channels: int = 3, norm: Type[nn.BatchNorm2d] = nn.BatchNorm2d, act: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(), block: Type[BasicBlock] = PreactBasicBlock, **kwargs ) -> ResNet: f"wideresnet-{depth}-{widen_factor}" assert (depth - 4) % 6 == 0 layer_depth = (depth - 4) // 6 return ResNet(block, num_classes, layer_depth, in_channels=in_channels, widen_factor=widen_factor, norm=norm, act=act, preact=True, **kwargs)
[docs]def resnext(num_classes: int, depth: int, width_per_group: int, groups: int, in_channels: int, norm: Type[nn.BatchNorm2d] = nn.BatchNorm2d, act: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(), block: Type[Bottleneck] = Bottleneck, **kwargs ) -> ResNet: f"resnext-{depth}_{groups}x{width_per_group}d" assert (depth - 2) % 9 == 0 layer_depth = (depth - 2) // 9 return ResNet(block, num_classes, layer_depth, width=64, in_channels=in_channels, groups=groups, width_per_group=width_per_group, norm=norm, act=act, **kwargs)
[docs]@MODEL_REGISTRY.register def resnet20(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ ResNet by He+16 """ return resnet(num_classes, 20, in_channels)
[docs]@MODEL_REGISTRY.register def resnet32(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ ResNet by He+16 """ return resnet(num_classes, 32, in_channels)
[docs]@MODEL_REGISTRY.register def resnet56(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ ResNet by He+16 """ return resnet(num_classes, 56, in_channels)
[docs]@MODEL_REGISTRY.register def resnet110(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ ResNet by He+16 """ return resnet(num_classes, 110, in_channels)
[docs]@MODEL_REGISTRY.register def se_resnet20(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ SEResNet by Hu+18 """ return resnet(num_classes, 20, in_channels, block=partial(SEBasicBlock, reduction=16))
[docs]@MODEL_REGISTRY.register def se_resnet56(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ SEResNet by Hu+18 """ return resnet(num_classes, 56, in_channels, block=partial(SEBasicBlock, reduction=16))
[docs]@MODEL_REGISTRY.register def wrn16_8(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ WideResNet by Zagoruyko&Komodakis 17 """ return wide_resnet(num_classes, 16, 8, in_channels)
[docs]@MODEL_REGISTRY.register def wrn28_2(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ WideResNet by Zagoruyko&Komodakis 17 """ return wide_resnet(num_classes, 28, 2, in_channels)
[docs]@MODEL_REGISTRY.register def wrn28_10(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ WideResNet by Zagoruyko&Komodakis 17 """ return wide_resnet(num_classes, 28, 10, in_channels)
[docs]@MODEL_REGISTRY.register def wrn40_2(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ WideResNet by Zagoruyko&Komodakis 17 """ return wide_resnet(num_classes, 40, 2, in_channels)
[docs]@MODEL_REGISTRY.register def resnext29_32x4d(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ ResNeXT by Xie+17 """ return resnext(num_classes, 29, 4, 32, in_channels)
[docs]@MODEL_REGISTRY.register def resnext29_8x64d(num_classes: int = 10, in_channels: int = 3 ) -> ResNet: """ ResNeXT by Xie+17 """ return resnext(num_classes, 29, 64, 8, in_channels)
[docs]@MODEL_REGISTRY.register def wrn28_2_attention_pool(num_classes: int = 10, in_channels: int = 3, num_heads: int = 2 ) -> ResNet: return wide_resnet(num_classes, 28, 2, in_channels, final_pool=AttentionPool2d(2 * 64, num_heads))
[docs]@MODEL_REGISTRY.register def wrn28_10_attention_pool(num_classes: int = 10, in_channels: int = 3, num_heads: int = 10 ) -> ResNet: return wide_resnet(num_classes, 28, 10, in_channels, final_pool=AttentionPool2d(10 * 64, num_heads))
[docs]class TVResNet(models.ResNet): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.maxpool = nn.Identity()
[docs]@MODEL_REGISTRY.register def cifar_resnet18(num_classes: int = 10, ) -> TVResNet: return TVResNet(models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
[docs]@MODEL_REGISTRY.register def cifar_resnet50(num_classes: int = 10, ) -> TVResNet: return TVResNet(models.resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes)