Skip to content

Commit

Permalink
hkd, fix crd
Browse files Browse the repository at this point in the history
  • Loading branch information
triomino committed Dec 30, 2020
1 parent 854717f commit 7bb58fc
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 50 deletions.
187 changes: 187 additions & 0 deletions distiller_zoo/HKD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import numpy as np
import torch
import torch.nn as nn


# https://github.com/passalis/pkth/blob/master/nn/pkt_transfer.py
def prob_loss(teacher_features, student_features, eps=1e-6, kernel_parameters={}):
# Teacher kernel
if kernel_parameters['teacher'] == 'rbf':
teacher_d = pairwise_distances(teacher_features)
if 'teacher_sigma' in kernel_parameters:
sigma = kernel_parameters['teacher_sigma']
else:
sigma = 1
teacher_s = torch.exp(-teacher_d / sigma)
elif kernel_parameters['teacher'] == 'adaptive_rbf':
teacher_d = pairwise_distances(teacher_features)
sigma = torch.mean(teacher_d).detach()
teacher_s = torch.exp(-teacher_d / sigma)
elif kernel_parameters['teacher'] == 'cosine':
teacher_s = cosine_pairwise_similarities(teacher_features)
elif kernel_parameters['teacher'] == 'student_t':
teacher_d = pairwise_distances(teacher_features)
if 'teacher_d' in kernel_parameters:
d = kernel_parameters['teacher_d']
else:
d = 1
teacher_s = 1.0 / (1 + teacher_d ** d)
elif kernel_parameters['teacher'] == 'cauchy':
teacher_d = pairwise_distances(teacher_features)
if 'teacher_sigma' in kernel_parameters:
sigma = kernel_parameters['teacher_sigma']
else:
sigma = 1
teacher_s = 1.0 / (1 + (teacher_d ** 2 / sigma ** 2))
elif kernel_parameters['teacher'] == 'combined':
teacher_d = pairwise_distances(teacher_features)
if 'teacher_d' in kernel_parameters:
d = kernel_parameters['teacher_d']
else:
d = 1
teacher_s_2 = 1.0 / (1 + teacher_d ** d)
teacher_s_1 = cosine_pairwise_similarities(teacher_features)
else:
assert False

# Student kernel
if kernel_parameters['student'] == 'rbf':
student_d = pairwise_distances(student_features)
if 'student_sigma' in kernel_parameters:
sigma = kernel_parameters['student_sigma']
else:
sigma = 1
student_s = torch.exp(-student_d / sigma)

elif kernel_parameters['student'] == 'adaptive_rbf':
student_d = pairwise_distances(student_features)
sigma = torch.mean(student_d).detach()
student_s = torch.exp(-student_d / sigma)

elif kernel_parameters['student'] == 'cosine':
student_s = cosine_pairwise_similarities(student_features)

elif kernel_parameters['student'] == 'student_t':
student_d = pairwise_distances(student_features)
if 'student_d' in kernel_parameters:
d = kernel_parameters['student_d']
else:
d = 1
student_s = 1.0 / (1 + student_d ** d)

elif kernel_parameters['student'] == 'cauchy':
student_d = pairwise_distances(student_features)
if 'student_sigma' in kernel_parameters:
sigma = kernel_parameters['student_sigma']
else:
sigma = 1
student_s = 1.0 / (1 + (student_d ** 2 / sigma ** 2))

elif kernel_parameters['student'] == 'combined':
student_d = pairwise_distances(student_features)
if 'student_d' in kernel_parameters:
d = kernel_parameters['student_d']
else:
d = 1
student_s_2 = 1.0 / (1 + student_d ** d)
student_s_1 = cosine_pairwise_similarities(student_features)
else:
assert False

if kernel_parameters['teacher'] == 'combined':
# Transform them into probabilities
teacher_s_1 = teacher_s_1 / torch.sum(teacher_s_1, dim=1, keepdim=True)
student_s_1 = student_s_1 / torch.sum(student_s_1, dim=1, keepdim=True)

