From 6dad88e9355490729d86a73e3ae18dc7656d8620 Mon Sep 17 00:00:00 2001 From: "Deng, Weishi" Date: Tue, 21 Feb 2023 12:11:20 +0000 Subject: [PATCH] enable xpu for single card training on intel gpus [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci change version comparison to base version number [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci enable bf16 for xpu, enable ccl [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci switch from deprecated set_default_tensor_type to set_default_dtype switch to info to print ipex and torch-ccl version number fix set_default_dtype incorrect argument error --- src/lightning/fabric/accelerators/__init__.py | 1 + src/lightning/fabric/accelerators/xpu.py | 259 ++++++++++++++++++ src/lightning/fabric/cli.py | 6 +- src/lightning/fabric/connector.py | 6 + .../fabric/utilities/device_dtype_mixin.py | 19 ++ .../fabric/utilities/device_parser.py | 34 ++- src/lightning/fabric/utilities/distributed.py | 16 +- src/lightning/fabric/utilities/enums.py | 1 + src/lightning/fabric/utilities/imports.py | 4 +- .../pytorch/accelerators/__init__.py | 1 + src/lightning/pytorch/accelerators/xpu.py | 162 +++++++++++ .../pytorch/plugins/precision/double.py | 8 +- .../connectors/accelerator_connector.py | 24 +- src/lightning/pytorch/trainer/setup.py | 13 +- src/lightning/pytorch/utilities/imports.py | 2 +- 15 files changed, 527 insertions(+), 29 deletions(-) create mode 100644 src/lightning/fabric/accelerators/xpu.py create mode 100644 src/lightning/pytorch/accelerators/xpu.py diff --git a/src/lightning/fabric/accelerators/__init__.py b/src/lightning/fabric/accelerators/__init__.py index 4480814d207d10..29b6ab0fe83508 100644 --- a/src/lightning/fabric/accelerators/__init__.py +++ b/src/lightning/fabric/accelerators/__init__.py @@ -16,6 +16,7 @@ from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401 from lightning.fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators from lightning.fabric.accelerators.tpu import TPUAccelerator # noqa: F401 +from lightning.fabric.accelerators.xpu import XPUAccelerator # noqa: F401 _ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators" ACCELERATOR_REGISTRY = _AcceleratorRegistry() diff --git a/src/lightning/fabric/accelerators/xpu.py b/src/lightning/fabric/accelerators/xpu.py new file mode 100644 index 00000000000000..9ebd062c1d7a57 --- /dev/null +++ b/src/lightning/fabric/accelerators/xpu.py @@ -0,0 +1,259 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import lru_cache +from typing import Dict, List, Optional, Union + +import torch +from lightning_utilities.core.rank_zero import rank_zero_info + +from lightning.fabric.accelerators.accelerator import Accelerator +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13 + +try: + import intel_extension_for_pytorch as ipex + + rank_zero_info(f"Using Intel® Extension for PyTorch* {ipex.__version__}") +except ImportError: + pass + + +class XPUAccelerator(Accelerator): + """Accelerator for Intel XPU devices.""" + + def setup_device(self, device: torch.device) -> None: + """ + Raises: + ValueError: + If the selected device is not of type XPU. + """ + if device.type != "xpu": + raise ValueError(f"Device should be XPU, got {device} instead.") + _check_xpu_math_precision(device) + torch.xpu.set_device(device) + + def teardown(self) -> None: + # clean up memory + torch.xpu.empty_cache() + + @staticmethod + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + """Accelerator device parsing logic.""" + from lightning.fabric.utilities.device_parser import _parse_gpu_ids + + return _parse_gpu_ids(devices, include_xpu=True) + + @staticmethod + def get_parallel_devices(devices: List[int]) -> List[torch.device]: + """Gets parallel devices for the Accelerator.""" + return [torch.device("xpu", i) for i in devices] + + @staticmethod + def auto_device_count() -> int: + """Get the devices when set to auto.""" + return num_xpu_devices() + + @staticmethod + def is_available() -> bool: + return num_xpu_devices() > 0 + + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "xpu", + cls, + description=cls.__class__.__name__, + ) + + +def find_usable_xpu_devices(num_devices: int = -1) -> List[int]: + """Returns a list of all available and usable XPU GPU devices. + + A GPU is considered usable if we can successfully move a tensor to the device, and this is what this function + tests for each GPU on the system until the target number of usable devices is found. + + A subset of GPUs on the system might be used by other processes, and if the GPU is configured to operate in + 'exclusive' mode (configurable by the admin), then only one process is allowed to occupy it. + + Args: + num_devices: The number of devices you want to request. By default, this function will return as many as there + are usable XPU GPU devices available. + + Warning: + If multiple processes call this function at the same time, there can be race conditions in the case where + both processes determine that the device is unoccupied, leading into one of them crashing later on. + """ + visible_devices = _get_all_visible_xpu_devices() + if not visible_devices: + raise ValueError( + f"You requested to find {num_devices} devices but there are no visible XPU devices on this machine." + ) + if num_devices > len(visible_devices): + raise ValueError( + f"You requested to find {num_devices} devices but this machine only has {len(visible_devices)} GPUs." + ) + + available_devices = [] + unavailable_devices = [] + + for gpu_idx in visible_devices: + try: + torch.tensor(0, device=torch.device("xpu", gpu_idx)) + except RuntimeError: + unavailable_devices.append(gpu_idx) + continue + + available_devices.append(gpu_idx) + if len(available_devices) == num_devices: + # exit early if we found the right number of GPUs + break + + if len(available_devices) != num_devices: + raise RuntimeError( + f"You requested to find {num_devices} devices but only {len(available_devices)} are currently available." + f" The devices {unavailable_devices} are occupied by other processes and can't be used at the moment." + ) + return available_devices + + +def _get_all_visible_xpu_devices() -> List[int]: + """Returns a list of all visible XPU GPU devices. + + Devices masked by the environment variabale ``CUDA_VISIBLE_DEVICES`` won't be returned here. For example, assume you + have 8 physical GPUs. If ``CUDA_VISIBLE_DEVICES="1,3,6"``, then this function will return the list ``[0, 1, 2]`` + because these are the three visible GPUs after applying the mask ``CUDA_VISIBLE_DEVICES``. + """ + return list(range(num_xpu_devices())) + + +## TODO: Remove once minimum supported PyTorch version is 2.0 +# @contextmanager +# def _patch_cuda_is_available() -> Generator: +# """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if +# possible.""" +# if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0 and not _TORCH_GREATER_EQUAL_2_0: +# # we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding +# # otherwise, patching is_available could lead to attribute errors or infinite recursion +# orig_check = torch.cuda.is_available +# torch.cuda.is_available = is_cuda_available +# try: +# yield +# finally: +# torch.cuda.is_available = orig_check +# else: +# yield + + +@lru_cache(1) +def num_xpu_devices() -> int: + """Returns the number of available XPU devices. + + Unlike :func:`torch.xpu.device_count`, this function does its best not to create a XPU context for fork support, if + the platform allows it. + """ + if _TORCH_GREATER_EQUAL_1_13: + try: + return torch.xpu.device_count() + except AttributeError: + return 0 + + ## Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879 + ## TODO: Remove once minimum supported PyTorch version is 1.13 + # nvml_count = _device_count_nvml() + # return torch.cuda.device_count() if nvml_count < 0 else nvml_count + return 0 + + +def is_xpu_available() -> bool: + """Returns a bool indicating if XPU is currently available. + + Unlike :func:`torch.xpu.is_available`, this function does its best not to create a XPU context for fork support, if + the platform allows it. + """ + ## We set `PYTORCH_NVML_BASED_CUDA_CHECK=1` in lightning.fabric.__init__.py + # return torch.cuda.is_available() if _TORCH_GREATER_EQUAL_2_0 else num_cuda_devices() > 0 + return torch.xpu.is_available() + + +## TODO: Remove once minimum supported PyTorch version is 1.13 +# def _parse_visible_devices() -> Set[int]: +# """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" +# var = os.getenv("CUDA_VISIBLE_DEVICES") +# if var is None: +# return {x for x in range(64)} +# +# def _strtoul(s: str) -> int: +# """Return -1 or integer sequence string starts with.""" +# if len(s) == 0: +# return -1 +# for idx, c in enumerate(s): +# if not c.isdigit(): +# break +# if idx + 1 == len(s): +# idx += 1 +# return int(s[:idx]) if idx > 0 else -1 +# +# # CUDA_VISIBLE_DEVICES uses something like strtoul +# # which makes `1gpu2,2ampere` is equivalent to `1,2` +# rc: Set[int] = set() +# for elem in var.split(","): +# rc.add(_strtoul(elem.strip())) +# return rc + + +## TODO: Remove once minimum supported PyTorch version is 1.13 +# def _raw_device_count_nvml() -> int: +# """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" +# from ctypes import c_int, CDLL +# +# nvml_h = CDLL("libnvidia-ml.so.1") +# rc = nvml_h.nvmlInit() +# if rc != 0: +# warnings.warn("Can't initialize NVML") +# return -1 +# dev_arr = (c_int * 1)(-1) +# rc = nvml_h.nvmlDeviceGetCount_v2(dev_arr) +# if rc != 0: +# warnings.warn("Can't get nvml device count") +# return -1 +# del nvml_h +# return dev_arr[0] + + +## TODO: Remove once minimum supported PyTorch version is 1.13 +# def _device_count_nvml() -> int: +# """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879.""" +# try: +# raw_cnt = _raw_device_count_nvml() +# if raw_cnt <= 0: +# return raw_cnt +# return len(set(range(raw_cnt)).intersection(_parse_visible_devices())) +# except OSError: +# return -1 +# except AttributeError: +# return -1 + + +def _check_xpu_math_precision(device: torch.device) -> None: + if not _TORCH_GREATER_EQUAL_1_12: + # before 1.12, tf32 was used by default + return + # check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and + # `set_float32_matmul_precision` + if torch.xpu.get_fp32_math_mode() == torch.xpu.FP32MathMode.FP32: # default + rank_zero_info( + f"You are using an XPU device ({torch.xpu.get_device_name(device)!r}). To properly utilize computation " + "power, you can set `torch.xpu.set_fp32_math_mode(mode=torch.xpu.FP32MathMode.FP32, device='cpu')` " + "which will trade-off precision for performance. For more details, read https://intel.github.io/" + "intel-extension-for-pytorch/xpu/latest/tutorials/api_doc.html#torch.xpu.set_fp32_math_mode" + ) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index cd715b79cf673f..71c16bbf0a962b 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -20,7 +20,7 @@ from lightning_utilities.core.imports import RequirementCache from typing_extensions import get_args -from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator +from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator, XPUAccelerator from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS from lightning.fabric.strategies import STRATEGY_REGISTRY from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -29,7 +29,7 @@ _CLICK_AVAILABLE = RequirementCache("click") -_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") +_SUPPORTED_ACCELERATORS = ("cpu", "xpu", "gpu", "cuda", "mps", "tpu") def _get_supported_strategies() -> List[str]: @@ -147,6 +147,8 @@ def _get_num_processes(accelerator: str, devices: str) -> int: """Parse the `devices` argument to determine how many processes need to be launched on the current machine.""" if accelerator == "gpu": parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True) + elif accelerator == "xpu": + parsed_devices = XPUAccelerator.parse_devices(devices) elif accelerator == "cuda": parsed_devices = CUDAAccelerator.parse_devices(devices) elif accelerator == "mps": diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index f9cdb56617554a..eef785abd63641 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -23,6 +23,7 @@ from lightning.fabric.accelerators.cuda import CUDAAccelerator from lightning.fabric.accelerators.mps import MPSAccelerator from lightning.fabric.accelerators.tpu import TPUAccelerator +from lightning.fabric.accelerators.xpu import XPUAccelerator from lightning.fabric.plugins import ( CheckpointIO, DeepSpeedPrecision, @@ -318,6 +319,8 @@ def _choose_auto_accelerator(self) -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if XPUAccelerator.is_available(): + return "xpu" return "cpu" @staticmethod @@ -326,6 +329,9 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if XPUAccelerator.is_available(): + return "xpu" + raise RuntimeError("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: diff --git a/src/lightning/fabric/utilities/device_dtype_mixin.py b/src/lightning/fabric/utilities/device_dtype_mixin.py index 30a7f0f8258ff0..d6ebdf9ee76f62 100644 --- a/src/lightning/fabric/utilities/device_dtype_mixin.py +++ b/src/lightning/fabric/utilities/device_dtype_mixin.py @@ -53,6 +53,25 @@ def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type] self.__update_properties(device=device, dtype=dtype) return super().to(*args, **kwargs) + def xpu(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type] + """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers + different objects. So it should be called before constructing optimizer if the module will live on GPU + while being optimized. + + Arguments: + device: If specified, all parameters will be copied to that device. If `None`, the current XPU device + index will be used. + + Returns: + Module: self + """ + if device is None: + device = torch.device("xpu", torch.xpu.current_device()) + elif isinstance(device, int): + device = torch.device("xpu", index=device) + self.__update_properties(device=device) + return super().xpu(device=device) + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type] """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index 3ec00e9e5a9a9f..1a05d21be10eb8 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -49,6 +49,7 @@ def _determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]: def _parse_gpu_ids( gpus: Optional[Union[int, str, List[int]]], + include_xpu: bool = False, include_cuda: bool = False, include_mps: bool = False, ) -> Optional[List[int]]: @@ -62,6 +63,7 @@ def _parse_gpu_ids( indicates specific GPUs to use. An int of 0 means that no GPUs should be used. Any int N > 0 indicates that GPUs [0..N) should be used. + include_xpu: A boolean value indicating whether to include XPU devices for GPU parsing. include_cuda: A boolean value indicating whether to include CUDA devices for GPU parsing. include_mps: A boolean value indicating whether to include MPS devices for GPU parsing. @@ -73,7 +75,7 @@ def _parse_gpu_ids( If no GPUs are available but the value of gpus variable indicates request for GPUs .. note:: - ``include_cuda`` and ``include_mps`` default to ``False`` so that you only + ``include_xpu``, ``include_cuda`` and ``include_mps`` default to ``False`` so that you only have to specify which device type to use and all other devices are not disabled. """ # Check that gpus param is None, Int, String or Sequence of Ints @@ -86,14 +88,17 @@ def _parse_gpu_ids( # We know the user requested GPUs therefore if some of the # requested GPUs are not available an exception is thrown. gpus = _normalize_parse_gpu_string_input(gpus) - gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps) + gpus = _normalize_parse_gpu_input_to_list( + gpus, include_xpu=include_xpu, include_cuda=include_cuda, include_mps=include_mps + ) if not gpus: raise MisconfigurationException("GPUs requested but none are available.") if ( TorchElasticEnvironment.detect() and len(gpus) != 1 - and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1 + and len(_get_all_available_gpus(include_xpu=include_xpu, include_cuda=include_cuda, include_mps=include_mps)) + == 1 ): # Omit sanity check on torchelastic because by default it shows one visible GPU per process return gpus @@ -101,7 +106,7 @@ def _parse_gpu_ids( # Check that GPUs are unique. Duplicate GPUs are not supported by the backend. _check_unique(gpus) - return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) + return _sanitize_gpu_ids(gpus, include_xpu=include_xpu, include_cuda=include_cuda, include_mps=include_mps) def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]: @@ -114,7 +119,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in return int(s.strip()) -def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _sanitize_gpu_ids( + gpus: List[int], include_xpu: bool = False, include_cuda: bool = False, include_mps: bool = False +) -> List[int]: """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the GPUs is not available. @@ -128,9 +135,11 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: MisconfigurationException: If machine has fewer available GPUs than requested. """ - if sum((include_cuda, include_mps)) == 0: + if sum((include_xpu, include_cuda, include_mps)) == 0: raise ValueError("At least one gpu type should be specified!") - all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + all_available_gpus = _get_all_available_gpus( + include_xpu=include_xpu, include_cuda=include_cuda, include_mps=include_mps + ) for gpu in gpus: if gpu not in all_available_gpus: raise MisconfigurationException( @@ -140,7 +149,7 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: def _normalize_parse_gpu_input_to_list( - gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool + gpus: Union[int, List[int], Tuple[int, ...]], include_xpu: bool, include_cuda: bool, include_mps: bool ) -> Optional[List[int]]: assert gpus is not None if isinstance(gpus, (MutableSequence, tuple)): @@ -150,19 +159,22 @@ def _normalize_parse_gpu_input_to_list( if not gpus: # gpus==0 return None if gpus == -1: - return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + return _get_all_available_gpus(include_xpu=include_xpu, include_cuda=include_cuda, include_mps=include_mps) return list(range(gpus)) -def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _get_all_available_gpus( + include_xpu: bool = False, include_cuda: bool = False, include_mps: bool = False +) -> List[int]: """ Returns: A list of all available GPUs """ + xpu_gpus = accelerators.xpu._get_all_visible_xpu_devices() if include_xpu else [] cuda_gpus = accelerators.cuda._get_all_visible_cuda_devices() if include_cuda else [] mps_gpus = accelerators.mps._get_all_available_mps_gpus() if include_mps else [] - return cuda_gpus + mps_gpus + return xpu_gpus + cuda_gpus + mps_gpus def _check_unique(device_ids: List[int]) -> None: diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 2b574e020b4d91..e9cd8461540f8a 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -15,6 +15,13 @@ from lightning.fabric.utilities.rank_zero import rank_zero_info from lightning.fabric.utilities.types import ReduceOp +try: + import oneccl_bindings_for_pytorch as torch_ccl + + rank_zero_info(f"Using Intel® oneCCL Bindings for PyTorch* {torch_ccl.__version__}") +except ImportError: + pass + if torch.distributed.is_available(): from torch.distributed import group else: @@ -223,7 +230,7 @@ def _init_dist_connection( Args: cluster_environment: ``ClusterEnvironment`` instance - torch_distributed_backend: Backend to use (includes `nccl` and `gloo`) + torch_distributed_backend: Backend to use (includes `nccl`, `ccl` and `gloo`) global_rank: Rank of the current process world_size: Number of processes in the group kwargs: Kwargs for ``init_process_group`` @@ -254,7 +261,12 @@ def _init_dist_connection( def _get_default_process_group_backend_for_device(device: torch.device) -> str: - return "nccl" if device.type == "cuda" else "gloo" + if device.type == "cuda": + return "nccl" + elif device.type == "xpu": + return "ccl" + else: + return "gloo" class _DatasetSamplerWrapper(Dataset): diff --git a/src/lightning/fabric/utilities/enums.py b/src/lightning/fabric/utilities/enums.py index 91915c5127c253..daa993abb7a2ed 100644 --- a/src/lightning/fabric/utilities/enums.py +++ b/src/lightning/fabric/utilities/enums.py @@ -33,6 +33,7 @@ class _AcceleratorType(LightningEnum): """Define Accelerator type by its nature.""" CPU = "CPU" + XPU = "XPU" CUDA = "CUDA" TPU = "TPU" MPS = "MPS" diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index a58a940152bbfc..dd952b920a7b4b 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -25,6 +25,6 @@ # 2. The inspection mode via `python -i`: https://stackoverflow.com/a/6879085/1162383 _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) -_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0") -_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0") +_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0", use_base_version=True) +_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True) _TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True) diff --git a/src/lightning/pytorch/accelerators/__init__.py b/src/lightning/pytorch/accelerators/__init__.py index 57d45bcb6448a2..9aeaee3e4f2d50 100644 --- a/src/lightning/pytorch/accelerators/__init__.py +++ b/src/lightning/pytorch/accelerators/__init__.py @@ -19,6 +19,7 @@ from lightning.pytorch.accelerators.ipu import IPUAccelerator # noqa: F401 from lightning.pytorch.accelerators.mps import MPSAccelerator # noqa: F401 from lightning.pytorch.accelerators.tpu import TPUAccelerator # noqa: F401 +from lightning.pytorch.accelerators.xpu import XPUAccelerator # noqa: F401 ACCELERATORS_BASE_MODULE = "lightning.pytorch.accelerators" AcceleratorRegistry = _AcceleratorRegistry() diff --git a/src/lightning/pytorch/accelerators/xpu.py b/src/lightning/pytorch/accelerators/xpu.py new file mode 100644 index 00000000000000..594de83e8ccc32 --- /dev/null +++ b/src/lightning/pytorch/accelerators/xpu.py @@ -0,0 +1,162 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import shutil +import subprocess +from typing import Any, Dict, List, Optional, Union + +import torch + +import lightning.pytorch as pl +from lightning.fabric.accelerators.xpu import _check_xpu_math_precision, num_xpu_devices +from lightning.fabric.utilities.device_parser import _parse_gpu_ids +from lightning.fabric.utilities.types import _DEVICE +from lightning.pytorch.accelerators.accelerator import Accelerator +from lightning.pytorch.utilities.exceptions import MisconfigurationException + +_log = logging.getLogger(__name__) + + +class XPUAccelerator(Accelerator): + """Accelerator for Intel XPU devices.""" + + def setup_device(self, device: torch.device) -> None: + """ + Raises: + MisconfigurationException: + If the selected device is not GPU. + """ + if device.type != "xpu": + raise MisconfigurationException(f"Device should be GPU, got {device} instead") + _check_xpu_math_precision(device) + torch.xpu.set_device(device) + + def setup(self, trainer: "pl.Trainer") -> None: + # TODO refactor input from trainer to local_rank @four4fish + # self.set_intel_flags(trainer.local_rank) + # clear cache before training + torch.xpu.empty_cache() + + # @staticmethod + # def set_intel_flags(local_rank: int) -> None: + # # set the correct xpu visible devices (using pci order) + # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + # all_gpu_ids = ",".join(str(x) for x in range(num_xpu_devices())) + # devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) + # _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") + + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + """Gets stats for the given GPU device. + + Args: + device: GPU device for which to get stats + + Returns: + A dictionary mapping the metrics to their values. + + Raises: + FileNotFoundError: + If xpum-smi installation not found + """ + return torch.xpu.memory_stats(device) + + def teardown(self) -> None: + # clean up memory + torch.xpu.empty_cache() + + @staticmethod + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + """Accelerator device parsing logic.""" + return _parse_gpu_ids(devices, include_xpu=True) + + @staticmethod + def get_parallel_devices(devices: List[int]) -> List[torch.device]: + """Gets parallel devices for the Accelerator.""" + return [torch.device("xpu", i) for i in devices] + + @staticmethod + def auto_device_count() -> int: + """Get the devices when set to auto.""" + return num_xpu_devices() + + @staticmethod + def is_available() -> bool: + return num_xpu_devices() > 0 + + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "xpu", + cls, + description=f"{cls.__class__.__name__}", + ) + + +def get_intel_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover + """Get GPU stats including memory, fan speed, and temperature from xpum-smi. + + Args: + device: GPU device for which to get stats + + Returns: + A dictionary mapping the metrics to their values. + + Raises: + FileNotFoundError: + If xpum-smi installation not found + """ + xpum_smi_path = shutil.which("xpum-smi") + if xpum_smi_path is None: + raise FileNotFoundError("xpum-smi: command not found") + + gpu_stat_metrics = [ + ("utilization.gpu", "%"), + ("memory.used", "MB"), + ("memory.free", "MB"), + ("utilization.memory", "%"), + ("fan.speed", "%"), + ("temperature.gpu", "°C"), + ("temperature.memory", "°C"), + ] + gpu_stat_keys = [k for k, _ in gpu_stat_metrics] + gpu_query = ",".join(gpu_stat_keys) + + index = torch._utils._get_device_index(device) + gpu_id = _get_gpu_id(index) + result = subprocess.run( + [xpum_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"], + encoding="utf-8", + capture_output=True, + check=True, + ) + + def _to_float(x: str) -> float: + try: + return float(x) + except ValueError: + return 0.0 + + s = result.stdout.strip() + stats = [_to_float(x) for x in s.split(", ")] + gpu_stats = {f"{x} ({unit})": stat for (x, unit), stat in zip(gpu_stat_metrics, stats)} + return gpu_stats + + +def _get_gpu_id(device_id: int) -> str: + """Get the unmasked real GPU IDs.""" + # All devices + default = ",".join(str(i) for i in range(num_xpu_devices())) + # cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",") + xpu_visible_devices = default.split(",") + return xpu_visible_devices[device_id].strip() diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index 77fa9c4171a2b2..6f5985abe01cb9 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn from lightning_utilities.core.apply_func import apply_to_collection -from torch import FloatTensor, Tensor +from torch import Tensor from torch.optim import Optimizer import lightning.pytorch as pl @@ -91,8 +91,8 @@ def connect( def forward_context(self) -> Generator[None, None, None]: """A context manager to change the default tensor type. - See: :meth:`torch.set_default_tensor_type` + See: :meth:`torch.set_default_dtype` """ - torch.set_default_tensor_type(torch.DoubleTensor) + torch.set_default_dtype(torch.float64) yield - torch.set_default_tensor_type(FloatTensor) + torch.set_default_dtype(torch.float32) diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 96654aac0f5017..989cfa66c61675 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -38,6 +38,7 @@ from lightning.pytorch.accelerators.ipu import IPUAccelerator from lightning.pytorch.accelerators.mps import MPSAccelerator from lightning.pytorch.accelerators.tpu import TPUAccelerator +from lightning.pytorch.accelerators.xpu import XPUAccelerator from lightning.pytorch.plugins import ( CheckpointIO, DeepSpeedPrecisionPlugin, @@ -358,6 +359,8 @@ def _choose_auto_accelerator(self) -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if XPUAccelerator.is_available(): + return "xpu" return "cpu" @staticmethod @@ -366,6 +369,9 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if XPUAccelerator.is_available(): + return "xpu" + raise MisconfigurationException("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: @@ -429,8 +435,8 @@ def _choose_strategy(self) -> Union[Strategy, str]: return DDPStrategy.strategy_name if len(self._parallel_devices) <= 1: # TODO: Change this once gpu accelerator was renamed to cuda accelerator - if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( - isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") + if isinstance(self._accelerator_flag, (XPUAccelerator, CUDAAccelerator, MPSAccelerator)) or ( + isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("xpu", "cuda", "gpu", "mps") ): device = _determine_root_gpu_device(self._parallel_devices) else: @@ -506,10 +512,11 @@ def _check_and_init_precision(self) -> PrecisionPlugin: if self._precision_flag == "64-true": return DoublePrecisionPlugin() - if self._precision_flag == "16-mixed" and self._accelerator_flag == "cpu": + if self._precision_flag == "16-mixed" and self._accelerator_flag == "cpu" or self._accelerator_flag == "xpu": + msg_keyword = "supported" if self._accelerator_flag == "cpu" else "recommended" rank_zero_warn( - "You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on " - "CPU. Using `precision='bf16-mixed'` instead." + f"You passed `Trainer(accelerator='{self._accelerator_flag}', precision='16-mixed')` but AMP with fp16 " + f"is not {msg_keyword} on {self._accelerator_flag.upper()}. Using `precision='bf16-mixed'` instead." ) self._precision_flag = "bf16-mixed" @@ -517,7 +524,12 @@ def _check_and_init_precision(self) -> PrecisionPlugin: rank_zero_info( f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + if self._accelerator_flag == "cpu": + device = "cpu" + elif self._accelerator_flag == "xpu": + device = "xpu" + else: + device = "cuda" if isinstance(self.strategy, FSDPStrategy): return FSDPMixedPrecisionPlugin(self._precision_flag, device) diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 9a30b006c18689..1c821726792d70 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -23,6 +23,7 @@ IPUAccelerator, MPSAccelerator, TPUAccelerator, + XPUAccelerator, ) from lightning.pytorch.accelerators.hpu import _HPU_AVAILABLE from lightning.pytorch.accelerators.ipu import _IPU_AVAILABLE @@ -155,11 +156,14 @@ def _log_device_info(trainer: "pl.Trainer") -> None: elif MPSAccelerator.is_available(): gpu_available = True gpu_type = " (mps)" + elif XPUAccelerator.is_available(): + gpu_available = True + gpu_type = " (xpu)" else: gpu_available = False gpu_type = "" - gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)) + gpu_used = isinstance(trainer.accelerator, (XPUAccelerator, CUDAAccelerator, MPSAccelerator)) rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}") num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, TPUAccelerator) else 0 @@ -179,6 +183,13 @@ def _log_device_info(trainer: "pl.Trainer") -> None: category=PossibleUserWarning, ) + if XPUAccelerator.is_available() and not isinstance(trainer.accelerator, XPUAccelerator): + rank_zero_warn( + "GPU available but not used. Set `accelerator` and `devices` using" + f" `Trainer(accelerator='gpu', devices={XPUAccelerator.auto_device_count()})`.", + category=PossibleUserWarning, + ) + if TPUAccelerator.is_available() and not isinstance(trainer.accelerator, TPUAccelerator): rank_zero_warn( "TPU available but not used. Set `accelerator` and `devices` using" diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index a733337894bec9..e9b2b9b55285a9 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -22,7 +22,7 @@ _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) _PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11) # duplicated from fabric because HPU is patching it below -_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0") +_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True) _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") _KINETO_AVAILABLE = torch.profiler.kineto_available()