diff --git a/python/ray/_private/runtime_env/pip.py b/python/ray/_private/runtime_env/pip.py index b70f165d5b545..199f62f2ac0cf 100644 --- a/python/ray/_private/runtime_env/pip.py +++ b/python/ray/_private/runtime_env/pip.py @@ -20,6 +20,9 @@ _WIN32 = os.name == "nt" +INTERNAL_PIP_FILENAME = "ray_runtime_env_internal_pip_requirements.txt" +MAX_INTERNAL_PIP_FILENAME_TRIES = 100 + def _get_pip_hash(pip_dict: Dict) -> str: serialized_pip_spec = json.dumps(pip_dict, sort_keys=True) @@ -69,8 +72,43 @@ def get_virtualenv_activate_command(cls, target_dir: str) -> List[str]: return cmd + ["1>&2", "&&"] @staticmethod - def get_requirements_file(target_dir: str) -> str: - return os.path.join(target_dir, "requirements.txt") + def get_requirements_file(target_dir: str, pip_list: Optional[List[str]]) -> str: + """Returns the path to the requirements file to use for this runtime env. + + If pip_list is not None, we will check if the internal pip filename is in any of + the entries of pip_list. If so, we will append numbers to the end of the + filename until we find one that doesn't conflict. This prevents infinite + recursion if the user specifies the internal pip filename in their pip list. + + Args: + target_dir: The directory to store the requirements file in. + pip_list: A list of pip requirements specified by the user. + + Returns: + The path to the requirements file to use for this runtime env. + """ + + def filename_in_pip_list(filename: str) -> bool: + for pip_entry in pip_list: + if filename in pip_entry: + return True + return False + + filename = INTERNAL_PIP_FILENAME + if pip_list is not None: + i = 1 + while ( + filename_in_pip_list(filename) and i < MAX_INTERNAL_PIP_FILENAME_TRIES + ): + filename = f"{INTERNAL_PIP_FILENAME}.{i}" + i += 1 + if i == MAX_INTERNAL_PIP_FILENAME_TRIES: + raise RuntimeError( + "Could not find a valid filename for the internal " + "pip requirements file. Please specify a different " + "pip list in your runtime env." + ) + return os.path.join(target_dir, filename) class PipProcessor: @@ -303,7 +341,7 @@ async def _install_pip_packages( virtualenv_path = _PathHelper.get_virtualenv_path(path) python = _PathHelper.get_virtualenv_python(path) # TODO(fyrestone): Support -i, --no-deps, --no-cache-dir, ... - pip_requirements_file = _PathHelper.get_requirements_file(path) + pip_requirements_file = _PathHelper.get_requirements_file(path, pip_packages) def _gen_requirements_txt(): with open(pip_requirements_file, "w") as file: diff --git a/python/ray/tests/test_runtime_env_conda_and_pip.py b/python/ray/tests/test_runtime_env_conda_and_pip.py index ec721a8ac9d7a..62cbdf2c7e759 100644 --- a/python/ray/tests/test_runtime_env_conda_and_pip.py +++ b/python/ray/tests/test_runtime_env_conda_and_pip.py @@ -10,6 +10,11 @@ generate_runtime_env_dict, ) from ray._private.runtime_env.conda import _get_conda_dict_with_ray_inserted +from ray._private.runtime_env.pip import ( + INTERNAL_PIP_FILENAME, + MAX_INTERNAL_PIP_FILENAME_TRIES, + _PathHelper, +) from ray.runtime_env import RuntimeEnv import yaml @@ -207,6 +212,53 @@ def f(): assert ray.get(f.remote()) == 0 +def test_get_requirements_file(): + """Unit test for _PathHelper.get_requirements_file.""" + with tempfile.TemporaryDirectory() as tmpdir: + path_helper = _PathHelper() + + # If pip_list is None, we should return the internal pip filename. + assert path_helper.get_requirements_file(tmpdir, pip_list=None) == os.path.join( + tmpdir, INTERNAL_PIP_FILENAME + ) + + # If the internal pip filename is not in pip_list, we should return the internal + # pip filename. + assert path_helper.get_requirements_file( + tmpdir, pip_list=["foo", "bar"] + ) == os.path.join(tmpdir, INTERNAL_PIP_FILENAME) + + # If the internal pip filename is in pip_list, we should append numbers to the + # end of the filename until we find one that doesn't conflict. + assert path_helper.get_requirements_file( + tmpdir, pip_list=["foo", "bar", f"-r {INTERNAL_PIP_FILENAME}"] + ) == os.path.join(tmpdir, f"{INTERNAL_PIP_FILENAME}.1") + assert path_helper.get_requirements_file( + tmpdir, + pip_list=[ + "foo", + "bar", + f"{INTERNAL_PIP_FILENAME}.1", + f"{INTERNAL_PIP_FILENAME}.2", + ], + ) == os.path.join(tmpdir, f"{INTERNAL_PIP_FILENAME}.3") + + # If we can't find a valid filename, we should raise an error. + with pytest.raises(RuntimeError) as excinfo: + path_helper.get_requirements_file( + tmpdir, + pip_list=[ + "foo", + "bar", + *[ + f"{INTERNAL_PIP_FILENAME}.{i}" + for i in range(MAX_INTERNAL_PIP_FILENAME_TRIES) + ], + ], + ) + assert "Could not find a valid filename for the internal " in str(excinfo.value) + + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))