teacher_s_2 = teacher_s_2 / torch.sum(teacher_s_2, dim=1, keepdim=True)
student_s_2 = student_s_2 / torch.sum(student_s_2, dim=1, keepdim=True)

else:
# Transform them into probabilities
teacher_s = teacher_s / torch.sum(teacher_s, dim=1, keepdim=True)
student_s = student_s / torch.sum(student_s, dim=1, keepdim=True)

if 'loss' in kernel_parameters:
if kernel_parameters['loss'] == 'kl':
loss = teacher_s * torch.log(eps + (teacher_s) / (eps + student_s))
elif kernel_parameters['loss'] == 'abs':
loss = torch.abs(teacher_s - student_s)
elif kernel_parameters['loss'] == 'squared':
loss = (teacher_s - student_s) ** 2
elif kernel_parameters['loss'] == 'jeffreys':
loss = (teacher_s - student_s) * (torch.log(teacher_s) - torch.log(student_s))
elif kernel_parameters['loss'] == 'exponential':
loss = teacher_s * (torch.log(teacher_s) - torch.log(student_s)) ** 2
elif kernel_parameters['loss'] == 'kagan':
loss = ((teacher_s - student_s) ** 2) / teacher_s
elif kernel_parameters['loss'] == 'combined':
# Jeffrey's combined
loss1 = (teacher_s_1 - student_s_1) * (torch.log(teacher_s_1) - torch.log(student_s_1))
loss2 = (teacher_s_2 - student_s_2) * (torch.log(teacher_s_2) - torch.log(student_s_2))
else:
assert False
else:
loss = teacher_s * torch.log(eps + (teacher_s) / (eps + student_s))

if 'loss' in kernel_parameters and kernel_parameters['loss'] == 'combined':
loss = torch.mean(loss1) + torch.mean(loss2)
else:
loss = torch.mean(loss)

return loss


def pairwise_distances(a, b=None, eps=1e-6):
"""
Calculates the pairwise distances between matrices a and b (or a and a, if b is not set)
:param a:
:param b:
:return:
"""
if b is None:
b = a

aa = torch.sum(a ** 2, dim=1)
bb = torch.sum(b ** 2, dim=1)

aa = aa.expand(bb.size(0), aa.size(0)).t()
bb = bb.expand(aa.size(0), bb.size(0))

AB = torch.mm(a, b.transpose(0, 1))

dists = aa + bb - 2 * AB
dists = torch.clamp(dists, min=0, max=np.inf)
dists = torch.sqrt(dists + eps)
return dists


def cosine_pairwise_similarities(features, eps=1e-6, normalized=True):
features_norm = torch.sqrt(torch.sum(features ** 2, dim=1, keepdim=True))
features = features / (features_norm + eps)
features[features != features] = 0
similarities = torch.mm(features, features.transpose(0, 1))

if normalized:
similarities = (similarities + 1.0) / 2.0
return similarities


class HKDLoss(nn.Module):
"""Heterogeneous Knowledge Distillation using Information Flow Modeling, CVPR2020"""
def __init__(self, init_weight=100, decay=0.7):
super(HKDLoss, self).__init__()
self.init_weight = init_weight
self.decay = decay

def forward(self, f_s, f_t):
kernel_parameters = {'teacher': 'combined', 'student': 'combined', 'loss': 'combined'}
for i, (teacher, student) in enumerate(zip(f_t, f_s)):
teacher = teacher.view(teacher.shape[0], -1)
student = student.view(student.shape[0], -1)
if i == 0:
weight = self.init_weight
loss = weight * prob_loss(teacher, student, kernel_parameters=kernel_parameters)
else:
weight *= self.decay
loss += prob_loss(teacher, student, kernel_parameters=kernel_parameters)
return loss
1 change: 1 addition & 0 deletions distiller_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .VID import VIDLoss
from .IRG import IRGLoss
from .SemCKD import SemCKDLoss
from .HKD import HKDLoss
16 changes: 10 additions & 6 deletions helper/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def train_vanilla(epoch, train_loader, model, criterion, optimizer, opt):

