import math
import torch
from torch import nn
from torch.nn import functional as F
from homura.vision.models import MODEL_REGISTRY
from homura.vision.models._utils import init_parameters
__all__ = ["unet", "CustomUNet"]
class Upsample(nn.Module):
def __init__(self, scale_factor, mode):
super(Upsample, self).__init__()
assert mode in ("nearest", "fc", "bilinear", "trilinear", "area")
self._scale_factor = scale_factor
self._mode = mode
def forward(self, input):
return F.interpolate(input, scale_factor=self._scale_factor, mode=self._mode, align_corners=False)
class Block(nn.Module):
def __init__(self, in_channel, out_channel):
"""
>>> a = torch.randn(1, 1, 128, 128)
>>> encoder = Block(1, 64)
>>> encoder(a).size()
torch.Size([1, 64, 128, 128])
"""
super().__init__()
self.block = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True),
nn.Conv2d(out_channel, out_channel,
kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True))
def forward(self, input):
return self.block(input)
class UpsampleBlock(nn.Module):
def __init__(self, in_channel, out_channel, upsample=True):
"""
>>> a = torch.randn(1, 1, 128, 128)
>>> encoder = Block(1, 64)
>>> encoder(a).size()
torch.Size([1, 64, 128, 128])
"""
super().__init__()
if upsample:
self.upsample = nn.Sequential(Upsample(scale_factor=2, mode="bilinear"),
nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1))
else:
self.upsample = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2)
self.decoder = Block(in_channel, out_channel)
def forward(self, input, bypass):
x = self.upsample(input)
_, _, i_h, i_w = x.shape
_, _, b_h, b_w = bypass.shape
pad = (math.ceil((b_w - i_w) / 2), math.floor((b_w - i_w) / 2),
math.ceil((b_h - i_h) / 2), math.floor((b_h - i_h) / 2))
x = F.pad(x, pad)
x = self.decoder(torch.cat([x, bypass], dim=1))
return x
class DownsampleBlock(Block):
def forward(self, input):
input = F.max_pool2d(input, 2, 2)
return self.block(input)
class UNet(nn.Module):
def __init__(self, num_classes, input_channels,
config=((64, 128, 256, 512, 1024),
(1024, 512, 256, 128, 64))):
"""
UNet, proposed in Ronneberger et al. (2015)
:param num_classes: number of output classes
:param input_channels: number of input channels
"""
super(UNet, self).__init__()
encoder_config, decoder_config = config
encoder_config = list(encoder_config)
decoder_config = list(decoder_config)
# zip (3, 64, 128, 256, 512) and (64, 128, 256, 512, 1024)
# (3, 64), (64, 128), (128, 256), (256, 512), (512, 1024)
encoder_config = list(zip([input_channels] + encoder_config[:-1], encoder_config))
# (1024, 512), (512, 256), (256, 128), (128, 64)
decoder_config = list(zip(decoder_config, decoder_config[1:]))
self.encoders = nn.ModuleList([Block(*encoder_config[0])] +
[DownsampleBlock(i, j) for i, j in encoder_config[1:]])
self.decoders = nn.ModuleList([UpsampleBlock(i, j) for i, j in decoder_config])
self.channel_conv = nn.Conv2d(64, num_classes, kernel_size=1)
init_parameters(self)
def forward(self, input):
down1 = self.encoders[0](input)
down2 = self.encoders[1](down1)
down3 = self.encoders[2](down2)
down4 = self.encoders[3](down3)
down5 = self.encoders[4](down4)
up1 = self.decoders[0](down5, down4)
up2 = self.decoders[1](up1, down3)
up3 = self.decoders[2](up2, down2)
up4 = self.decoders[3](up3, down1)
return self.channel_conv(up4)
[docs]class CustomUNet(UNet):
[docs] def forward(self, input):
x = [input]
for enc in self.encoders:
# [input, enc1(input), enc2(input), enc3(input)]
x += [enc(x[-1])]
# enc3(input), (enc2(input), enc(input), input)
x, *rest = reversed(x)
for dec, _x in zip(self.decoders, rest):
# img = dec(enc3(input), enc2(input))
x = dec(x, _x)
return self.channel_conv(x)
[docs]@MODEL_REGISTRY.register
def unet(num_classes, input_channels=3):
return UNet(num_classes, input_channels)