Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
HobbitLong committed Oct 22, 2019
1 parent 107dbfc commit b2487ab
Show file tree
Hide file tree
Showing 7 changed files with 999 additions and 0 deletions.
13 changes: 13 additions & 0 deletions distiller_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .AB import ABLoss
from .AT import Attention
from .CC import Correlation
from .FitNet import HintLoss
from .FSP import FSP
from .FT import FactorTransfer
from .KD import DistillKL
from .KDSVD import KDSVD
from .NST import NSTLoss
from .PKT import PKT
from .RKD import RKDLoss
from .SP import Similarity
from .VID import VIDLoss
Empty file added helper/__init__.py
Empty file.
271 changes: 271 additions & 0 deletions helper/loops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
from __future__ import print_function, division

import sys
import time
import torch

from .util import AverageMeter, accuracy


def train_vanilla(epoch, train_loader, model, criterion, optimizer, opt):
"""vanilla training"""
model.train()

batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()

end = time.time()
for idx, (input, target) in enumerate(train_loader):
data_time.update(time.time() - end)

input = input.float()
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()

# ===================forward=====================
output = model(input)
loss = criterion(output, target)

acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1[0], input.size(0))
top5.update(acc5[0], input.size(0))

# ===================backward=====================
optimizer.zero_grad()
loss.backward()
optimizer.step()

# ===================meters=====================
batch_time.update(time.time() - end)
end = time.time()

# tensorboard logger
pass

# print info
if idx % opt.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, idx, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
sys.stdout.flush()

print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))

return top1.avg, losses.avg


def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, opt):
"""One epoch distillation"""
# set modules as train()
for module in module_list:
module.train()
# set teacher as eval()
module_list[-1].eval()

if opt.distill == 'abound':
module_list[1].eval()
elif opt.distill == 'factor':
module_list[2].eval()

criterion_cls = criterion_list[0]
criterion_div = criterion_list[1]
criterion_kd = criterion_list[2]

model_s = module_list[0]
model_t = module_list[-1]

batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()

end = time.time()
for idx, data in enumerate(train_loader):
if opt.distill in ['contrast', 'infonce']:
input, target, index, contrast_idx = data
else:
input, target, index = data
data_time.update(time.time() - end)

input = input.float()
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
index = index.cuda()
if opt.distill in ['contrast', 'infonce']:
contrast_idx = contrast_idx.cuda()

# ===================forward=====================
preact = False
if opt.distill in ['abound']:
preact = True
feat_s, logit_s = model_s(input, is_feat=True, preact=preact)
with torch.no_grad():
feat_t, logit_t = model_t(input, is_feat=True, preact=preact)
feat_t = [f.detach() for f in feat_t]

# cls + kl div
loss_cls = criterion_cls(logit_s, target)
loss_div = criterion_div(logit_s, logit_t)

# other kd beyond KL divergence
if opt.distill == 'kd':
loss_kd = 0
elif opt.distill == 'hint':
f_s = module_list[1](feat_s[opt.hint_layer])
f_t = feat_t[opt.hint_layer]
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'contrast':
f_s = module_list[1](feat_s[-1])
f_t = module_list[2](feat_t[-1])
loss_kd = criterion_kd(f_s, f_t, index, contrast_idx)
elif opt.distill == 'infonce':
f_s = module_list[1](feat_s[-1])
f_t = module_list[2](feat_t[-1])
loss_kd = criterion_kd(f_s, f_t, index, contrast_idx)
elif opt.distill == 'softmax':
f_s = module_list[1](feat_s[-1])
f_t = module_list[2](feat_t[-1])
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'attention':
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_group = criterion_kd(g_s, g_t)
loss_kd = sum(loss_group)
elif opt.distill == 'nst':
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_group = criterion_kd(g_s, g_t)
loss_kd = sum(loss_group)
elif opt.distill == 'similarity':
g_s = [feat_s[-2]]
g_t = [feat_t[-2]]
loss_group = criterion_kd(g_s, g_t)
loss_kd = sum(loss_group)
elif opt.distill == 'rkd':
f_s = feat_s[-1]
f_t = feat_t[-1]
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'pkt':
f_s = feat_s[-1]
f_t = feat_t[-1]
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'kdsvd':
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_group = criterion_kd(g_s, g_t)
loss_kd = sum(loss_group)
elif opt.distill == 'correlation':
f_s = module_list[1](feat_s[-1])
f_t = module_list[2](feat_t[-1])
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'vid':
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)]
loss_kd = sum(loss_group)
elif opt.distill == 'abound':
# can also add loss to this stage
loss_kd = 0
elif opt.distill == 'fsp':
# can also add loss to this stage
loss_kd = 0
elif opt.distill == 'factor':
factor_s = module_list[1](feat_s[-2])
factor_t = module_list[2](feat_t[-2], is_factor=True)
loss_kd = criterion_kd(factor_s, factor_t)
else:
raise NotImplementedError(opt.distill)

loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd

acc1, acc5 = accuracy(logit_s, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1[0], input.size(0))
top5.update(acc5[0], input.size(0))

# ===================backward=====================
optimizer.zero_grad()
loss.backward()
optimizer.step()

# ===================meters=====================
batch_time.update(time.time() - end)
end = time.time()

# print info
if idx % opt.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, idx, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
sys.stdout.flush()

print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))

return top1.avg, losses.avg


def validate(val_loader, model, criterion, opt):
"""validation"""
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()

# switch to evaluate mode
model.eval()

with torch.no_grad():
end = time.time()
for idx, (input, target) in enumerate(val_loader):

input = input.float()
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()

# compute output
output = model(input)
loss = criterion(output, target)

# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1[0], input.size(0))
top5.update(acc5[0], input.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if idx % opt.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
idx, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))

print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))

return top1.avg, top5.avg, losses.avg
94 changes: 94 additions & 0 deletions helper/pretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from __future__ import print_function, division

import time
import sys
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
from .util import AverageMeter


def init(model_s, model_t, init_modules, criterion, train_loader, logger, opt):
model_t.eval()
model_s.eval()
init_modules.train()

if torch.cuda.is_available():
model_s.cuda()
model_t.cuda()
init_modules.cuda()
cudnn.benchmark = True

if opt.model_s in ['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110',
'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2'] and \
opt.distill == 'factor':
lr = 0.01
else:
lr = opt.learning_rate
optimizer = optim.SGD(init_modules.parameters(),
lr=lr,
momentum=opt.momentum,
weight_decay=opt.weight_decay)

batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
for epoch in range(1, opt.init_epochs + 1):
batch_time.reset()
data_time.reset()
losses.reset()
end = time.time()
for idx, data in enumerate(train_loader):
if opt.distill in ['contrast', 'infonce']:
input, target, index, contrast_idx = data
else:
input, target, index = data
data_time.update(time.time() - end)

input = input.float()
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
index = index.cuda()
if opt.distill in ['contrast', 'infonce']:
contrast_idx = contrast_idx.cuda()

# ============= forward ==============
preact = (opt.distill == 'abound')
feat_s, _ = model_s(input, is_feat=True, preact=preact)
with torch.no_grad():
feat_t, _ = model_t(input, is_feat=True, preact=preact)
feat_t = [f.detach() for f in feat_t]

if opt.distill == 'abound':
g_s = init_modules[0](feat_s[1:-1])
g_t = feat_t[1:-1]
loss_group = criterion(g_s, g_t)
loss = sum(loss_group)
elif opt.distill == 'factor':
f_t = feat_t[-2]
_, f_t_rec = init_modules[0](f_t)
loss = criterion(f_t_rec, f_t)
elif opt.distill == 'fsp':
loss_group = criterion(feat_s[:-1], feat_t[:-1])
loss = sum(loss_group)
else:
raise NotImplemented('Not supported in init training: {}'.format(opt.distill))

losses.update(loss.item(), input.size(0))

# ===================backward=====================
optimizer.zero_grad()
loss.backward()
optimizer.step()

batch_time.update(time.time() - end)
end = time.time()

# end of epoch
logger.log_value('init_train_loss', losses.avg, epoch)
print('Epoch: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'losses: {losses.val:.3f} ({losses.avg:.3f})'.format(
epoch, opt.init_epochs, batch_time=batch_time, losses=losses))
sys.stdout.flush()
Loading

0 comments on commit b2487ab

Please sign in to comment.