Skip to content

Commit

Permalink
remove apex
Browse files Browse the repository at this point in the history
  • Loading branch information
triomino committed Jul 14, 2020
1 parent dc34553 commit ed14d11
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
9 changes: 6 additions & 3 deletions train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.backends.cudnn as cudnn
import tensorboard_logger as tb_logger

import apex
# import apex

from models import model_dict
from models.util import Embed, ConvReg, LinearEmbed, SelfA
Expand Down Expand Up @@ -303,8 +303,11 @@ def main_worker(gpu, ngpus_per_node, opt):
module_list.cuda(opt.gpu)
distributed_modules = []
for module in module_list:
DDP = torch.nn.parallel.DistributedDataParallel if opt.dali is None else apex.parallel.DistributedDataParallel
distributed_modules.append(DDP(module, delay_allreduce=True))
# TODO: test whether apex is faster
# DDP = torch.nn.parallel.DistributedDataParallel if opt.dali is None else apex.parallel.DistributedDataParallel
# distributed_modules.append(DDP(module, delay_allreduce=True))
DDP = torch.nn.parallel.DistributedDataParallel
distributed_modules.append(DDP(module, device_ids=[opt.gpu]))
module_list = distributed_modules
criterion_list.cuda(opt.gpu)
else:
Expand Down
8 changes: 5 additions & 3 deletions train_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.backends.cudnn as cudnn
import tensorboard_logger as tb_logger

import apex
# import apex

from models import model_dict

Expand Down Expand Up @@ -162,8 +162,10 @@ def main_worker(gpu, ngpus_per_node, opt):
# ourselves based on the total number of GPUs we have
opt.batch_size = int(opt.batch_size / ngpus_per_node)
opt.num_workers = int((opt.num_workers + ngpus_per_node - 1) / ngpus_per_node)
DDP = torch.nn.parallel.DistributedDataParallel if opt.dali is None else apex.parallel.DistributedDataParallel
model = DDP(model, delay_allreduce=True)
# DDP = torch.nn.parallel.DistributedDataParallel if opt.dali is None else apex.parallel.DistributedDataParallel
# model = DDP(model, delay_allreduce=True)
DDP = torch.nn.parallel.DistributedDataParallel
model = DDP(model, device_ids=[opt.gpu])
else:
print('multiprocessing_distributed must be with a specifiec gpu id')
else:
Expand Down

0 comments on commit ed14d11

Please sign in to comment.