forked from MingjieChen/DYGANVC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
wadain.py
81 lines (59 loc) · 2.84 KB
/
wadain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
import math
class EqualLinear(nn.Module):
def __init__(self, dim_in, dim_out, bias = True, bias_init = 0, lr_mul = 1, activation = None):
super().__init__()
self.weight = nn.Parameter(torch.randn(dim_out, dim_in).div_(lr_mul), requires_grad = True)
self.bias = nn.Parameter(torch.zeros(dim_out).fill_(bias_init), requires_grad = True)
self.activation = activation
self.scale = (1 / math.sqrt(dim_in)) * lr_mul
#self.relu = nn.LeakyReLU(0.2)
self.lr_mul = lr_mul
def forward(self, x):
out = F.linear(x, self.weight * self.scale, bias = self.bias * self.lr_mul)
#out = self.relu(out)
return out
class Wadain(nn.Module):
'''a stylegan2 module'''
def __init__(self, dim_in, dim_out, kernel_size, use_act = True, spk_emb_dim = 128 ):
super().__init__()
self.use_act = use_act
if self.use_act:
self.act = nn.LeakyReLU(0.2)
self.dim_out = dim_out
else:
self.dim_out = dim_out
self.style_linear = EqualLinear( spk_emb_dim, dim_in, bias_init = 1)
self.style_linear_beta = EqualLinear(spk_emb_dim, dim_in, bias_init = 1)
self.weight = nn.Parameter(torch.randn(1, self.dim_out, dim_in, kernel_size), requires_grad = True)
fan_in = dim_in * kernel_size **2
self.scale = 1 / math.sqrt(fan_in)
if kernel_size %2 ==0:
self.padding = (kernel_size //2, kernel_size // 2 - 1)
else:
self.padding = (kernel_size // 2, kernel_size // 2)
self.dim_in = dim_in
self.kernel_size = kernel_size
def forward(self, inputs):
x, c_src, c_trg = inputs
batch_size, in_channel, t = x.size()
s = self.style_linear(c_trg).view(batch_size, 1, in_channel, 1)
beta = self.style_linear_beta(c_trg).view(batch_size, 1, in_channel, 1)
# scale weights
weight = self.scale * (self.weight * s + beta) # b out in ks
# demodulate
demod = torch.rsqrt(weight.pow(2).sum([2,3]) + 1e-8)
demod_mean = torch.mean(weight.view(batch_size, self.dim_out, -1), dim = 2)
weight = (weight - demod_mean.view(batch_size, self.dim_out, 1,1) ) * demod.view(batch_size, self.dim_out, 1,1)
weight = weight.view(batch_size * self.dim_out, self.dim_in, self.kernel_size)
x = x.reshape(1, batch_size * in_channel, t)
x = F.pad(x, self.padding, mode = 'reflect')
out = F.conv1d(x, weight, padding = 0, groups = batch_size)
_, _, new_t = out.size()
out = out.view(batch_size, self.dim_out, new_t)
if self.use_act:
out = self.act(out)
return (out, c_src, c_trg)