Skip to content

Commit

Permalink
Add IRG Loss
Browse files Browse the repository at this point in the history
  • Loading branch information
triomino committed Jul 4, 2020
1 parent 34557d2 commit 17f66dc
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,6 @@ venv.bak/

# mypy
.mypy_cache/

# trained models
save/*
24 changes: 22 additions & 2 deletions dataset/cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,28 @@ def get_data_folder():

return data_folder

class CIFAR100BackCompat(datasets.CIFAR100):
"""
CIFAR100Instance+Sample Dataset
"""

@property
def train_labels(self):
return self.targets

@property
def test_labels(self):
return self.targets

@property
def train_data(self):
return self.data

@property
def test_data(self):
return self.data

class CIFAR100Instance(datasets.CIFAR100):
class CIFAR100Instance(CIFAR100BackCompat):
"""CIFAR100Instance Dataset.
"""
def __getitem__(self, index):
Expand Down Expand Up @@ -106,7 +126,7 @@ def get_cifar100_dataloaders(batch_size=128, num_workers=8, is_instance=False):
return train_loader, test_loader


class CIFAR100InstanceSample(datasets.CIFAR100):
class CIFAR100InstanceSample(CIFAR100BackCompat):
"""
CIFAR100Instance+Sample Dataset
"""
Expand Down
68 changes: 68 additions & 0 deletions distiller_zoo/IRG.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F


class IRGLoss(nn.Module):
"""Knowledge Distillation via Instance Relationship Graph, CVPR2019"""
def __init__(self, w_graph = 1, w_transform = 1):
super(IRGLoss, self).__init__()
self.mseloss = nn.MSELoss()
self.w_graph = w_graph
self.w_transform = w_transform

def forward(self, f_s, f_t, transform_s, transform_t, no_edge_transform = False):
edge_transform = not no_edge_transform
student = f_s.view(f_s.shape[0], -1)
teacher = f_t.view(f_t.shape[0], -1)

# vertex and edge loss
with torch.no_grad():
t_d = self.pdist(teacher, squared=True, normalization = 'max')
d = self.pdist(student, squared=True, normalization = 'max')
loss = self.mseloss(d, t_d)
if f_s.shape == f_t.shape:
loss += self.mseloss(f_s, f_t)
loss *= self.w_graph

# transform loss
transform_zip = list(zip(transform_s, transform_t))
for (l1_s, l1_t), (l2_s, l2_t) in list(zip(transform_zip, transform_zip[1:]))[::2]:
loss += self.transform_loss(l1_s, l2_s, l1_t, l2_t, edge_transform) * self.w_transform

return loss

def transform_loss(self, l1_s, l2_s, l1_t, l2_t, edge_transform = True):
loss = []
if edge_transform:
dl1_s = self.pdist(l1_s.view(l1_s.shape[0], -1), squared = True, normalization = 'max')
dl2_s = self.pdist(l2_s.view(l2_s.shape[0], -1), squared = True, normalization = 'max')
with torch.no_grad():
dl1_t = self.pdist(l1_t.view(l1_t.shape[0], -1), squared = True, normalization = 'max')
dl2_t = self.pdist(l2_t.view(l2_t.shape[0], -1), squared = True, normalization = 'max')
loss.append(self.mseloss(self.mseloss(dl1_s, dl2_s), self.mseloss(dl1_t, dl2_t)))
if l1_s.shape == l2_s.shape and l1_t.shape == l2_t.shape:
with torch.no_grad():
lossv_t = self.mseloss(l1_t, l2_t)
loss.append(self.mseloss(self.mseloss(l1_s, l2_s), lossv_t))
return sum(loss)

@staticmethod
def pdist(e, squared=False, eps=1e-12, normalization='max'):
e_square = e.pow(2).sum(dim=1)
prod = e @ e.t()
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)

if not squared:
res = res.sqrt()

res = res.clone()
res[range(len(e)), range(len(e))] = 0

if normalization == 'max':
res_max = res.max() + eps
res = res / res_max

return res
1 change: 1 addition & 0 deletions distiller_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from .RKD import RKDLoss
from .SP import Similarity
from .VID import VIDLoss
from .IRG import IRGLoss
6 changes: 6 additions & 0 deletions helper/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o
f_s = feat_s[-1]
f_t = feat_t[-1]
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'irg':
transform_s = [feat_s[i] for i in opt.transform_layer_s]
transform_t = [feat_t[i] for i in opt.transform_layer_t]
f_s = feat_s[-1]
f_t = feat_t[-1]
loss_kd = criterion_kd(f_s, f_t, transform_s, transform_t, opt.no_edge_transform)
elif opt.distill == 'pkt':
f_s = feat_s[-1]
f_t = feat_t[-1]
Expand Down
14 changes: 12 additions & 2 deletions train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from helper.util import adjust_learning_rate

from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss
from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss, IRGLoss
from distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss
from crd.criterion import CRDLoss

Expand Down Expand Up @@ -67,7 +67,7 @@ def parse_option():
# distillation
parser.add_argument('--distill', type=str, default='kd', choices=['kd', 'hint', 'attention', 'similarity',
'correlation', 'vid', 'crd', 'kdsvd', 'fsp',
'rkd', 'pkt', 'abound', 'factor', 'nst'])
'rkd', 'pkt', 'abound', 'factor', 'nst', 'irg'])
parser.add_argument('--trial', type=str, default='1', help='trial id')

parser.add_argument('-r', '--gamma', type=float, default=1, help='weight for classification')
Expand All @@ -87,6 +87,13 @@ def parse_option():
# hint layer
parser.add_argument('--hint_layer', default=2, type=int, choices=[0, 1, 2, 3, 4])

# transform layers for IRG
parser.add_argument('--transform_layer_t', nargs='+', type=int, default = [])
parser.add_argument('--transform_layer_s', nargs='+', type=int, default = [])

# switch for edge transformation
parser.add_argument('--no_edge_transform', action='store_true')

opt = parser.parse_args()

# set different learning rate from these 4 models
Expand Down Expand Up @@ -204,6 +211,8 @@ def main():
criterion_kd = Similarity()
elif opt.distill == 'rkd':
criterion_kd = RKDLoss()
elif opt.distill == 'irg':
criterion_kd = IRGLoss()
elif opt.distill == 'pkt':
criterion_kd = PKT()
elif opt.distill == 'kdsvd':
Expand Down Expand Up @@ -294,6 +303,7 @@ def main():
print("==> training...")

time1 = time.time()
# train_loss, train_acc = 0, 0
train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt)
time2 = time.time()
print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
Expand Down

0 comments on commit 17f66dc

Please sign in to comment.