From 0130273eb568788f3c9d88a42a4810468b47e750 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 23 Feb 2023 18:42:17 +0100 Subject: [PATCH] Trainer: auto default (#16847) --- .../source-pytorch/accelerators/gpu_basic.rst | 39 ++-- .../source-pytorch/accelerators/hpu_basic.rst | 40 ++-- .../source-pytorch/accelerators/ipu_basic.rst | 25 ++- .../source-pytorch/accelerators/tpu_basic.rst | 40 ++-- docs/source-pytorch/common/trainer.rst | 4 +- src/lightning/pytorch/CHANGELOG.md | 4 + .../connectors/accelerator_connector.py | 73 +++---- src/lightning/pytorch/trainer/trainer.py | 11 +- tests/tests_pytorch/accelerators/test_gpu.py | 2 +- tests/tests_pytorch/accelerators/test_tpu.py | 4 +- .../callbacks/test_early_stopping.py | 10 +- tests/tests_pytorch/callbacks/test_pruning.py | 2 +- .../callbacks/test_stochastic_weight_avg.py | 4 +- tests/tests_pytorch/conftest.py | 35 ++-- tests/tests_pytorch/models/test_amp.py | 5 +- .../strategies/test_ddp_spawn_strategy.py | 3 +- .../strategies/test_ddp_strategy.py | 5 +- .../strategies/test_deepspeed_strategy.py | 4 +- .../connectors/test_accelerator_connector.py | 198 ++++++++++++++++-- .../logging_/test_train_loop_logging.py | 2 +- tests/tests_pytorch/trainer/test_trainer.py | 18 +- 21 files changed, 336 insertions(+), 192 deletions(-) diff --git a/docs/source-pytorch/accelerators/gpu_basic.rst b/docs/source-pytorch/accelerators/gpu_basic.rst index 2d8ad71cb486c..852c898419c5a 100644 --- a/docs/source-pytorch/accelerators/gpu_basic.rst +++ b/docs/source-pytorch/accelerators/gpu_basic.rst @@ -14,30 +14,31 @@ A Graphics Processing Unit (GPU), is a specialized hardware accelerator designed ---- -Train on 1 GPU --------------- - -Make sure you're running on a machine with at least one GPU. There's no need to specify any NVIDIA flags -as Lightning will do it for you. - -.. testcode:: - :skipif: torch.cuda.device_count() < 1 - - trainer = Trainer(accelerator="gpu", devices=1) - ----------------- - - .. _multi_gpu: -Train on multiple GPUs ----------------------- +Train on GPUs +------------- -To use multiple GPUs, set the number of devices in the Trainer or the index of the GPUs. +The Trainer will run on all available GPUs by default. Make sure you're running on a machine with at least one GPU. +There's no need to specify any NVIDIA flags as Lightning will do it for you. -.. code:: +.. code-block:: python + + # run on as many GPUs as available by default + trainer = Trainer(accelerator="auto", devices="auto", strategy="auto") + # equivalent to + trainer = Trainer() - trainer = Trainer(accelerator="gpu", devices=4) + # run on one GPU + trainer = Trainer(accelerator="gpu", devices=1) + # run on multiple GPUs + trainer = Trainer(accelerator="gpu", devices=8) + # choose the number of devices automatically + trainer = Trainer(accelerator="gpu", devices="auto") + +.. note:: + Setting ``accelerator="gpu"`` will also automatically choose the "mps" device on Apple sillicon GPUs. + If you want to avoid this, you can set ``accelerator="cuda"`` instead. Choosing GPU devices ^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source-pytorch/accelerators/hpu_basic.rst b/docs/source-pytorch/accelerators/hpu_basic.rst index 2ee36fee2361d..28e4c93996f11 100644 --- a/docs/source-pytorch/accelerators/hpu_basic.rst +++ b/docs/source-pytorch/accelerators/hpu_basic.rst @@ -25,25 +25,30 @@ For more information, check out `Gaudi Architecture 1`` parameter with HPUs enables the Habana accelerator for distributed training. +It uses :class:`~pytorch_lightning.strategies.hpu_parallel.HPUParallelStrategy` internally which is based on DDP +strategy with the addition of Habana's collective communication library (HCCL) to support scale-up within a node and +scale-out across multiple nodes. ---- @@ -81,19 +86,6 @@ On Node 2: ---- -Select Gaudis automatically ---------------------------- - -Lightning can automatically detect the number of Gaudi devices to run on. This setting is enabled by default if the devices argument is missing. - -.. code-block:: python - - # equivalent - trainer = Trainer(accelerator="hpu") - trainer = Trainer(accelerator="hpu", devices="auto") - ----- - How to access HPUs ------------------ diff --git a/docs/source-pytorch/accelerators/ipu_basic.rst b/docs/source-pytorch/accelerators/ipu_basic.rst index 06cd056029bcc..e065a365afddc 100644 --- a/docs/source-pytorch/accelerators/ipu_basic.rst +++ b/docs/source-pytorch/accelerators/ipu_basic.rst @@ -24,23 +24,26 @@ See the `Graphcore Glossary None: self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1 @@ -344,12 +340,6 @@ def _check_device_config_and_set_final_flags( f" using {accelerator_name} accelerator." ) - if self._devices_flag == "auto" and self._accelerator_flag is None: - raise MisconfigurationException( - f"You passed `devices={devices}` but haven't specified" - " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu'|'hpu'|'mps')` for the devices mapping." - ) - def _set_accelerator_if_ipu_strategy_is_passed(self) -> None: # current logic only apply to object config # TODO this logic should apply to both str and object config @@ -357,18 +347,17 @@ def _set_accelerator_if_ipu_strategy_is_passed(self) -> None: self._accelerator_flag = "ipu" def _choose_auto_accelerator(self) -> str: - """Choose the accelerator type (str) based on availability when ``accelerator='auto'``.""" - if self._accelerator_flag == "auto": - if TPUAccelerator.is_available(): - return "tpu" - if _IPU_AVAILABLE: - return "ipu" - if HPUAccelerator.is_available(): - return "hpu" - if MPSAccelerator.is_available(): - return "mps" - if CUDAAccelerator.is_available(): - return "cuda" + """Choose the accelerator type (str) based on availability.""" + if TPUAccelerator.is_available(): + return "tpu" + if IPUAccelerator.is_available(): + return "ipu" + if HPUAccelerator.is_available(): + return "hpu" + if MPSAccelerator.is_available(): + return "mps" + if CUDAAccelerator.is_available(): + return "cuda" return "cpu" @staticmethod @@ -377,14 +366,12 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" - raise MisconfigurationException("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: if isinstance(self._accelerator_flag, Accelerator): self.accelerator: Accelerator = self._accelerator_flag else: - assert self._accelerator_flag is not None self.accelerator = AcceleratorRegistry.get(self._accelerator_flag) accelerator_cls = self.accelerator.__class__ @@ -407,7 +394,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: self._parallel_devices = accelerator_cls.get_parallel_devices(self._devices_flag) def _set_devices_flag_if_auto_passed(self) -> None: - if self._devices_flag == "auto" or self._devices_flag is None: + if self._devices_flag == "auto": self._devices_flag = self.accelerator.auto_device_count() def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: @@ -588,7 +575,7 @@ def _lazy_init_strategy(self) -> None: raise MisconfigurationException( f"`Trainer(strategy={self.strategy.strategy_name!r})` is not compatible with an interactive" " environment. Run your code as a script, or choose one of the compatible strategies:" - f" `Fabric(strategy=None|'dp'|'ddp_notebook')`." + f" `Fabric(strategy='dp'|'ddp_notebook')`." " In case you are spawning processes yourself, make sure to include the Trainer" " creation inside the worker function." ) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 168d5c1ad3678..7f4fb4742d660 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -96,7 +96,7 @@ def __init__( gradient_clip_val: Optional[Union[int, float]] = None, gradient_clip_algorithm: Optional[str] = None, num_nodes: int = 1, - devices: Optional[Union[List[int], str, int]] = None, + devices: Union[List[int], str, int] = "auto", enable_progress_bar: bool = True, overfit_batches: Union[int, float] = 0.0, check_val_every_n_epoch: Optional[int] = 1, @@ -113,8 +113,8 @@ def __init__( limit_predict_batches: Optional[Union[int, float]] = None, val_check_interval: Optional[Union[int, float]] = None, log_every_n_steps: int = 50, - accelerator: Optional[Union[str, Accelerator]] = None, - strategy: Optional[Union[str, Strategy]] = None, + accelerator: Union[str, Accelerator] = "auto", + strategy: Union[str, Strategy] = "auto", sync_batchnorm: bool = False, precision: _PRECISION_INPUT = "32-true", enable_model_summary: bool = True, @@ -259,9 +259,8 @@ def __init__( sampler was already added, Lightning will not replace the existing one. For iterable-style datasets, we don't do this automatically. - strategy: Supports different training strategies with aliases - as well custom strategies. - Default: ``None``. + strategy: Supports different training strategies with aliases as well custom strategies. + Default: ``"auto"``. sync_batchnorm: Synchronize batch norm layers between process groups/whole world. Default: ``False``. diff --git a/tests/tests_pytorch/accelerators/test_gpu.py b/tests/tests_pytorch/accelerators/test_gpu.py index 49d6998443954..7da631c2649fd 100644 --- a/tests/tests_pytorch/accelerators/test_gpu.py +++ b/tests/tests_pytorch/accelerators/test_gpu.py @@ -68,4 +68,4 @@ def test_gpu_availability(): @RunIf(min_cuda_gpus=1) def test_warning_if_gpus_not_used(): with pytest.warns(UserWarning, match="GPU available but not used. Set `accelerator` and `devices`"): - Trainer() + Trainer(accelerator="cpu") diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index 39ed1f6c14c8c..b1d8ef2a0d3b2 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -88,7 +88,7 @@ def test_accelerator_cpu_when_tpu_available(tpu_available): @RunIf(skip_windows=True) -@pytest.mark.parametrize(["accelerator", "devices"], [("auto", 8), ("auto", "auto"), ("tpu", None)]) +@pytest.mark.parametrize(["accelerator", "devices"], [("auto", 8), ("auto", "auto"), ("tpu", "auto")]) @mock.patch("lightning.pytorch.strategies.xla.XLAStrategy.set_world_ranks") def test_accelerator_tpu(_, accelerator, devices, tpu_available): assert TPUAccelerator.is_available() @@ -299,7 +299,7 @@ def __instancecheck__(self, instance): def test_warning_if_tpus_not_used(tpu_available): with pytest.warns(UserWarning, match="TPU available but not used. Set `accelerator` and `devices`"): - Trainer() + Trainer(accelerator="cpu") @RunIf(tpu=True, standalone=True) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 4a55aa8625970..4d8198d99e0e5 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -374,8 +374,8 @@ def on_train_end(self) -> None: @pytest.mark.parametrize( "callbacks, expected_stop_epoch, check_on_train_epoch_end, strategy, devices, dist_diverge_epoch", [ - ([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, None, 1, None), - ([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, None, 1, None), + ([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "auto", 1, None), + ([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "auto", 1, None), pytest.param( [EarlyStopping("abc", patience=1), EarlyStopping("cba")], 2, False, "ddp_spawn", 2, 2, **_SPAWN_MARK ), @@ -385,8 +385,8 @@ def on_train_end(self) -> None: pytest.param( [EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "ddp_spawn", 2, None, **_SPAWN_MARK ), - ([EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, True, None, 1, None), - ([EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], 3, True, None, 1, None), + ([EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, True, "auto", 1, None), + ([EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], 3, True, "auto", 1, None), pytest.param( [EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, @@ -412,7 +412,7 @@ def test_multiple_early_stopping_callbacks( callbacks: List[EarlyStopping], expected_stop_epoch: int, check_on_train_epoch_end: bool, - strategy: Optional[str], + strategy: str, devices: int, dist_diverge_epoch: Optional[int], ): diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index 4953420c70e6d..f6aa1c7f0e3b6 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -63,7 +63,7 @@ def train_with_pruning_callback( use_global_unstructured=False, pruning_fn="l1_unstructured", use_lottery_ticket_hypothesis=False, - strategy=None, + strategy="auto", accelerator="cpu", devices=1, ): diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index 35d88587100a7..d03e5ea5c064c 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -144,7 +144,7 @@ def on_train_end(self, trainer, pl_module): def train_with_swa( tmpdir, batchnorm=True, - strategy=None, + strategy="auto", accelerator="cpu", devices=1, interval="epoch", @@ -295,7 +295,7 @@ def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False) "default_root_dir": tmpdir, "max_epochs": 5, "accelerator": "cpu", - "strategy": "ddp_spawn" if ddp else None, + "strategy": "ddp_spawn" if ddp else "auto", "devices": 2 if ddp else 1, "limit_train_batches": 5, "limit_val_batches": 0, diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index d2723e67fa348..a3851dd5ab425 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -171,24 +171,33 @@ def mps_count_4(monkeypatch): mock_mps_count(monkeypatch, 4) +def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: + monkeypatch.setattr(lightning.pytorch.accelerators.tpu, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.pytorch.strategies.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.pytorch.strategies.single_tpu, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.pytorch.plugins.precision.tpu, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.pytorch.strategies.launchers.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.accelerators.tpu, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.strategies.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) + + @pytest.fixture(scope="function") def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(lightning.pytorch.accelerators.tpu, "_XLA_AVAILABLE", True) - monkeypatch.setattr(lightning.pytorch.strategies.xla, "_XLA_AVAILABLE", True) - monkeypatch.setattr(lightning.pytorch.strategies.single_tpu, "_XLA_AVAILABLE", True) - monkeypatch.setattr(lightning.pytorch.plugins.precision.tpu, "_XLA_AVAILABLE", True) - monkeypatch.setattr(lightning.pytorch.strategies.launchers.xla, "_XLA_AVAILABLE", True) - monkeypatch.setattr(lightning.fabric.accelerators.tpu, "_XLA_AVAILABLE", True) - monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", True) - monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", True) - monkeypatch.setattr(lightning.fabric.strategies.xla, "_XLA_AVAILABLE", True) - monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", True) + mock_xla_available(monkeypatch) + + +def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: + mock_xla_available(monkeypatch, value) + monkeypatch.setattr(lightning.pytorch.accelerators.tpu.TPUAccelerator, "is_available", lambda: value) + monkeypatch.setattr(lightning.fabric.accelerators.tpu.TPUAccelerator, "is_available", lambda: value) @pytest.fixture(scope="function") -def tpu_available(xla_available, monkeypatch) -> None: - monkeypatch.setattr(lightning.pytorch.accelerators.tpu.TPUAccelerator, "is_available", lambda: True) - monkeypatch.setattr(lightning.fabric.accelerators.tpu.TPUAccelerator, "is_available", lambda: True) +def tpu_available(monkeypatch) -> None: + mock_tpu_available(monkeypatch) @pytest.fixture diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index 618a79bb5421a..ecd1340165a01 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -82,19 +82,18 @@ def test_amp_cpus(tmpdir, strategy, precision, devices): trainer.predict(model) -@pytest.mark.parametrize("strategy", [None, "ddp_spawn"]) @pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))]) @pytest.mark.parametrize( "devices", (pytest.param(1, marks=RunIf(min_cuda_gpus=1)), pytest.param(2, marks=RunIf(min_cuda_gpus=2))) ) -def test_amp_gpus(tmpdir, strategy, precision, devices): +def test_amp_gpus(tmpdir, precision, devices): """Make sure combinations of AMP and strategies work if supported.""" trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=devices, - strategy=("ddp_spawn" if strategy is None and devices > 1 else strategy), + strategy=("ddp_spawn" if devices > 1 else "auto"), precision=precision, ) diff --git a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py index 174a75bc73a18..1b7a75fdda553 100644 --- a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py @@ -122,13 +122,14 @@ def test_ddp_spawn_configure_ddp(tmpdir): @mock.patch("torch.distributed.init_process_group") -def test_ddp_spawn_strategy_set_timeout(mock_init_process_group, cuda_count_2, mps_count_0): +def test_ddp_spawn_strategy_set_timeout(mock_init_process_group): """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function.""" test_timedelta = timedelta(seconds=30) model = BoringModel() ddp_spawn_strategy = DDPSpawnStrategy(timeout=test_timedelta) trainer = Trainer( max_epochs=1, + accelerator="cpu", strategy=ddp_spawn_strategy, ) # test wrap the model if fitting diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy.py b/tests/tests_pytorch/strategies/test_ddp_strategy.py index 9d1718d578cf6..9a352f750ed38 100644 --- a/tests/tests_pytorch/strategies/test_ddp_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_strategy.py @@ -115,7 +115,7 @@ def creates_processes_externally(self): @RunIf(skip_windows=True) -def test_ddp_configure_ddp(cuda_count_2, mps_count_0): +def test_ddp_configure_ddp(mps_count_0): """Tests with ddp strategy.""" model = BoringModel() ddp_strategy = DDPStrategy() @@ -229,13 +229,14 @@ def node_rank(self): @mock.patch("torch.distributed.init_process_group") -def test_ddp_strategy_set_timeout(mock_init_process_group, cuda_count_2, mps_count_0): +def test_ddp_strategy_set_timeout(mock_init_process_group): """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function.""" test_timedelta = timedelta(seconds=30) model = BoringModel() ddp_strategy = DDPStrategy(timeout=test_timedelta) trainer = Trainer( max_epochs=1, + accelerator="cpu", strategy=ddp_strategy, ) # test wrap the model if fitting diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index 76f248bb5264e..90cd89de75c4b 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -498,7 +498,7 @@ def test_deepspeed_multigpu_single_file(tmpdir): """Test to ensure that DeepSpeed loads from a single file checkpoint.""" model = BoringModel() checkpoint_path = os.path.join(tmpdir, "model.pt") - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator="cpu", devices=1) trainer.fit(model) trainer.save_checkpoint(checkpoint_path) @@ -712,6 +712,8 @@ def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): fast_dev_run=True, enable_progress_bar=False, enable_model_summary=False, + accelerator="cpu", + devices=1, ) trainer.fit(model) trainer.save_checkpoint(checkpoint_path) diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index b2c88a50d9c4b..3ae92163f01dc 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -13,6 +13,7 @@ # limitations under the License import inspect import os +import sys from typing import Any, Dict from unittest import mock from unittest.mock import Mock @@ -30,8 +31,9 @@ TorchElasticEnvironment, XLAEnvironment, ) +from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer -from lightning.pytorch.accelerators import TPUAccelerator +from lightning.pytorch.accelerators import HPUAccelerator, IPUAccelerator, TPUAccelerator from lightning.pytorch.accelerators.accelerator import Accelerator from lightning.pytorch.accelerators.cpu import CPUAccelerator from lightning.pytorch.accelerators.cuda import CUDAAccelerator @@ -43,7 +45,9 @@ DDPStrategy, DeepSpeedStrategy, FSDPStrategy, + IPUStrategy, SingleDeviceStrategy, + SingleHPUStrategy, SingleTPUStrategy, XLAStrategy, ) @@ -51,6 +55,7 @@ from lightning.pytorch.strategies.hpu_parallel import HPUParallelStrategy from lightning.pytorch.trainer.connectors.accelerator_connector import _set_torch_flags, AcceleratorConnector from lightning.pytorch.utilities.exceptions import MisconfigurationException +from tests_pytorch.conftest import mock_cuda_count, mock_mps_count, mock_tpu_available, mock_xla_available from tests_pytorch.helpers.runif import RunIf @@ -62,15 +67,13 @@ def test_accelerator_choice_cpu(tmpdir): @RunIf(tpu=True, standalone=True) @pytest.mark.parametrize( - ["accelerator", "devices"], [("tpu", None), ("tpu", 1), ("tpu", [1]), ("tpu", 8), ("auto", 1), ("auto", 8)] + ["accelerator", "devices"], [("tpu", "auto"), ("tpu", 1), ("tpu", [1]), ("tpu", 8), ("auto", 1), ("auto", 8)] ) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_accelerator_choice_tpu(accelerator, devices): connector = AcceleratorConnector(accelerator=accelerator, devices=devices) assert isinstance(connector.accelerator, TPUAccelerator) - if devices is None or (isinstance(devices, int) and devices > 1): - # accelerator=tpu, devices=None (default) maps to devices=auto (8) and then chooses XLAStrategy - # This behavior may change in the future: https://github.com/Lightning-AI/lightning/issues/10606 + if devices == "auto" or (isinstance(devices, int) and devices > 1): assert isinstance(connector.strategy, XLAStrategy) assert isinstance(connector.strategy.cluster_environment, XLAEnvironment) assert isinstance(connector.cluster_environment, XLAEnvironment) @@ -338,7 +341,7 @@ def test_set_devices_if_none_cpu(): pytest.param("deepspeed", DeepSpeedStrategy, marks=RunIf(deepspeed=True)), ), ) -@pytest.mark.parametrize("accelerator", ["mps", "auto", "gpu", None, MPSAccelerator()]) +@pytest.mark.parametrize("accelerator", ["mps", "auto", "gpu", MPSAccelerator()]) def test_invalid_ddp_strategy_with_mps(accelerator, strategy, strategy_class, mps_count_1, cuda_count_0): with pytest.raises(ValueError, match="strategies from the DDP family are not supported"): Trainer(accelerator=accelerator, strategy=strategy) @@ -440,7 +443,7 @@ def test_strategy_choice_ddp_cuda(strategy, expected_cls, mps_count_0, cuda_coun @pytest.mark.parametrize("job_name,expected_env", [("some_name", SLURMEnvironment), ("bash", LightningEnvironment)]) -@pytest.mark.parametrize("strategy", [None, "ddp", DDPStrategy]) +@pytest.mark.parametrize("strategy", ["auto", "ddp", DDPStrategy]) def test_strategy_choice_ddp_slurm(cuda_count_2, strategy, job_name, expected_env): if strategy and not isinstance(strategy, str): strategy = strategy() @@ -541,7 +544,7 @@ def test_strategy_choice_ddp_cpu_kubeflow(cuda_count_0): }, ) @mock.patch("lightning.pytorch.strategies.DDPStrategy.setup_distributed", autospec=True) -@pytest.mark.parametrize("strategy", [None, "ddp", DDPStrategy()]) +@pytest.mark.parametrize("strategy", ["auto", "ddp", DDPStrategy()]) def test_strategy_choice_ddp_cpu_slurm(cuda_count_0, strategy): trainer = Trainer(fast_dev_run=True, strategy=strategy, accelerator="cpu", devices=2) assert isinstance(trainer.accelerator, CPUAccelerator) @@ -572,23 +575,31 @@ def test_unsupported_tpu_choice(tpu_available): Trainer(accelerator="tpu", precision="16-mixed", strategy="ddp") -@mock.patch("lightning.pytorch.accelerators.ipu.IPUAccelerator.is_available", return_value=True) -def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch): - import lightning.pytorch.accelerators.ipu as ipu_ - import lightning.pytorch.strategies.ipu as ipu +def mock_ipu_available(monkeypatch, value=True): + monkeypatch.setattr(lightning.pytorch.accelerators.ipu, "_IPU_AVAILABLE", value) + monkeypatch.setattr(lightning.pytorch.strategies.ipu, "_IPU_AVAILABLE", value) - monkeypatch.setattr(ipu_, "_IPU_AVAILABLE", True) - monkeypatch.setattr(ipu, "_IPU_AVAILABLE", True) + +def test_unsupported_ipu_choice(monkeypatch): + mock_ipu_available(monkeypatch) with pytest.raises(ValueError, match=r"accelerator='ipu', precision='bf16-mixed'\)` is not supported"): Trainer(accelerator="ipu", precision="bf16-mixed") with pytest.raises(ValueError, match=r"accelerator='ipu', precision='64-true'\)` is not supported"): Trainer(accelerator="ipu", precision="64-true") -@mock.patch("lightning.pytorch.accelerators.tpu._XLA_AVAILABLE", return_value=False) -@mock.patch("lightning.pytorch.accelerators.ipu._IPU_AVAILABLE", return_value=False) -@mock.patch("lightning.pytorch.accelerators.hpu._HPU_AVAILABLE", return_value=False) -def test_devices_auto_choice_cpu(cuda_count_0, *_): +def mock_hpu_available(monkeypatch, value=True): + monkeypatch.setattr(lightning.pytorch.accelerators.hpu, "_HPU_AVAILABLE", value) + monkeypatch.setattr(lightning.pytorch.accelerators.hpu.HPUAccelerator, "is_available", lambda: value) + monkeypatch.setattr(lightning.pytorch.strategies.hpu_parallel, "_HPU_AVAILABLE", value) + monkeypatch.setattr(lightning.pytorch.strategies.single_hpu, "_HPU_AVAILABLE", value) + monkeypatch.setattr(lightning.pytorch.plugins.precision.hpu, "_HPU_AVAILABLE", value) + + +def test_devices_auto_choice_cpu(monkeypatch, cuda_count_0): + mock_ipu_available(monkeypatch, False) + mock_hpu_available(monkeypatch, False) + mock_xla_available(monkeypatch, False) trainer = Trainer(accelerator="auto", devices="auto") assert trainer.num_devices == 1 @@ -761,10 +772,8 @@ def test_gpu_accelerator_misconfiguration_exception(*_): Trainer(accelerator="gpu") -@mock.patch("lightning.pytorch.accelerators.hpu.HPUAccelerator.is_available", return_value=True) -@mock.patch("lightning.pytorch.strategies.hpu_parallel._HPU_AVAILABLE", return_value=True) -@mock.patch("lightning.pytorch.plugins.precision.hpu._HPU_AVAILABLE", return_value=True) -def test_accelerator_specific_checkpoint_io(*_): +def test_accelerator_specific_checkpoint_io(monkeypatch): + mock_hpu_available(monkeypatch) ckpt_plugin = TorchCheckpointIO() trainer = Trainer(accelerator="hpu", strategy=HPUParallelStrategy(), plugins=[ckpt_plugin]) assert trainer.strategy.checkpoint_io is ckpt_plugin @@ -815,3 +824,148 @@ def test_colossalai_external_strategy(monkeypatch): trainer = Trainer(strategy="colossalai", precision="16-mixed") assert isinstance(trainer.strategy, ColossalAIStrategy) + + +def test_connector_auto_selection(monkeypatch): + import lightning.fabric # avoid breakage with standalone package + + def _mock_tpu_available(value): + mock_tpu_available(monkeypatch, value) + monkeypatch.setitem(sys.modules, "torch_xla", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) + monkeypatch.setattr(lightning.fabric.plugins.environments.XLAEnvironment, "node_rank", lambda *_: 0) + + # CPU + with monkeypatch.context(): + mock_cuda_count(monkeypatch, 0) + mock_mps_count(monkeypatch, 0) + mock_tpu_available(monkeypatch, False) + mock_ipu_available(monkeypatch, False) + mock_hpu_available(monkeypatch, False) + connector = AcceleratorConnector() + assert isinstance(connector.accelerator, CPUAccelerator) + assert isinstance(connector.strategy, SingleDeviceStrategy) + assert connector._devices_flag == 1 + + # single CUDA + with monkeypatch.context(): + mock_cuda_count(monkeypatch, 1) + mock_mps_count(monkeypatch, 0) + mock_tpu_available(monkeypatch, False) + mock_ipu_available(monkeypatch, False) + mock_hpu_available(monkeypatch, False) + connector = AcceleratorConnector() + assert isinstance(connector.accelerator, CUDAAccelerator) + assert isinstance(connector.strategy, SingleDeviceStrategy) + assert connector._devices_flag == [0] + + # multi CUDA + with monkeypatch.context(): + mock_cuda_count(monkeypatch, 4) + mock_mps_count(monkeypatch, 0) + mock_tpu_available(monkeypatch, False) + mock_ipu_available(monkeypatch, False) + mock_hpu_available(monkeypatch, False) + connector = AcceleratorConnector() + assert isinstance(connector.accelerator, CUDAAccelerator) + assert isinstance(connector.strategy, DDPStrategy) + assert connector._devices_flag == list(range(4)) + + # MPS (there's no distributed) + with monkeypatch.context(): + mock_cuda_count(monkeypatch, 0) + mock_mps_count(monkeypatch, 1) + mock_tpu_available(monkeypatch, False) + mock_ipu_available(monkeypatch, False) + mock_hpu_available(monkeypatch, False) + connector = AcceleratorConnector() + assert isinstance(connector.accelerator, MPSAccelerator) + assert isinstance(connector.strategy, SingleDeviceStrategy) + assert connector._devices_flag == [0] + + # single TPU + with monkeypatch.context(): + mock_cuda_count(monkeypatch, 0) + mock_mps_count(monkeypatch, 0) + mock_ipu_available(monkeypatch, False) + mock_hpu_available(monkeypatch, False) + _mock_tpu_available(True) + # TPUAccelerator.auto_device_count always returns 8, but in case this changes in the future... + monkeypatch.setattr(lightning.pytorch.accelerators.TPUAccelerator, "auto_device_count", lambda *_: 1) + monkeypatch.setattr(torch, "device", Mock()) + connector = AcceleratorConnector() + assert isinstance(connector.accelerator, TPUAccelerator) + assert isinstance(connector.strategy, SingleTPUStrategy) + assert connector._devices_flag == 1 + + monkeypatch.undo() # for some reason `.context()` is not working properly + assert lightning.fabric.accelerators.TPUAccelerator.auto_device_count() == 8 + + # Multi TPU + with monkeypatch.context(): + if _IS_WINDOWS: + # simulate fork support on windows + monkeypatch.setattr(torch.multiprocessing, "get_all_start_methods", lambda: ["fork", "spawn"]) + mock_cuda_count(monkeypatch, 0) + mock_mps_count(monkeypatch, 0) + _mock_tpu_available(True) + mock_ipu_available(monkeypatch, False) + mock_hpu_available(monkeypatch, False) + connector = AcceleratorConnector() + assert isinstance(connector.accelerator, TPUAccelerator) + assert isinstance(connector.strategy, XLAStrategy) + assert connector._devices_flag == 8 + + # Single/Multi IPU: strategy is the same + with monkeypatch.context(): + mock_cuda_count(monkeypatch, 0) + mock_mps_count(monkeypatch, 0) + mock_tpu_available(monkeypatch, False) + mock_ipu_available(monkeypatch, True) + mock_hpu_available(monkeypatch, False) + connector = AcceleratorConnector() + assert isinstance(connector.accelerator, IPUAccelerator) + assert isinstance(connector.strategy, IPUStrategy) + assert connector._devices_flag == 4 + + # Single HPU + with monkeypatch.context(): + mock_cuda_count(monkeypatch, 0) + mock_mps_count(monkeypatch, 0) + mock_tpu_available(monkeypatch, False) + mock_ipu_available(monkeypatch, False) + mock_hpu_available(monkeypatch, True) + monkeypatch.setattr(lightning.pytorch.accelerators.hpu.HPUAccelerator, "auto_device_count", lambda *_: 1) + connector = AcceleratorConnector() + assert isinstance(connector.accelerator, HPUAccelerator) + assert isinstance(connector.strategy, SingleHPUStrategy) + assert connector._devices_flag == 1 + + monkeypatch.undo() # for some reason `.context()` is not working properly + + # Multi HPU + with monkeypatch.context(): + mock_cuda_count(monkeypatch, 0) + mock_mps_count(monkeypatch, 0) + mock_tpu_available(monkeypatch, False) + mock_ipu_available(monkeypatch, False) + mock_hpu_available(monkeypatch, True) + connector = AcceleratorConnector() + assert isinstance(connector.accelerator, HPUAccelerator) + assert isinstance(connector.strategy, HPUParallelStrategy) + assert connector._devices_flag == 8 + + # TPU and CUDA: prefers TPU + with monkeypatch.context(): + if _IS_WINDOWS: + # simulate fork support on windows + monkeypatch.setattr(torch.multiprocessing, "get_all_start_methods", lambda: ["fork", "spawn"]) + mock_cuda_count(monkeypatch, 2) + mock_mps_count(monkeypatch, 0) + _mock_tpu_available(True) + mock_ipu_available(monkeypatch, False) + mock_hpu_available(monkeypatch, False) + connector = AcceleratorConnector() + assert isinstance(connector.accelerator, TPUAccelerator) + assert isinstance(connector.strategy, XLAStrategy) + assert connector._devices_flag == 8 diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index 4d311f792de1a..dd3756e814057 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -363,7 +363,7 @@ def test_logging_sync_dist_true(tmpdir, devices, accelerator): limit_train_batches=3, limit_val_batches=3, enable_model_summary=False, - strategy="ddp_spawn" if use_multiple_devices else None, + strategy="ddp_spawn" if use_multiple_devices else "auto", accelerator=accelerator, devices=devices, ) diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index bf80f15b44c68..3bb34392abc5c 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1292,9 +1292,9 @@ def on_predict_epoch_end(self, trainer, pl_module): def predict( tmpdir, - strategy=None, - accelerator=None, - devices=None, + strategy="auto", + accelerator="auto", + devices="auto", model=None, plugins=None, datamodule=True, @@ -1629,7 +1629,9 @@ def on_predict_start(self) -> None: assert not self.training -@pytest.mark.parametrize("strategy,devices", [(None, 1), pytest.param("ddp_spawn", 1, marks=RunIf(skip_windows=True))]) +@pytest.mark.parametrize( + "strategy,devices", [("auto", 1), pytest.param("ddp_spawn", 1, marks=RunIf(skip_windows=True))] +) def test_model_in_correct_mode_during_stages(tmpdir, strategy, devices): model = TrainerStagesModel() trainer = Trainer(default_root_dir=tmpdir, strategy=strategy, accelerator="cpu", devices=devices, fast_dev_run=True) @@ -1776,7 +1778,7 @@ def on_exception(self, *_): self.exceptions += 1 -@pytest.mark.parametrize("strategy", [None, pytest.param("ddp_spawn", marks=RunIf(skip_windows=True, mps=False))]) +@pytest.mark.parametrize("strategy", ["auto", pytest.param("ddp_spawn", marks=RunIf(skip_windows=True, mps=False))]) def test_error_handling_all_stages(tmpdir, strategy): model = TrainerStagesErrorsModel() counter = ExceptionCounter() @@ -1871,13 +1873,13 @@ def training_step(self, batch, batch_idx): @pytest.mark.parametrize( ["trainer_kwargs", "strategy_cls", "strategy_name", "accelerator_cls", "devices"], [ - ({"strategy": None}, SingleDeviceStrategy, "single_device", CPUAccelerator, 1), + ({"strategy": "auto"}, SingleDeviceStrategy, "single_device", CPUAccelerator, 1), pytest.param({"strategy": "ddp"}, DDPStrategy, "ddp", CPUAccelerator, 1, marks=RunIf(mps=False)), pytest.param( {"strategy": "ddp", "num_nodes": 2}, DDPStrategy, "ddp", CPUAccelerator, 1, marks=RunIf(mps=False) ), ( - {"strategy": None, "accelerator": "cuda", "devices": 1}, + {"strategy": "auto", "accelerator": "cuda", "devices": 1}, SingleDeviceStrategy, "single_device", CUDAAccelerator, @@ -1891,7 +1893,7 @@ def training_step(self, batch, batch_idx): CUDAAccelerator, 1, ), - ({"strategy": None, "accelerator": "cuda", "devices": 2}, DDPStrategy, "ddp", CUDAAccelerator, 2), + ({"strategy": "auto", "accelerator": "cuda", "devices": 2}, DDPStrategy, "ddp", CUDAAccelerator, 2), ({"strategy": "ddp", "accelerator": "cuda", "devices": 2}, DDPStrategy, "ddp", CUDAAccelerator, 2), ({"strategy": "ddp", "accelerator": "cpu", "devices": 2}, DDPStrategy, "ddp", CPUAccelerator, 2), (