# input = input.float()
if opt.gpu is not None:
input = input.cuda(opt.gpu, non_blocking=True)
input = input.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
if torch.cuda.is_available():
target = target.cuda(opt.gpu, non_blocking=True)
target = target.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)

# ===================forward=====================
output = model(input)
Expand Down Expand Up @@ -106,9 +106,9 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o
input, target = data[0]['data'], data[0]['label'].squeeze().long()

if opt.gpu is not None:
input = input.cuda(opt.gpu, non_blocking=True)
input = input.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
if torch.cuda.is_available():
target = target.cuda(opt.gpu, non_blocking=True)
target = target.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
if opt.distill in ['crd']:
index = index.cuda()
contrast_idx = contrast_idx.cuda()
Expand Down Expand Up @@ -161,6 +161,10 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o
f_s = feat_s[-1]
f_t = feat_t[-1]
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'hkd':
f_s = feat_s[1:-1]
f_t = feat_t[1:-1]
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'correlation':
f_s = module_list[1](feat_s[-1])
f_t = module_list[2](feat_t[-1])
Expand Down Expand Up @@ -225,9 +229,9 @@ def validate(val_loader, model, criterion, opt):
input, target = batch_data[0]['data'], batch_data[0]['label'].squeeze().long()

if opt.gpu is not None:
input = input.cuda(opt.gpu, non_blocking=True)
input = input.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
if torch.cuda.is_available():
target = target.cuda(opt.gpu, non_blocking=True)
target = target.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)

# compute output
output = model(input)
Expand Down
13 changes: 10 additions & 3 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from .resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4
from .resnetv2 import resnet18, resnet34, resnet50, wide_resnet50_2, resnext50_32x4d
from .resnet import resnet8x4_double
from .resnetv2 import resnet18, resnet34, resnet50, wide_resnet50_2, resnext50_32x4d, resnet34x4, resnet18x2
from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2
from .vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn
from .vggv2 import vgg13_bn as vgg13_imagenet
from .vggv2 import vgg13_bn as vgg13_imagenet, vgg11_bn as vgg11_imagenet
from .mobilenetv2 import mobile_half
from .mobilenetv2_imagenet import mobilenet_v2
from .ShuffleNetv1 import ShuffleV1
from .ShuffleNetv2 import ShuffleV2
from .ShuffleNetv2_Imagenet import shufflenet_v2_x1_0 as ShuffleNetV2Imagenet
from .ShuffleNetv2_Imagenet import shufflenet_v2_x1_0 as ShuffleNetV2Imagenet, shufflenet_v2_x0_5, shufflenet_v2_x2_0

model_dict = {
'resnet8': resnet8,
Expand All @@ -17,12 +18,15 @@
'resnet44': resnet44,
'resnet56': resnet56,
'ResNet18': resnet18,
'ResNet18Double': resnet18x2,
'ResNet34': resnet34,
'ResNet50': resnet50,
'resnet110': resnet110,
'resnet8x4': resnet8x4,
'resnet8x4_double': resnet8x4_double,
'resnet32x4': resnet32x4,
'resnext50_32x4d': resnext50_32x4d,
'resnet34x4': resnet34x4,
'wrn_16_1': wrn_16_1,
'wrn_16_2': wrn_16_2,
'wrn_40_1': wrn_40_1,
Expand All @@ -34,9 +38,12 @@
'vgg16': vgg16_bn,
'vgg19': vgg19_bn,
'vgg13_imagenet': vgg13_imagenet,
'vgg11_imagenet': vgg11_imagenet,
'MobileNetV2': mobile_half,
'MobileNetV2_Imagenet': mobilenet_v2,
'ShuffleV1': ShuffleV1,
'ShuffleV2': ShuffleV2,
'ShuffleV2_Imagenet': ShuffleNetV2Imagenet,
'shufflenet_v2_x0_5': shufflenet_v2_x0_5,
'shufflenet_v2_x2_0': shufflenet_v2_x2_0,
}
4 changes: 4 additions & 0 deletions models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ def resnet8x4(**kwargs):
return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs)


