Skip to content

Commit

Permalink
distributed training for kd
Browse files Browse the repository at this point in the history
  • Loading branch information
triomino committed Jul 10, 2020
1 parent d5d625e commit 23d693c
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 95 deletions.
17 changes: 5 additions & 12 deletions dataset/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_dataloader_sample(dataset='imagenet', batch_size=128, num_workers=8, is_
return train_loader, test_loader, len(train_set), len(train_set.classes)


def get_imagenet_dataloader(dataset='imagenet', batch_size=128, num_workers=16, is_instance=False, use_lmdb=False, multiprocessing_distributed=False):
def get_imagenet_dataloader(dataset='imagenet', batch_size=128, num_workers=16, use_lmdb=False, multiprocessing_distributed=False):
"""
Data Loader for imagenet
"""
Expand Down Expand Up @@ -208,14 +208,10 @@ def get_imagenet_dataloader(dataset='imagenet', batch_size=128, num_workers=16,
train_lmdb_path = os.path.join(data_folder, 'train.lmdb')
test_lmdb_path = os.path.join(data_folder, 'val.lmdb')

if is_instance:
train_set = ImageFolderInstance(train_folder, transform=train_transform)
n_data = len(train_set)
if use_lmdb:
train_set = ImageFolderLMDB(train_lmdb_path, transform=train_transform)
else:
if use_lmdb:
train_set = ImageFolderLMDB(train_lmdb_path, transform=train_transform)
else:
train_set = datasets.ImageFolder(train_folder, transform=train_transform)
train_set = datasets.ImageFolder(train_folder, transform=train_transform)

if use_lmdb:
test_set = ImageFolderLMDB(test_lmdb_path, transform=test_transform)
Expand Down Expand Up @@ -243,7 +239,4 @@ def get_imagenet_dataloader(dataset='imagenet', batch_size=128, num_workers=16,
pin_memory=True,
sampler=test_sampler)

if is_instance:
return train_loader, test_loader, n_data
else:
return train_loader, test_loader, train_sampler
return train_loader, test_loader, train_sampler
40 changes: 27 additions & 13 deletions helper/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ def train_vanilla(epoch, train_loader, model, criterion, optimizer, opt):
# print info
if idx % opt.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'GPU {3}\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.avg:.4f}\t'
'Acc@1 {top1.avg:.3f}\t'
'Acc@5 {top5.avg:.3f}'.format(
epoch, idx, len(train_loader), batch_time=batch_time,
epoch, idx, len(train_loader), opt.gpu, batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
sys.stdout.flush()

Expand All @@ -77,30 +78,36 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o
model_s = module_list[0]
model_t = module_list[-1]

batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()

end = time.time()
for idx, data in enumerate(train_loader):
data_time.update(time.time() - end)

if opt.distill in ['crd']:
input, target, index, contrast_idx = data
else:
input, target, index = data
input, target = data

if target.shape[0] < opt.batch_size:
continue
# input = input.float()

if opt.gpu is not None:
input = input.cuda(opt.gpu, non_blocking=True)
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
index = index.cuda()
target = target.cuda(opt.gpu, non_blocking=True)
if opt.distill in ['crd']:
index = index.cuda()
contrast_idx = contrast_idx.cuda()

# ===================forward=====================
preact = False
feat_s, logit_s = model_s(input, is_feat=True, preact=preact)
feat_s, logit_s = model_s(input, is_feat=True)
with torch.no_grad():
feat_t, logit_t = model_t(input, is_feat=True, preact=preact)
feat_t, logit_t = model_t(input, is_feat=True)
feat_t = [f.detach() for f in feat_t]

# cls + kl div
Expand Down Expand Up @@ -164,6 +171,8 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o
metrics = accuracy(logit_s, target, topk=(1, 5))
top1.update(metrics[0].item(), input.size(0))
top5.update(metrics[1].item(), input.size(0))
batch_time.update(time.time() - end)
end = time.time()

# ===================backward=====================
optimizer.zero_grad()
Expand All @@ -173,13 +182,17 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o
# print info
if idx % opt.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'GPU {3}\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.avg:.4f}\t'
'Acc@1 {top1.avg:.3f}\t'
'Acc@5 {top5.avg:.3f}'.format(
epoch, idx, len(train_loader), loss=losses, top1=top1, top5=top5))
epoch, idx, len(train_loader), opt.gpu, loss=losses, top1=top1, top5=top5,
batch_time=batch_time, data_time=data_time))
sys.stdout.flush()

return top1.avg, top5.avg, losses.avg
return top1.avg, top5.avg, losses.avg, data_time.avg


def validate(val_loader, model, criterion, opt, meter_queue = None):
Expand Down Expand Up @@ -217,11 +230,12 @@ def validate(val_loader, model, criterion, opt, meter_queue = None):

if idx % opt.print_freq == 0:
print('Test: [{0}/{1}]\t'
'GPU: {2}\t'
'Time: {batch_time.avg:.3f}\t'
'Loss {loss.avg:.4f}\t'
'Acc@1 {top1.avg:.3f}\t'
'Acc@5 {top5.avg:.3f}'.format(
idx, len(val_loader), batch_time=batch_time, loss=losses,
idx, len(val_loader), opt.gpu, batch_time=batch_time, loss=losses,
top1=top1, top5=top5))

if opt.multiprocessing_distributed:
Expand Down
2 changes: 1 addition & 1 deletion models/resnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def forward(self, x, is_feat=False):
x = torch.flatten(x, 1)
f5 = x
x = self.fc(x)
return [f0, f1, f2, f3, f4, f5], out
return [f0, f1, f2, f3, f4, f5], x
else:
return self._forward_impl(x)

Expand Down
4 changes: 4 additions & 0 deletions stu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
python train_student.py --path-t ./save/models/ResNet34_vanilla/resnet34_transformed.pth \
--batch_size 256 --epochs 90 --dataset imagenet --gpu_id 0,1,2,3 --dist-url tcp:https://127.0.0.1:23453 \
--print-freq 100 --num_workers 16 --distill kd --model_s ResNet18 -r 1 -a 1 -b 0 --trial 0 \
--multiprocessing-distributed --learning_rate 0.1 --lr_decay_epochs 30,60 --weight_decay 1e-4
Loading

0 comments on commit 23d693c

Please sign in to comment.