Skip to content

Commit

Permalink
wrap crd
Browse files Browse the repository at this point in the history
  • Loading branch information
HobbitLong committed Oct 23, 2019
1 parent c86f297 commit 1fb2343
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
57 changes: 53 additions & 4 deletions crd/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,37 @@ class CRDLoss(nn.Module):
includes two symmetric parts:
(a) using teacher as anchor, choose positive and negatives over the student side
(b) using student as anchor, choose positive and negatives over the teacher side
Args:
opt.s_dim: the dimension of student's feature
opt.t_dim: the dimension of teacher's feature
opt.feat_dim: the dimension of the projection space
opt.nce_k: number of negatives paired with each positive
opt.nce_t: the temperature
opt.nce_m: the momentum for updating the memory buffer
opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim
"""
def __init__(self, opt, n_data):
def __init__(self, opt):
super(CRDLoss, self).__init__()
self.contrast = ContrastMemory(opt.feat_dim, n_data, opt.nce_k, opt.nce_t, opt.nce_m)
self.criterion_t = ContrastLoss(n_data)
self.criterion_s = ContrastLoss(n_data)
self.embed_s = Embed(opt.s_dim, opt.feat_dim)
self.embed_t = Embed(opt.t_dim, opt.feat_dim)
self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m)
self.criterion_t = ContrastLoss(opt.n_data)
self.criterion_s = ContrastLoss(opt.n_data)

def forward(self, f_s, f_t, idx, contrast_idx=None):
"""
Args:
f_s: the feature of student network, size [batch_size, s_dim]
f_t: the feature of teacher network, size [batch_size, t_dim]
idx: the indices of these positive samples in the dataset, size [batch_size]
contrast_idx: the indices of negative samples, size [batch_size, nce_k]
Returns:
The contrastive loss
"""
f_s = self.embed_s(f_s)
f_t = self.embed_t(f_t)
out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)
s_loss = self.criterion_s(out_s)
t_loss = self.criterion_t(out_t)
Expand Down Expand Up @@ -51,3 +74,29 @@ def forward(self, x):
loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz

return loss


class Embed(nn.Module):
"""Embedding module"""
def __init__(self, dim_in=1024, dim_out=128):
super(Embed, self).__init__()
self.linear = nn.Linear(dim_in, dim_out)
self.l2norm = Normalize(2)

def forward(self, x):
x = x.view(x.shape[0], -1)
x = self.linear(x)
x = self.l2norm(x)
return x


class Normalize(nn.Module):
"""normalization layer"""
def __init__(self, power=2):
super(Normalize, self).__init__()
self.power = power

def forward(self, x):
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
out = x.div(norm)
return out
4 changes: 2 additions & 2 deletions helper/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o
f_t = feat_t[opt.hint_layer]
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'crd':
f_s = module_list[1](feat_s[-1])
f_t = module_list[2](feat_t[-1])
f_s = feat_s[-1]
f_t = feat_t[-1]
loss_kd = criterion_kd(f_s, f_t, index, contrast_idx)
elif opt.distill == 'attention':
g_s = feat_s[1:-1]
Expand Down
15 changes: 8 additions & 7 deletions train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,14 @@ def main():
module_list.append(regress_s)
trainable_list.append(regress_s)
elif opt.distill == 'crd':
criterion_kd = CRDLoss(opt, n_data)
embed_s = Embed(feat_s[-1].shape[1], opt.feat_dim)
embed_t = Embed(feat_t[-1].shape[1], opt.feat_dim)
module_list.append(embed_s)
module_list.append(embed_t)
trainable_list.append(embed_s)
trainable_list.append(embed_t)
opt.s_dim = feat_s[-1].shape[1]
opt.t_dim = feat_t[-1].shape[1]
opt.n_data = n_data
criterion_kd = CRDLoss(opt)
module_list.append(criterion_kd.embed_s)
module_list.append(criterion_kd.embed_t)
trainable_list.append(criterion_kd.embed_s)
trainable_list.append(criterion_kd.embed_t)
elif opt.distill == 'attention':
criterion_kd = Attention()
elif opt.distill == 'nst':
Expand Down

0 comments on commit 1fb2343

Please sign in to comment.