-
Notifications
You must be signed in to change notification settings - Fork 48
/
blocks.py
81 lines (75 loc) · 3.44 KB
/
blocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, in_features, out_features, num_conv, pool=False):
super(ConvBlock, self).__init__()
features = [in_features] + [out_features for i in range(num_conv)]
layers = []
for i in range(len(features)-1):
layers.append(nn.Conv2d(in_channels=features[i], out_channels=features[i+1], kernel_size=3, padding=1, bias=True))
layers.append(nn.BatchNorm2d(num_features=features[i+1], affine=True, track_running_stats=True))
layers.append(nn.ReLU())
if pool:
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
self.op = nn.Sequential(*layers)
def forward(self, x):
return self.op(x)
class ProjectorBlock(nn.Module):
def __init__(self, in_features, out_features):
super(ProjectorBlock, self).__init__()
self.op = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=1, padding=0, bias=False)
def forward(self, inputs):
return self.op(inputs)
class LinearAttentionBlock(nn.Module):
def __init__(self, in_features, normalize_attn=True):
super(LinearAttentionBlock, self).__init__()
self.normalize_attn = normalize_attn
self.op = nn.Conv2d(in_channels=in_features, out_channels=1, kernel_size=1, padding=0, bias=False)
def forward(self, l, g):
N, C, W, H = l.size()
c = self.op(l+g) # batch_sizex1xWxH
if self.normalize_attn:
a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,W,H)
else:
a = torch.sigmoid(c)
g = torch.mul(a.expand_as(l), l)
if self.normalize_attn:
g = g.view(N,C,-1).sum(dim=2) # batch_sizexC
else:
g = F.adaptive_avg_pool2d(g, (1,1)).view(N,C)
return c.view(N,1,W,H), g
'''
Grid attention block
Reference papers
Attention-Gated Networks https://arxiv.org/abs/1804.05338 & https://arxiv.org/abs/1808.08114
Reference code
https://github.com/ozan-oktay/Attention-Gated-Networks
'''
class GridAttentionBlock(nn.Module):
def __init__(self, in_features_l, in_features_g, attn_features, up_factor, normalize_attn=False):
super(GridAttentionBlock, self).__init__()
self.up_factor = up_factor
self.normalize_attn = normalize_attn
self.W_l = nn.Conv2d(in_channels=in_features_l, out_channels=attn_features, kernel_size=1, padding=0, bias=False)
self.W_g = nn.Conv2d(in_channels=in_features_g, out_channels=attn_features, kernel_size=1, padding=0, bias=False)
self.phi = nn.Conv2d(in_channels=attn_features, out_channels=1, kernel_size=1, padding=0, bias=True)
def forward(self, l, g):
N, C, W, H = l.size()
l_ = self.W_l(l)
g_ = self.W_g(g)
if self.up_factor > 1:
g_ = F.interpolate(g_, scale_factor=self.up_factor, mode='bilinear', align_corners=False)
c = self.phi(F.relu(l_ + g_)) # batch_sizex1xWxH
# compute attn map
if self.normalize_attn:
a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,W,H)
else:
a = torch.sigmoid(c)
# re-weight the local feature
f = torch.mul(a.expand_as(l), l) # batch_sizexCxWxH
if self.normalize_attn:
output = f.view(N,C,-1).sum(dim=2) # weighted sum
else:
output = F.adaptive_avg_pool2d(f, (1,1)).view(N,C)
return c.view(N,1,W,H), output