diff --git a/helper/loops.py b/helper/loops.py index a590f0d..3dc5ac4 100644 --- a/helper/loops.py +++ b/helper/loops.py @@ -93,7 +93,7 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o end = time.time() for idx, data in enumerate(train_loader): - if opt.distill in ['contrast', 'infonce']: + if opt.distill in ['contrast']: input, target, index, contrast_idx = data else: input, target, index = data @@ -104,7 +104,7 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o input = input.cuda() target = target.cuda() index = index.cuda() - if opt.distill in ['contrast', 'infonce']: + if opt.distill in ['contrast']: contrast_idx = contrast_idx.cuda() # ===================forward===================== @@ -131,14 +131,6 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o f_s = module_list[1](feat_s[-1]) f_t = module_list[2](feat_t[-1]) loss_kd = criterion_kd(f_s, f_t, index, contrast_idx) - elif opt.distill == 'infonce': - f_s = module_list[1](feat_s[-1]) - f_t = module_list[2](feat_t[-1]) - loss_kd = criterion_kd(f_s, f_t, index, contrast_idx) - elif opt.distill == 'softmax': - f_s = module_list[1](feat_s[-1]) - f_t = module_list[2](feat_t[-1]) - loss_kd = criterion_kd(f_s, f_t) elif opt.distill == 'attention': g_s = feat_s[1:-1] g_t = feat_t[1:-1] diff --git a/models/__init__.py b/models/__init__.py index e69de29..91f87a5 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -0,0 +1,33 @@ +from .resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4 +from .resnetv2 import ResNet50 +from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2 +from .vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn +from .mobilenetv2 import mobile_half, mobile_full +from .ShuffleNetv1 import ShuffleV1 +from .ShuffleNetv2 import ShuffleV2 + +model_dict = { + 'resnet8': resnet8, + 'resnet14': resnet14, + 'resnet20': resnet20, + 'resnet32': resnet32, + 'resnet44': resnet44, + 'resnet56': resnet56, + 'resnet110': resnet110, + 'resnet8x4': resnet8x4, + 'resnet32x4': resnet32x4, + 'ResNet50': ResNet50, + 'wrn_16_1': wrn_16_1, + 'wrn_16_2': wrn_16_2, + 'wrn_40_1': wrn_40_1, + 'wrn_40_2': wrn_40_2, + 'vgg8': vgg8_bn, + 'vgg11': vgg11_bn, + 'vgg13': vgg13_bn, + 'vgg16': vgg16_bn, + 'vgg19': vgg19_bn, + 'mobile_half': mobile_half, + 'mobile_full': mobile_full, + 'ShuffleV1': ShuffleV1, + 'ShuffleV2': ShuffleV2, +} diff --git a/train_student.py b/train_student.py index 1d384cd..f736786 100644 --- a/train_student.py +++ b/train_student.py @@ -8,7 +8,6 @@ import argparse import socket import time -import sys import tensorboard_logger as tb_logger import torch @@ -16,13 +15,8 @@ import torch.nn as nn import torch.backends.cudnn as cudnn -from models.resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4 -from models.resnetv2 import ResNet50 -from models.wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2 -from models.vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn -from models.mobilenetv2 import mobile_half, mobile_full -from models.ShuffleNetv1 import ShuffleV1 -from models.ShuffleNetv2 import ShuffleV2 + +from models import model_dict from models.util import Embed, ConvReg, LinearEmbed from models.util import Connector, Translator, Paraphraser @@ -37,15 +31,6 @@ from helper.loops import train_distill as train, validate from helper.pretrain import init -model_dict = { - 'resnet8': resnet8, 'resnet14': resnet14, 'resnet20': resnet20, 'resnet32': resnet32, 'resnet44': resnet44, - 'resnet56': resnet56, 'resnet110': resnet110, 'resnet8x4': resnet8x4, 'resnet32x4': resnet32x4, - 'wrn_16_1': wrn_16_1, 'wrn_16_2': wrn_16_2, 'wrn_40_1': wrn_40_1, 'wrn_40_2': wrn_40_2, - 'vgg19': vgg19_bn, 'vgg16': vgg16_bn, 'vgg13': vgg13_bn, 'vgg11': vgg11_bn, 'vgg8': vgg8_bn, - 'ResNet50': ResNet50, 'mobile_half': mobile_half, 'mobile_full': mobile_full, - 'ShuffleV1': ShuffleV1, 'ShuffleV2': ShuffleV2, -} - def parse_option(): diff --git a/train_teacher.py b/train_teacher.py index 5846577..40b284b 100644 --- a/train_teacher.py +++ b/train_teacher.py @@ -4,7 +4,6 @@ import argparse import socket import time -import sys import tensorboard_logger as tb_logger import torch @@ -12,12 +11,7 @@ import torch.nn as nn import torch.backends.cudnn as cudnn -from models.resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4 -from models.wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2 -from models.vgg import vgg8_bn, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn -from models.mobilenetv2 import mobile_half, mobile_full -from models.ShuffleNetv1 import ShuffleV1 -from models.ShuffleNetv2 import ShuffleV2 +from models import model_dict from dataset.cifar100 import get_cifar100_dataloaders @@ -97,30 +91,6 @@ def main(): raise NotImplementedError(opt.dataset) # model - model_dict = { - 'resnet8': resnet8, - 'resnet14': resnet14, - 'resnet20': resnet20, - 'resnet32': resnet32, - 'resnet44': resnet44, - 'resnet56': resnet56, - 'resnet110': resnet110, - 'resnet8x4': resnet8x4, - 'resnet32x4': resnet32x4, - 'wrn_16_1': wrn_16_1, - 'wrn_16_2': wrn_16_2, - 'wrn_40_1': wrn_40_1, - 'wrn_40_2': wrn_40_2, - 'vgg8': vgg8_bn, - 'vgg11': vgg11_bn, - 'vgg13': vgg13_bn, - 'vgg16': vgg16_bn, - 'vgg19': vgg19_bn, - 'mobile_half': mobile_half, - 'mobile_full': mobile_full, - 'ShuffleV1': ShuffleV1, - 'ShuffleV2': ShuffleV2, - } model = model_dict[opt.model](num_classes=n_cls) # optimizer