-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
107dbfc
commit b2487ab
Showing
7 changed files
with
999 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from .AB import ABLoss | ||
from .AT import Attention | ||
from .CC import Correlation | ||
from .FitNet import HintLoss | ||
from .FSP import FSP | ||
from .FT import FactorTransfer | ||
from .KD import DistillKL | ||
from .KDSVD import KDSVD | ||
from .NST import NSTLoss | ||
from .PKT import PKT | ||
from .RKD import RKDLoss | ||
from .SP import Similarity | ||
from .VID import VIDLoss |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
from __future__ import print_function, division | ||
|
||
import sys | ||
import time | ||
import torch | ||
|
||
from .util import AverageMeter, accuracy | ||
|
||
|
||
def train_vanilla(epoch, train_loader, model, criterion, optimizer, opt): | ||
"""vanilla training""" | ||
model.train() | ||
|
||
batch_time = AverageMeter() | ||
data_time = AverageMeter() | ||
losses = AverageMeter() | ||
top1 = AverageMeter() | ||
top5 = AverageMeter() | ||
|
||
end = time.time() | ||
for idx, (input, target) in enumerate(train_loader): | ||
data_time.update(time.time() - end) | ||
|
||
input = input.float() | ||
if torch.cuda.is_available(): | ||
input = input.cuda() | ||
target = target.cuda() | ||
|
||
# ===================forward===================== | ||
output = model(input) | ||
loss = criterion(output, target) | ||
|
||
acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
losses.update(loss.item(), input.size(0)) | ||
top1.update(acc1[0], input.size(0)) | ||
top5.update(acc5[0], input.size(0)) | ||
|
||
# ===================backward===================== | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# ===================meters===================== | ||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
|
||
# tensorboard logger | ||
pass | ||
|
||
# print info | ||
if idx % opt.print_freq == 0: | ||
print('Epoch: [{0}][{1}/{2}]\t' | ||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | ||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' | ||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | ||
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' | ||
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | ||
epoch, idx, len(train_loader), batch_time=batch_time, | ||
data_time=data_time, loss=losses, top1=top1, top5=top5)) | ||
sys.stdout.flush() | ||
|
||
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' | ||
.format(top1=top1, top5=top5)) | ||
|
||
return top1.avg, losses.avg | ||
|
||
|
||
def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, opt): | ||
"""One epoch distillation""" | ||
# set modules as train() | ||
for module in module_list: | ||
module.train() | ||
# set teacher as eval() | ||
module_list[-1].eval() | ||
|
||
if opt.distill == 'abound': | ||
module_list[1].eval() | ||
elif opt.distill == 'factor': | ||
module_list[2].eval() | ||
|
||
criterion_cls = criterion_list[0] | ||
criterion_div = criterion_list[1] | ||
criterion_kd = criterion_list[2] | ||
|
||
model_s = module_list[0] | ||
model_t = module_list[-1] | ||
|
||
batch_time = AverageMeter() | ||
data_time = AverageMeter() | ||
losses = AverageMeter() | ||
top1 = AverageMeter() | ||
top5 = AverageMeter() | ||
|
||
end = time.time() | ||
for idx, data in enumerate(train_loader): | ||
if opt.distill in ['contrast', 'infonce']: | ||
input, target, index, contrast_idx = data | ||
else: | ||
input, target, index = data | ||
data_time.update(time.time() - end) | ||
|
||
input = input.float() | ||
if torch.cuda.is_available(): | ||
input = input.cuda() | ||
target = target.cuda() | ||
index = index.cuda() | ||
if opt.distill in ['contrast', 'infonce']: | ||
contrast_idx = contrast_idx.cuda() | ||
|
||
# ===================forward===================== | ||
preact = False | ||
if opt.distill in ['abound']: | ||
preact = True | ||
feat_s, logit_s = model_s(input, is_feat=True, preact=preact) | ||
with torch.no_grad(): | ||
feat_t, logit_t = model_t(input, is_feat=True, preact=preact) | ||
feat_t = [f.detach() for f in feat_t] | ||
|
||
# cls + kl div | ||
loss_cls = criterion_cls(logit_s, target) | ||
loss_div = criterion_div(logit_s, logit_t) | ||
|
||
# other kd beyond KL divergence | ||
if opt.distill == 'kd': | ||
loss_kd = 0 | ||
elif opt.distill == 'hint': | ||
f_s = module_list[1](feat_s[opt.hint_layer]) | ||
f_t = feat_t[opt.hint_layer] | ||
loss_kd = criterion_kd(f_s, f_t) | ||
elif opt.distill == 'contrast': | ||
f_s = module_list[1](feat_s[-1]) | ||
f_t = module_list[2](feat_t[-1]) | ||
loss_kd = criterion_kd(f_s, f_t, index, contrast_idx) | ||
elif opt.distill == 'infonce': | ||
f_s = module_list[1](feat_s[-1]) | ||
f_t = module_list[2](feat_t[-1]) | ||
loss_kd = criterion_kd(f_s, f_t, index, contrast_idx) | ||
elif opt.distill == 'softmax': | ||
f_s = module_list[1](feat_s[-1]) | ||
f_t = module_list[2](feat_t[-1]) | ||
loss_kd = criterion_kd(f_s, f_t) | ||
elif opt.distill == 'attention': | ||
g_s = feat_s[1:-1] | ||
g_t = feat_t[1:-1] | ||
loss_group = criterion_kd(g_s, g_t) | ||
loss_kd = sum(loss_group) | ||
elif opt.distill == 'nst': | ||
g_s = feat_s[1:-1] | ||
g_t = feat_t[1:-1] | ||
loss_group = criterion_kd(g_s, g_t) | ||
loss_kd = sum(loss_group) | ||
elif opt.distill == 'similarity': | ||
g_s = [feat_s[-2]] | ||
g_t = [feat_t[-2]] | ||
loss_group = criterion_kd(g_s, g_t) | ||
loss_kd = sum(loss_group) | ||
elif opt.distill == 'rkd': | ||
f_s = feat_s[-1] | ||
f_t = feat_t[-1] | ||
loss_kd = criterion_kd(f_s, f_t) | ||
elif opt.distill == 'pkt': | ||
f_s = feat_s[-1] | ||
f_t = feat_t[-1] | ||
loss_kd = criterion_kd(f_s, f_t) | ||
elif opt.distill == 'kdsvd': | ||
g_s = feat_s[1:-1] | ||
g_t = feat_t[1:-1] | ||
loss_group = criterion_kd(g_s, g_t) | ||
loss_kd = sum(loss_group) | ||
elif opt.distill == 'correlation': | ||
f_s = module_list[1](feat_s[-1]) | ||
f_t = module_list[2](feat_t[-1]) | ||
loss_kd = criterion_kd(f_s, f_t) | ||
elif opt.distill == 'vid': | ||
g_s = feat_s[1:-1] | ||
g_t = feat_t[1:-1] | ||
loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)] | ||
loss_kd = sum(loss_group) | ||
elif opt.distill == 'abound': | ||
# can also add loss to this stage | ||
loss_kd = 0 | ||
elif opt.distill == 'fsp': | ||
# can also add loss to this stage | ||
loss_kd = 0 | ||
elif opt.distill == 'factor': | ||
factor_s = module_list[1](feat_s[-2]) | ||
factor_t = module_list[2](feat_t[-2], is_factor=True) | ||
loss_kd = criterion_kd(factor_s, factor_t) | ||
else: | ||
raise NotImplementedError(opt.distill) | ||
|
||
loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd | ||
|
||
acc1, acc5 = accuracy(logit_s, target, topk=(1, 5)) | ||
losses.update(loss.item(), input.size(0)) | ||
top1.update(acc1[0], input.size(0)) | ||
top5.update(acc5[0], input.size(0)) | ||
|
||
# ===================backward===================== | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# ===================meters===================== | ||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
|
||
# print info | ||
if idx % opt.print_freq == 0: | ||
print('Epoch: [{0}][{1}/{2}]\t' | ||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | ||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' | ||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | ||
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' | ||
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | ||
epoch, idx, len(train_loader), batch_time=batch_time, | ||
data_time=data_time, loss=losses, top1=top1, top5=top5)) | ||
sys.stdout.flush() | ||
|
||
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' | ||
.format(top1=top1, top5=top5)) | ||
|
||
return top1.avg, losses.avg | ||
|
||
|
||
def validate(val_loader, model, criterion, opt): | ||
"""validation""" | ||
batch_time = AverageMeter() | ||
losses = AverageMeter() | ||
top1 = AverageMeter() | ||
top5 = AverageMeter() | ||
|
||
# switch to evaluate mode | ||
model.eval() | ||
|
||
with torch.no_grad(): | ||
end = time.time() | ||
for idx, (input, target) in enumerate(val_loader): | ||
|
||
input = input.float() | ||
if torch.cuda.is_available(): | ||
input = input.cuda() | ||
target = target.cuda() | ||
|
||
# compute output | ||
output = model(input) | ||
loss = criterion(output, target) | ||
|
||
# measure accuracy and record loss | ||
acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
losses.update(loss.item(), input.size(0)) | ||
top1.update(acc1[0], input.size(0)) | ||
top5.update(acc5[0], input.size(0)) | ||
|
||
# measure elapsed time | ||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
|
||
if idx % opt.print_freq == 0: | ||
print('Test: [{0}/{1}]\t' | ||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | ||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | ||
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' | ||
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( | ||
idx, len(val_loader), batch_time=batch_time, loss=losses, | ||
top1=top1, top5=top5)) | ||
|
||
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' | ||
.format(top1=top1, top5=top5)) | ||
|
||
return top1.avg, top5.avg, losses.avg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from __future__ import print_function, division | ||
|
||
import time | ||
import sys | ||
import torch | ||
import torch.optim as optim | ||
import torch.backends.cudnn as cudnn | ||
from .util import AverageMeter | ||
|
||
|
||
def init(model_s, model_t, init_modules, criterion, train_loader, logger, opt): | ||
model_t.eval() | ||
model_s.eval() | ||
init_modules.train() | ||
|
||
if torch.cuda.is_available(): | ||
model_s.cuda() | ||
model_t.cuda() | ||
init_modules.cuda() | ||
cudnn.benchmark = True | ||
|
||
if opt.model_s in ['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', | ||
'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2'] and \ | ||
opt.distill == 'factor': | ||
lr = 0.01 | ||
else: | ||
lr = opt.learning_rate | ||
optimizer = optim.SGD(init_modules.parameters(), | ||
lr=lr, | ||
momentum=opt.momentum, | ||
weight_decay=opt.weight_decay) | ||
|
||
batch_time = AverageMeter() | ||
data_time = AverageMeter() | ||
losses = AverageMeter() | ||
for epoch in range(1, opt.init_epochs + 1): | ||
batch_time.reset() | ||
data_time.reset() | ||
losses.reset() | ||
end = time.time() | ||
for idx, data in enumerate(train_loader): | ||
if opt.distill in ['contrast', 'infonce']: | ||
input, target, index, contrast_idx = data | ||
else: | ||
input, target, index = data | ||
data_time.update(time.time() - end) | ||
|
||
input = input.float() | ||
if torch.cuda.is_available(): | ||
input = input.cuda() | ||
target = target.cuda() | ||
index = index.cuda() | ||
if opt.distill in ['contrast', 'infonce']: | ||
contrast_idx = contrast_idx.cuda() | ||
|
||
# ============= forward ============== | ||
preact = (opt.distill == 'abound') | ||
feat_s, _ = model_s(input, is_feat=True, preact=preact) | ||
with torch.no_grad(): | ||
feat_t, _ = model_t(input, is_feat=True, preact=preact) | ||
feat_t = [f.detach() for f in feat_t] | ||
|
||
if opt.distill == 'abound': | ||
g_s = init_modules[0](feat_s[1:-1]) | ||
g_t = feat_t[1:-1] | ||
loss_group = criterion(g_s, g_t) | ||
loss = sum(loss_group) | ||
elif opt.distill == 'factor': | ||
f_t = feat_t[-2] | ||
_, f_t_rec = init_modules[0](f_t) | ||
loss = criterion(f_t_rec, f_t) | ||
elif opt.distill == 'fsp': | ||
loss_group = criterion(feat_s[:-1], feat_t[:-1]) | ||
loss = sum(loss_group) | ||
else: | ||
raise NotImplemented('Not supported in init training: {}'.format(opt.distill)) | ||
|
||
losses.update(loss.item(), input.size(0)) | ||
|
||
# ===================backward===================== | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
|
||
# end of epoch | ||
logger.log_value('init_train_loss', losses.avg, epoch) | ||
print('Epoch: [{0}/{1}]\t' | ||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | ||
'losses: {losses.val:.3f} ({losses.avg:.3f})'.format( | ||
epoch, opt.init_epochs, batch_time=batch_time, losses=losses)) | ||
sys.stdout.flush() |
Oops, something went wrong.