Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
HobbitLong committed Oct 22, 2019
1 parent b2487ab commit 1143dfe
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 58 deletions.
12 changes: 2 additions & 10 deletions helper/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o

end = time.time()
for idx, data in enumerate(train_loader):
if opt.distill in ['contrast', 'infonce']:
if opt.distill in ['contrast']:
input, target, index, contrast_idx = data
else:
input, target, index = data
Expand All @@ -104,7 +104,7 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o
input = input.cuda()
target = target.cuda()
index = index.cuda()
if opt.distill in ['contrast', 'infonce']:
if opt.distill in ['contrast']:
contrast_idx = contrast_idx.cuda()

# ===================forward=====================
Expand All @@ -131,14 +131,6 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o
f_s = module_list[1](feat_s[-1])
f_t = module_list[2](feat_t[-1])
loss_kd = criterion_kd(f_s, f_t, index, contrast_idx)
elif opt.distill == 'infonce':
f_s = module_list[1](feat_s[-1])
f_t = module_list[2](feat_t[-1])
loss_kd = criterion_kd(f_s, f_t, index, contrast_idx)
elif opt.distill == 'softmax':
f_s = module_list[1](feat_s[-1])
f_t = module_list[2](feat_t[-1])
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'attention':
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
Expand Down
33 changes: 33 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from .resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4
from .resnetv2 import ResNet50
from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2
from .vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn
from .mobilenetv2 import mobile_half, mobile_full
from .ShuffleNetv1 import ShuffleV1
from .ShuffleNetv2 import ShuffleV2

model_dict = {
'resnet8': resnet8,
'resnet14': resnet14,
'resnet20': resnet20,
'resnet32': resnet32,
'resnet44': resnet44,
'resnet56': resnet56,
'resnet110': resnet110,
'resnet8x4': resnet8x4,
'resnet32x4': resnet32x4,
'ResNet50': ResNet50,
'wrn_16_1': wrn_16_1,
'wrn_16_2': wrn_16_2,
'wrn_40_1': wrn_40_1,
'wrn_40_2': wrn_40_2,
'vgg8': vgg8_bn,
'vgg11': vgg11_bn,
'vgg13': vgg13_bn,
'vgg16': vgg16_bn,
'vgg19': vgg19_bn,
'mobile_half': mobile_half,
'mobile_full': mobile_full,
'ShuffleV1': ShuffleV1,
'ShuffleV2': ShuffleV2,
}
19 changes: 2 additions & 17 deletions train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,15 @@
import argparse
import socket
import time
import sys

import tensorboard_logger as tb_logger
import torch
import torch.optim as optim
import torch.nn as nn
import torch.backends.cudnn as cudnn

from models.resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4
from models.resnetv2 import ResNet50
from models.wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2
from models.vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn
from models.mobilenetv2 import mobile_half, mobile_full
from models.ShuffleNetv1 import ShuffleV1
from models.ShuffleNetv2 import ShuffleV2

from models import model_dict
from models.util import Embed, ConvReg, LinearEmbed
from models.util import Connector, Translator, Paraphraser

Expand All @@ -37,15 +31,6 @@
from helper.loops import train_distill as train, validate
from helper.pretrain import init

model_dict = {
'resnet8': resnet8, 'resnet14': resnet14, 'resnet20': resnet20, 'resnet32': resnet32, 'resnet44': resnet44,
'resnet56': resnet56, 'resnet110': resnet110, 'resnet8x4': resnet8x4, 'resnet32x4': resnet32x4,
'wrn_16_1': wrn_16_1, 'wrn_16_2': wrn_16_2, 'wrn_40_1': wrn_40_1, 'wrn_40_2': wrn_40_2,
'vgg19': vgg19_bn, 'vgg16': vgg16_bn, 'vgg13': vgg13_bn, 'vgg11': vgg11_bn, 'vgg8': vgg8_bn,
'ResNet50': ResNet50, 'mobile_half': mobile_half, 'mobile_full': mobile_full,
'ShuffleV1': ShuffleV1, 'ShuffleV2': ShuffleV2,
}


def parse_option():

Expand Down
32 changes: 1 addition & 31 deletions train_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,14 @@
import argparse
import socket
import time
import sys

import tensorboard_logger as tb_logger
import torch
import torch.optim as optim
import torch.nn as nn
import torch.backends.cudnn as cudnn

from models.resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4
from models.wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2
from models.vgg import vgg8_bn, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from models.mobilenetv2 import mobile_half, mobile_full
from models.ShuffleNetv1 import ShuffleV1
from models.ShuffleNetv2 import ShuffleV2
from models import model_dict

from dataset.cifar100 import get_cifar100_dataloaders

Expand Down Expand Up @@ -97,30 +91,6 @@ def main():
raise NotImplementedError(opt.dataset)

# model
model_dict = {
'resnet8': resnet8,
'resnet14': resnet14,
'resnet20': resnet20,
'resnet32': resnet32,
'resnet44': resnet44,
'resnet56': resnet56,
'resnet110': resnet110,
'resnet8x4': resnet8x4,
'resnet32x4': resnet32x4,
'wrn_16_1': wrn_16_1,
'wrn_16_2': wrn_16_2,
'wrn_40_1': wrn_40_1,
'wrn_40_2': wrn_40_2,
'vgg8': vgg8_bn,
'vgg11': vgg11_bn,
'vgg13': vgg13_bn,
'vgg16': vgg16_bn,
'vgg19': vgg19_bn,
'mobile_half': mobile_half,
'mobile_full': mobile_full,
'ShuffleV1': ShuffleV1,
'ShuffleV2': ShuffleV2,
}
model = model_dict[opt.model](num_classes=n_cls)

# optimizer
Expand Down

0 comments on commit 1143dfe

Please sign in to comment.