""" Helper functions to get information about the environment.
"""
import importlib.util
import os as python_os
import subprocess
import sys as python_sys
from typing import Any
import torch
from homura.liblog import get_logger
logger = get_logger("homura.environment")
# Utility functions that useful libraries are available or not
[docs]def is_accimage_available() -> bool:
return importlib.util.find_spec("accimage") is not None
[docs]def enable_accimage() -> None:
if is_accimage_available():
import torchvision
torchvision.set_image_backend("accimage")
logger.info("accimage is activated")
else:
logger.warning("accimage is not available")
[docs]def is_faiss_available() -> bool:
_faiss_available = importlib.util.find_spec("faiss") is not None
if _faiss_available:
import faiss
if not hasattr(faiss, 'StandardGpuResources'):
logger.info("faiss is available but is not for GPUs")
return _faiss_available
[docs]def is_cupy_available() -> bool:
return importlib.util.find_spec("cupy") is not None
[docs]def is_opteinsum_available() -> bool:
return importlib.util.find_spec("opt_einsum") is not None
# TF32
def _enable_tf32(mode: bool) -> None:
try:
torch.backends.cuda.matmul.allow_tf32 = mode
torch.backends.cudnn.allow_tf32 = mode
if mode:
logger.info("TF32 is enabled")
else:
logger.info("TF32 is disabled")
except Exception as e:
logger.exception(e)
[docs]def disable_tf32() -> None:
""" Globally disable TF32
"""
_enable_tf32(False)
[docs]class disable_tf32_locally(object):
""" Locally disable TF32
>>> with disable_tf32_locally():
>>> ...
or
>>> @disable_tf32_locally()
>>> def function():
>>> ...
"""
def __call__(self):
_enable_tf32(False)
def __enter__(self):
_enable_tf32(False)
def __exit__(self, exc_type, exc_val, exc_tb):
_enable_tf32(True)
# get environment information
[docs]def get_git_hash() -> str:
def _decode_bytes(b: bytes) -> str:
return b.decode("ascii")[:-1]
try:
is_git_repo = subprocess.run(["git", "rev-parse", "--is-inside-work-tree"],
stdout=subprocess.PIPE, stderr=subprocess.DEVNULL).stdout
except FileNotFoundError:
return ""
if _decode_bytes(is_git_repo) == "true":
git_hash = subprocess.run(["git", "rev-parse", "--short", "HEAD"],
stdout=subprocess.PIPE).stdout
return _decode_bytes(git_hash)
else:
logger.info("No git info available in this directory")
return ""
[docs]def get_args() -> list:
return python_sys.argv
[docs]def get_environ(name: str,
default: Any = None
) -> str:
return python_os.environ.get(name, default)