Skip to content

Commit

Permalink
replace 'contrast
Browse files Browse the repository at this point in the history
  • Loading branch information
HobbitLong committed Oct 22, 2019
1 parent 1143dfe commit c513466
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions helper/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o

end = time.time()
for idx, data in enumerate(train_loader):
if opt.distill in ['contrast']:
if opt.distill in ['crd']:
input, target, index, contrast_idx = data
else:
input, target, index = data
Expand All @@ -104,7 +104,7 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o
input = input.cuda()
target = target.cuda()
index = index.cuda()
if opt.distill in ['contrast']:
if opt.distill in ['crd']:
contrast_idx = contrast_idx.cuda()

# ===================forward=====================
Expand All @@ -127,7 +127,7 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o
f_s = module_list[1](feat_s[opt.hint_layer])
f_t = feat_t[opt.hint_layer]
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'contrast':
elif opt.distill == 'crd':
f_s = module_list[1](feat_s[-1])
f_t = module_list[2](feat_t[-1])
loss_kd = criterion_kd(f_s, f_t, index, contrast_idx)
Expand Down
4 changes: 2 additions & 2 deletions helper/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def init(model_s, model_t, init_modules, criterion, train_loader, logger, opt):
losses.reset()
end = time.time()
for idx, data in enumerate(train_loader):
if opt.distill in ['contrast', 'infonce']:
if opt.distill in ['crd']:
input, target, index, contrast_idx = data
else:
input, target, index = data
Expand All @@ -50,7 +50,7 @@ def init(model_s, model_t, init_modules, criterion, train_loader, logger, opt):
input = input.cuda()
target = target.cuda()
index = index.cuda()
if opt.distill in ['contrast', 'infonce']:
if opt.distill in ['crd']:
contrast_idx = contrast_idx.cuda()

# ============= forward ==============
Expand Down
2 changes: 1 addition & 1 deletion train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def main():

# dataloader
if opt.dataset == 'cifar100':
if opt.distill in ['contrast']:
if opt.distill in ['crd']:
train_loader, val_loader, n_data = get_cifar100_dataloaders_sample(batch_size=opt.batch_size,
num_workers=opt.num_workers,
k=opt.nce_k,
Expand Down

0 comments on commit c513466

Please sign in to comment.