Skip to content

Commit

Permalink
prepare for FSDP
Browse files Browse the repository at this point in the history
  • Loading branch information
moskomule committed Dec 6, 2021
1 parent 49158a6 commit 71d661b
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions homura/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self,
profile: bool = False,
dist_kwargs: Optional[dict] = None,
prof_kwargs: Optional[dict] = None,
disable_auto_ddp: bool = False,
**kwargs):

if kwargs.get("update_scheduler_by_epoch"):
Expand Down Expand Up @@ -104,7 +105,8 @@ def __init__(self,
raise TypeError(f"Unknown type for `model`. Expected nn.Module or dict[str, Module], but got {type(model)}")

if "cuda" in str(self.device):
self.model.to(self.device)
if not disable_auto_ddp:
self.model.to(self.device)
torch.backends.cudnn.benchmark = not disable_cudnn_benchmark
self._cuda_nonblocking = not disable_cuda_nonblocking
self.logger.debug(f"cuda: True, cudnn.benchmark: {not disable_cudnn_benchmark}, "
Expand All @@ -114,7 +116,7 @@ def __init__(self,
# usually, this is not expected
self.logger.info(f"cuda: False (torch.cuda.is_available()={torch.cuda.is_available()})")

if is_distributed():
if is_distributed() and not disable_auto_ddp:
dist_kwargs = dist_kwargs or {}
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[rank], **dist_kwargs)
self.logger.debug(f"model converted to DistributedDataParallel at rank={rank}")
Expand All @@ -124,6 +126,8 @@ def __init__(self,
self.accessible_model = self.model.module
else:
self.accessible_model = self.model
if disable_auto_ddp:
self.logger.info("self.accessible_model need to be set manually")

self.optimizer = optimizer
self.scheduler = scheduler
Expand Down

0 comments on commit 71d661b

Please sign in to comment.