Skip to content

Commit

Permalink
deal with last batch
Browse files Browse the repository at this point in the history
  • Loading branch information
triomino committed Jul 17, 2020
1 parent 4db1cd4 commit 9c473ad
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 28 deletions.
4 changes: 2 additions & 2 deletions dataset/imagenet_dali.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_dali_data_loader(args):
shard_id=args.rank,
num_shards=args.world_size)
pipe.build()
train_loader = DALIClassificationIterator(pipe, reader_name="Reader", fill_last_batch=False)
train_loader = DALIClassificationIterator(pipe, reader_name="Reader", fill_last_batch=True, last_batch_padded=False)

pipe = HybridValPipe(batch_size=args.batch_size,
num_threads=args.num_workers,
Expand All @@ -133,6 +133,6 @@ def get_dali_data_loader(args):
shard_id=args.rank,
num_shards=args.world_size)
pipe.build()
val_loader = DALIClassificationIterator(pipe, reader_name="Reader", fill_last_batch=False)
val_loader = DALIClassificationIterator(pipe, reader_name="Reader", fill_last_batch=False, last_batch_padded=False)

return train_loader, val_loader
2 changes: 1 addition & 1 deletion distiller_zoo/KD.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ def __init__(self, T):
def forward(self, y_s, y_t):
p_s = F.log_softmax(y_s/self.T, dim=1)
p_t = F.softmax(y_t/self.T, dim=1)
loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
loss = nn.KLDivLoss(reduction='batchmean')(p_s, p_t) * (self.T**2)
return loss
17 changes: 12 additions & 5 deletions helper/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, o
else:
input, target = data[0]['data'], data[0]['label'].squeeze().long()

# TODO: how to deal with the last batch
# if target.shape[0] < opt.batch_size:
# continue

if opt.gpu is not None:
input = input.cuda(opt.gpu, non_blocking=True)
if torch.cuda.is_available():
Expand Down Expand Up @@ -257,5 +253,16 @@ def validate(val_loader, model, criterion, opt):
'Acc@5 {top5.avg:.3f}'.format(
idx, n_batch, opt.gpu, batch_time=batch_time, loss=losses,
top1=top1, top5=top5))


if opt.multiprocessing_distributed:
# Batch size may not be equal across multiple gpus
total_metrics = torch.tensor([top1.sum, top5.sum, losses.sum]).to(opt.gpu)
count_metrics = torch.tensor([top1.count, top5.count, losses.count]).to(opt.gpu)
total_metrics = reduce_tensor(total_metrics, 1) # here world_size=1, because they should be summed up
count_metrics = reduce_tensor(count_metrics, 1)
ret = []
for s, n in zip(total_metrics.tolist(), count_metrics.tolist()):
ret.append(s / (1.0 * n))
return ret

return top1.avg, top5.avg, losses.avg
12 changes: 3 additions & 9 deletions helper/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ def update(self, val, n=1):
self.count += n
self.avg = self.sum / self.count

def merge(self, peer):
self.val = peer.val
self.sum += peer.sum
self.count += peer.count
self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
Expand Down Expand Up @@ -86,10 +79,11 @@ def load_json_to_dict(json_path):
params = json.load(f)
return params

def reduce_tensor(tensor, world_size = 1):
def reduce_tensor(tensor, world_size = 1, op='avg'):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= world_size
if world_size > 1:
rt = torch.true_divide(rt, world_size)
return rt

if __name__ == '__main__':
Expand Down
7 changes: 0 additions & 7 deletions train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,6 @@ def main_worker(gpu, ngpus_per_node, opt):

# validate teacher accuracy
teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt)
teacher_acc = torch.tensor([teacher_acc]).cuda(opt.gpu, non_blocking=True)
reduced = reduce_tensor(teacher_acc, opt.world_size)
teacher_acc = reduced.item()

if opt.dali is not None:
val_loader.reset()
Expand Down Expand Up @@ -386,10 +383,6 @@ def main_worker(gpu, ngpus_per_node, opt):
train_loader.reset()
val_loader.reset()

metrics = torch.tensor([test_acc, test_acc_top5, test_loss]).cuda(opt.gpu, non_blocking=True)
reduced = reduce_tensor(metrics, opt.world_size)
test_acc, test_acc_top5, test_loss = reduced.tolist()

if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
print(' ** Acc@1 {:.3f}, Acc@5 {:.3f}'.format(test_acc, test_acc_top5))

Expand Down
4 changes: 0 additions & 4 deletions train_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,6 @@ def main_worker(gpu, ngpus_per_node, opt):
train_loader.reset()
val_loader.reset()

metrics = torch.tensor([test_acc, test_acc_top5, test_loss]).cuda(opt.gpu, non_blocking=True)
reduced = reduce_tensor(metrics, opt.world_size)
test_acc, test_acc_top5, test_loss = reduced.tolist()

if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
print(' ** Acc@1 {:.3f}, Acc@5 {:.3f}'.format(test_acc, test_acc_top5))

Expand Down

0 comments on commit 9c473ad

Please sign in to comment.