Source code for homura.utils.distributed

""" Helper functions to make distributed training easy
"""

import builtins
import os as python_os
import warnings
from functools import wraps
from typing import Callable, Optional

from torch import distributed
from torch.cuda import device_count

from homura.liblog import get_logger
from .environment import get_args, get_environ

logger = get_logger("homura.distributed")
original_print = builtins.print


[docs]def is_horovod_available() -> bool: warnings.warn("horovod is no longer supported by homura", DeprecationWarning) return False
[docs]def is_distributed_available() -> bool: return distributed.is_available()
[docs]def is_distributed() -> bool: """ Check if the process is distributed by checking the world size is larger than 1. """ return get_world_size() > 1
[docs]def get_local_rank() -> int: """ Get the local rank of the process, i.e., the process number of the node. """ return int(get_environ('LOCAL_RANK', 0))
[docs]def get_global_rank() -> int: """ Get the global rank of the process. 0 if the process is the master. """ return int(get_environ('RANK', 0))
[docs]def is_master() -> bool: return get_global_rank() == 0
[docs]def get_num_nodes() -> int: """ Get the number of nodes. Note that this function assumes all nodes have the same number of processes. """ if not is_distributed(): return 1 else: return get_world_size() // device_count()
[docs]def get_world_size() -> int: """ Get the world size, i.e., the total number of processes. """ return int(python_os.environ.get("WORLD_SIZE", 1))
def _print_if_master(self, *args, sep=' ', end='\n', file=None) -> None: if is_master(): original_print(self, *args, sep=sep, end=end, file=file)
[docs]def distributed_print(self, *args, sep=' ', end='\n', file=None) -> None: """ print something on any node """ if is_distributed(): self = f"[rank={get_global_rank()}] {self}" original_print(self, *args, sep=sep, end=end, file=file)
[docs]def init_distributed(use_horovod: bool = False, backend: Optional[str] = None, init_method: Optional[str] = None, disable_distributed_print: str = False ) -> None: """ Simple initializer for distributed training. This function substitutes print function with `_print_if_master`. :param use_horovod: If use horovod as distributed backend :param backend: backend of torch.distributed.init_process_group :param init_method: init_method of torch.distributed.init_process_group :param disable_distributed_print: :return: None """ if not is_distributed_available(): raise RuntimeError('Distributed training is not available on this machine') if use_horovod: raise DeprecationWarning("horovod is no longer supported by homura") # default values backend = backend or "nccl" init_method = init_method or "env://" if not is_distributed(): raise RuntimeError(f"For distributed training, use `python -m torch.distributed.launch " f"--nproc_per_node={device_count()} {get_args()}` ...") if not distributed.is_initialized(): distributed.init_process_group(backend=backend, init_method=init_method) logger.info("Distributed initialized") if not disable_distributed_print: builtins.print = _print_if_master
[docs]def distributed_ready_main(func: Callable = None, backend: Optional[str] = None, init_method: Optional[str] = None, disable_distributed_print: str = False ) -> Callable: """ Wrap a main function to make it distributed ready """ init_distributed(backend=backend, init_method=init_method, disable_distributed_print=disable_distributed_print) @wraps(func) def inner(*args, **kwargs): return func(*args, **kwargs) return inner
[docs]def if_is_master(func: Callable ) -> Callable: """ Wrap a void function that are active only if it is the master process:: @if_is_master def print_master(message): print(message) :param func: Any function """ @wraps(func) def inner(*args, **kwargs) -> None: if is_master(): return func(*args, **kwargs) return inner