def resnet8x4_double(**kwargs):
return ResNet(8, [64, 128, 256, 512], 'basicblock', **kwargs)


def resnet32x4(**kwargs):
return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs)

Expand Down
12 changes: 12 additions & 0 deletions models/resnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,18 @@ def resnet18(pretrained=False, progress=True, **kwargs):
**kwargs)


def resnet18x2(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, multiplier=2,
**kwargs)


def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Expand Down
7 changes: 5 additions & 2 deletions models/vggv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def forward(self, x, is_feat=False):
for module in self.features[left:right]:
x = module(x)
hidden_layers.append(x)
for module in self.features[self.split[-1]:]:
x = module(x)
if self.split[-1] < len(self.features):
for module in self.features[self.split[-1]:]:
x = module(x)
hidden_layers.append(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
Expand Down Expand Up @@ -107,6 +109,7 @@ def make_layers(cfg, split, batch_norm=False):
}

splits = {
'A': [0, 2, 4, 7, 10],
'B': [0, 3, 6, 9, 12]
}

Expand Down
26 changes: 15 additions & 11 deletions scripts/run_cifar_distill.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,28 @@
# use resnet32x4 and resnet8x4 as an example

# kd
python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill kd --model_s resnet8x4 -r 1 -a 1 -b 0 --trial 1
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill kd --model_s resnet8x4 -r 1 -a 1 -b 0 --trial 1
# FitNet
python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill hint --model_s resnet8x4 -r 1 -a 1 -b 100 --trial 1
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill hint --model_s resnet8x4 -r 1 -a 1 -b 100 --trial 1
# AT
python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill attention --model_s resnet8x4 -r 1 -a 1 -b 1000 --trial 1
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill attention --model_s resnet8x4 -r 1 -a 1 -b 1000 --trial 1
# SP
python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill similarity --model_s resnet8x4 -r 1 -a 1 -b 3000 --trial 1
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill similarity --model_s resnet8x4 -r 1 -a 1 -b 3000 --trial 1
# CC
python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill correlation --model_s resnet8x4 -r 1 -a 1 -b 0.02 --trial 1
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill correlation --model_s resnet8x4 -r 1 -a 1 -b 0.02 --trial 1
# VID
python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill vid --model_s resnet8x4 -r 1 -a 1 -b 1 --trial 1
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill vid --model_s resnet8x4 -r 1 -a 1 -b 1 --trial 1
# RKD
python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill rkd --model_s resnet8x4 -r 1 -a 1 -b 1 --trial 1
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill rkd --model_s resnet8x4 -r 1 -a 1 -b 1 --trial 1
# PKT
python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill pkt --model_s resnet8x4 -r 1 -a 1 -b 30000 --trial 1
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill pkt --model_s resnet8x4 -r 1 -a 1 -b 30000 --trial 1
# CRD
python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill crd --model_s resnet8x4 -r 1 -a 1 -b 0.8 --trial 1
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill crd --model_s resnet8x4 -r 1 -a 1 -b 0.8 --trial 1
# IRG
python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill irg --model_s resnet8x4 -r 1 -a 1 -b 0.8 --trial 1
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill irg --model_s resnet8x4 -r 1 -a 1 -b 0.8 --trial 1
# SemCKD
python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill semckd --model_s resnet8x4 -r 1 -a 1 -b 400 --trial 1
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill semckd --model_s resnet8x4 -r 1 -a 1 -b 400 --trial 1
# HKD(two pass)
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill pkt --model_s resnet8x4_double -r 1 -a 1 -b 30000 --trial 1
python train_student.py --path-t ./save/student_model/S:resnet8x4_double_T:resnet32x4_cifar100_pkt_r:1.0_a:1.0_b:30000.0_1/resnet8x4_double_best.pth \
--distill hkd --model_s resnet8x4 -r 1 -a 1 -b 1 --trial 1
Loading

0 comments on commit 7bb58fc

Please sign in to comment.