Skip to content

Commit

Permalink
fixed non-CUDA torch on post-update reinstall octimot#170
Browse files Browse the repository at this point in the history
  • Loading branch information
octimot committed Jun 3, 2024
1 parent f7fddcb commit 25eb44b
Showing 1 changed file with 55 additions and 6 deletions.
61 changes: 55 additions & 6 deletions storytoolkitai/core/post_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,69 @@
import sys


def reinstall_requirements():
def cuda_is_available():
try:
# check if nvcc (NVIDIA's CUDA compiler) is installed
subprocess.check_output(['nvcc', '--version'])
return True
except subprocess.CalledProcessError:
return False
except FileNotFoundError:
return False


def reinstall_requirements():
# get the absolute path to requirements.txt,
# considering it should be relative to the current file
requirements_file_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), '..', '..', 'requirements.txt'
)

logger.info('Re-installing requirements.txt...')
logger.info('This may take a few minutes.')
try:
# get the absolute path to requirements.txt,
# considering it should be relative to the current file
requirements_file_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), '..', '..', 'requirements.txt'
)

# but before we reinstall, detect if we have CUDA available:
if cuda_is_available():

logger.info('CUDA is available on this system. '
'Trying to install CUDA compatible versions of torch, torchaudio, and torchvision.')

# if we do, we will use torch 2.0.0+cu117, torchaudio 2.0.1+cu117, and torchvision 0.15.1+cu117
# so we need to replace the torch, torchaudio, and torchvision lines in the requirements.txt file
# copy the requirements.txt file to a temporary file
with open(requirements_file_path, 'r') as f:
requirements = f.readlines()
original_requirements = requirements.copy()

# replace the torch, torchaudio, and torchvision lines
for i, line in enumerate(requirements):
if 'torchaudio' in line:
requirements[i] = \
'torchaudio==2.0.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117\n'
elif 'torch' in line:
requirements[i] = \
'torch==2.0.0+cu117 --extra-index-url https://download.pytorch.org/whl/cu117\n'

# write the requirements back to the file
with open(requirements_file_path, 'w') as f:
f.writelines(requirements)

# don't use cache dir
subprocess.check_call(
[sys.executable, '-m', 'pip', 'install', '-r', requirements_file_path, '--no-cache-dir'])

# restore the original requirements
with open(requirements_file_path, 'w') as f:
f.writelines(original_requirements)
except Exception as e:
logger.error('Failed to install requirements.txt: {}'.format(e))
logger.warning('Please install the requirements.txt manually.')

# restore the original requirements
with open(requirements_file_path, 'w') as f:
f.writelines(original_requirements)

return False

return True
Expand Down Expand Up @@ -286,3 +333,5 @@ def post_update_0_24_0(is_standalone=False):
'0.23.0': post_update_0_23_0,
'0.24.0': post_update_0_24_0,
}

reinstall_requirements()

0 comments on commit 25eb44b

Please sign in to comment.