Skip to content

Commit

Permalink
modify
Browse files Browse the repository at this point in the history
  • Loading branch information
DefangChen committed Dec 29, 2020
1 parent e7c30c3 commit 854717f
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from helper.util import adjust_learning_rate, save_dict_to_json, reduce_tensor

from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss, SemCKDLoss, IRGLoss
from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, PKT, Correlation, VIDLoss, RKDLoss, SemCKDLoss, IRGLoss
from crd.criterion import CRDLoss

from helper.loops import train_distill as train, validate
Expand All @@ -41,7 +41,7 @@ def parse_option():

parser = argparse.ArgumentParser('argument for training')

# baisc
# basic
parser.add_argument('--print-freq', type=int, default=200, help='print frequency')
parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
Expand Down Expand Up @@ -165,13 +165,6 @@ def load_teacher(model_path, n_cls, gpu=None):
print('==> loading teacher model')
model_t = get_teacher_name(model_path)
model = model_dict[model_t](num_classes=n_cls)
# pre-trained model download from AAAI-2020
#if n_cls == 1000:
#model = nn.DataParallel(model)
#model.load_state_dict(torch.load(model_path)['state_dict'])
#model = model.module
# pre-trained model saved from train_teacher.py
#else:
# TODO: reduce size of the teacher saved in train_teacher.py
map_location = None if gpu is None else {'cuda:0': 'cuda:%d' % gpu}
model.load_state_dict(torch.load(model_path, map_location=map_location)['model'])
Expand Down

0 comments on commit 854717f

Please sign in to comment.