-
Notifications
You must be signed in to change notification settings - Fork 10
/
gan_resnet.py
88 lines (70 loc) · 2.42 KB
/
gan_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from torch import nn
from torch.nn import functional as F
import numpy as np
class Discriminator(nn.Module):
def __init__(self, B,nlabels, size, nfilter=64, nfilter_max=1024):
super().__init__()
# self.embed_size = embed_size
s0 = self.s0 = 4
nf = self.nf = nfilter
nf_max = self.nf_max = nfilter_max
# Submodules
nlayers = int(np.log2(size / s0))
self.nf0 = min(nf_max, nf * 2 ** nlayers)
blocks = [
ResnetBlock(nf, nf)
]
for i in range(nlayers):
nf0 = min(nf * 2 ** i, nf_max)
nf1 = min(nf * 2 ** (i + 1), nf_max)
blocks += [
nn.AvgPool2d(3, stride=2, padding=1),
ResnetBlock(nf0, nf1),
]
self.conv_img = nn.Conv2d(B, 1 * nf, 3, padding=1)
self.resnet = nn.Sequential(*blocks)
self.fc = nn.Linear(self.nf0 * s0 * s0, nlabels)
def forward(self, x, y):
assert (x.size(0) == y.size(0))
batch_size = x.size(0)
out = self.conv_img(x)
out = self.resnet(out)
out = out.view(batch_size, self.nf0 * self.s0 * self.s0)
out = self.fc(actvn(out))
# index = Variable(torch.LongTensor(range(out.size(0))))
# if y.is_cuda:
# index = index.cuda()
# out = out[index, y]
return out
class ResnetBlock(nn.Module):
def __init__(self, fin, fout, fhidden=None, is_bias=True):
super().__init__()
# Attributes
self.is_bias = is_bias
self.learned_shortcut = (fin != fout)
self.fin = fin
self.fout = fout
if fhidden is None:
self.fhidden = min(fin, fout)
else:
self.fhidden = fhidden
# Submodules
self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)
def forward(self, x):
x_s = self._shortcut(x)
dx = self.conv_0(actvn(x))
dx = self.conv_1(actvn(dx))
out = x_s + 0.1 * dx
return out
def _shortcut(self, x):
if self.learned_shortcut:
x_s = self.conv_s(x)
else:
x_s = x
return x_s
def actvn(x):
out = F.leaky_relu(x, 2e-1)
return out