Source code for homura.utils.backends

""" Helper functions to convert PyTorch Tensors <->  Cupy/Numpy arrays. These functions  are useful to write device-agnostic extensions.
"""

import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack

from .environment import is_cupy_available, is_opteinsum_available

has_cupy = is_cupy_available()
if has_cupy:
    import cupy

has_opt_einsum = is_opteinsum_available()
if has_opt_einsum:
    import opt_einsum


[docs]def torch_to_xp(input: torch.Tensor ) -> np.ndarray: """ Convert a PyTorch tensor to a Cupy/Numpy array. """ if not isinstance(input, torch.Tensor): raise RuntimeError(f'torch_to_numpy expects torch.Tensor as input, but got {type(input)}') if has_cupy and input.is_cuda: return cupy.fromDlpack(to_dlpack(input)) else: return input.numpy()
[docs]def xp_to_torch(input: np.ndarray ) -> torch.Tensor: """ Convert a Cupy/Numpy array to a PyTorch tensor """ if isinstance(input, np.ndarray): return torch.from_numpy(input) elif has_cupy and isinstance(input, cupy.ndarray): return from_dlpack(cupy.ToDlpack(input)) else: raise RuntimeError(f'xp_to_torch expects numpy/cupy.ndarray as input, but got {type(input)}')
[docs]def einsum(expr: str, *xs): if has_opt_einsum: return opt_einsum.contract(expr, *xs, backend='torch') return torch.einsum(expr, *xs)