Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UX: Allow inferring cloud from region or zone. #2632

Merged
merged 9 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,17 @@ def instance_type_exists(self, instance_type):
"""Returns whether the instance type exists for this cloud."""
raise NotImplementedError

def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
"""Validates the region and zone."""
def validate_region_zone(
self, region: Optional[str],
zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
"""Validates whether region and zone exist in the catalog.

Returns:
A tuple of region and zone, if validated.

Raises:
ValueError: If region or zone is invalid or not supported.
"""
return service_catalog.validate_region_zone(region,
zone,
clouds=self._REPR.lower())
Expand Down
9 changes: 9 additions & 0 deletions sky/clouds/cloud_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ class _CloudRegistry(dict):
def from_str(self, name: Optional[str]) -> Optional['cloud.Cloud']:
if name is None:
return None
if name.lower() == 'local':
# Backward compatibility. global_user_state's DB may have recorded
# Local cloud, and we've just removed it from the registry, and
# global_user_state.get_enabled_clouds() would call into this func
# and fail.
#
# TODO(skypilot): have a better way to handle clouds removed from
# registry if needed.
return None
if name.lower() not in self:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Cloud {name!r} is not a valid cloud among '
Expand Down
11 changes: 9 additions & 2 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class Kubernetes(clouds.Cloud):
_DEFAULT_MEMORY_CPU_RATIO = 1
_DEFAULT_MEMORY_CPU_RATIO_WITH_GPU = 4 # Allocate more memory for GPU tasks
_REPR = 'Kubernetes'
_regions: List[clouds.Region] = [clouds.Region('kubernetes')]
_SINGLETON_REGION = 'kubernetes'
_regions: List[clouds.Region] = [clouds.Region(_SINGLETON_REGION)]
_CLOUD_UNSUPPORTED_FEATURES = {
# TODO(romilb): Stopping might be possible to implement with
# container checkpointing introduced in Kubernetes v1.25. See:
Expand Down Expand Up @@ -329,7 +330,13 @@ def instance_type_exists(self, instance_type: str) -> bool:
instance_type)

def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
# Kubernetes doesn't have regions or zones, so we don't need to validate
if region != self._SINGLETON_REGION:
raise ValueError(
'Kubernetes support does not support setting region.'
' Cluster used is determined by the kubeconfig.')
if zone is not None:
raise ValueError('Kubernetes support does not support setting zone.'
' Cluster used is determined by the kubeconfig.')
return region, zone

def accelerator_in_region_or_zone(self,
Expand Down
7 changes: 4 additions & 3 deletions sky/clouds/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sky import resources as resources_lib


@clouds.CLOUD_REGISTRY.register
# TODO(skypilot): remove Local now that we're using Kubernetes.
class Local(clouds.Cloud):
"""Local/on-premise cloud.

Expand Down Expand Up @@ -191,10 +191,11 @@ def instance_type_exists(self, instance_type: str) -> bool:
def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
# Returns true if the region name is same as Local cloud's
# one and only region: 'Local'.
assert zone is None
if zone is not None:
raise ValueError('Local cloud does not support zones.')
if region is None or region != Local.LOCAL_REGION.name:
raise ValueError(f'Region {region!r} does not match the Local'
' cloud region {Local.LOCAL_REGION.name!r}.')
f' cloud region {Local.LOCAL_REGION.name!r}.')
return region, zone

@classmethod
Expand Down
9 changes: 8 additions & 1 deletion sky/clouds/service_catalog/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,14 @@ def instance_type_exists_impl(df: pd.DataFrame, instance_type: str) -> bool:
def validate_region_zone_impl(
cloud_name: str, df: pd.DataFrame, region: Optional[str],
zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
"""Validates whether region and zone exist in the catalog."""
"""Validates whether region and zone exist in the catalog.

Returns:
A tuple of region and zone, if validated.

Raises:
ValueError: If region or zone is invalid or not supported.
"""

def _get_candidate_str(loc: str, all_loc: List[str]) -> str:
candidate_loc = difflib.get_close_matches(loc, all_loc, n=5, cutoff=0.9)
Expand Down
60 changes: 51 additions & 9 deletions sky/resources.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Resources: compute requirements of Tasks."""
import functools
import textwrap
from typing import Dict, List, Optional, Set, Tuple, Union

import colorama
Expand All @@ -15,6 +16,7 @@
from sky.provision import docker_utils
from sky.skylet import constants
from sky.utils import accelerator_registry
from sky.utils import log_utils
from sky.utils import resources_utils
from sky.utils import schemas
from sky.utils import tpu_utils
Expand Down Expand Up @@ -134,7 +136,7 @@ def __init__(
self._cloud = cloud
self._region: Optional[str] = None
self._zone: Optional[str] = None
self._set_region_zone(region, zone)
self._validate_and_set_region_zone(region, zone)

self._instance_type = instance_type

Expand Down Expand Up @@ -537,22 +539,62 @@ def is_launchable(self) -> bool:
return self.cloud is not None and self._instance_type is not None

def need_cleanup_after_preemption(self) -> bool:
"""Returns whether a spot resource needs cleanup after preeemption."""
"""Returns whether a spot resource needs cleanup after preemption."""
assert self.is_launchable(), self
return self.cloud.need_cleanup_after_preemption(self)

def _set_region_zone(self, region: Optional[str],
zone: Optional[str]) -> None:
def _validate_and_set_region_zone(self, region: Optional[str],
zone: Optional[str]) -> None:
if region is None and zone is None:
return

if self._cloud is None:
with ux_utils.print_exception_no_traceback():
raise ValueError(
'Cloud must be specified when region/zone are specified.')
# Try to infer the cloud from region/zone, if unique. If 0 or >1
# cloud corresponds to region/zone, errors out.
valid_clouds = []
enabled_clouds = global_user_state.get_enabled_clouds()
cloud_to_errors = {}
for cloud in enabled_clouds:
try:
cloud.validate_region_zone(region, zone)
except ValueError as e:
cloud_to_errors[repr(cloud)] = e
continue
valid_clouds.append(cloud)

if len(valid_clouds) == 0:
if len(enabled_clouds) == 1:
cloud_str = f'for cloud {enabled_clouds[0]}'
else:
cloud_str = f'for any cloud among {enabled_clouds}'
with ux_utils.print_exception_no_traceback():
if len(cloud_to_errors) == 1:
# UX: if 1 cloud, don't print a table.
hint = list(cloud_to_errors.items())[0][-1]
else:
table = log_utils.create_table(['Cloud', 'Hint'])
table.add_row(['-----', '----'])
for cloud, error in cloud_to_errors.items():
reason_str = '\n'.join(textwrap.wrap(
str(error), 80))
table.add_row([str(cloud), reason_str])
hint = table.get_string()
raise ValueError(
f'Invalid (region {region!r}, zone {zone!r}) '
f'{cloud_str}. Details:\n{hint}')
elif len(valid_clouds) > 1:
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Cannot infer cloud from (region {region!r}, zone '
f'{zone!r}). Multiple enabled clouds have region/zone '
f'of the same names: {valid_clouds}. '
f'To fix: explicitly specify `cloud`.')
logger.debug(f'Cloud is not specified, using {valid_clouds[0]} '
f'inferred from region {region!r} and zone {zone!r}')
self._cloud = valid_clouds[0]

# Validate whether region and zone exist in the catalog, and set the
# region if zone is specified.
# Validate if region and zone exist in the catalog, and set the region
# if zone is specified.
self._region, self._zone = self._cloud.validate_region_zone(
region, zone)

Expand Down
62 changes: 62 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import tempfile
from typing import List, Optional

import pandas as pd
import pytest

from sky import clouds
from sky.utils import kubernetes_utils


def enable_all_clouds_in_monkeypatch(
monkeypatch: pytest.MonkeyPatch,
enabled_clouds: Optional[List[str]] = None,
) -> None:
# Monkey-patching is required because in the test environment, no cloud is
# enabled. The optimizer checks the environment to find enabled clouds, and
# only generates plans within these clouds. The tests assume that all three
# clouds are enabled, so we monkeypatch the `sky.global_user_state` module
# to return all three clouds. We also monkeypatch `sky.check.check` so that
# when the optimizer tries calling it to update enabled_clouds, it does not
# raise exceptions.
if enabled_clouds is None:
enabled_clouds = list(clouds.CLOUD_REGISTRY.values())
monkeypatch.setattr(
'sky.global_user_state.get_enabled_clouds',
lambda: enabled_clouds,
)
monkeypatch.setattr('sky.check.check', lambda *_args, **_kwargs: None)
config_file_backup = tempfile.NamedTemporaryFile(
prefix='tmp_backup_config_default', delete=False)
monkeypatch.setattr('sky.clouds.gcp.GCP_CONFIG_SKY_BACKUP_PATH',
config_file_backup.name)
monkeypatch.setattr(
'sky.clouds.gcp.DEFAULT_GCP_APPLICATION_CREDENTIAL_PATH',
config_file_backup.name)
monkeypatch.setenv('OCI_CONFIG', config_file_backup.name)

az_mappings = pd.read_csv('tests/default_aws_az_mappings.csv')

def _get_az_mappings(_):
return az_mappings

monkeypatch.setattr(
'sky.clouds.service_catalog.aws_catalog._get_az_mappings',
_get_az_mappings)

monkeypatch.setattr('sky.backends.backend_utils.check_owner_identity',
lambda _: None)

monkeypatch.setattr(
'sky.clouds.gcp.GCP._list_reservations_for_instance_type',
lambda *_args, **_kwargs: [])

# Monkey patch Kubernetes resource detection since it queries
# the cluster to detect available cluster resources.
monkeypatch.setattr(
'sky.utils.kubernetes_utils.detect_gpu_label_formatter',
lambda *_args, **_kwargs: [kubernetes_utils.SkyPilotLabelFormatter, []])
monkeypatch.setattr('sky.utils.kubernetes_utils.detect_gpu_resource',
lambda *_args, **_kwargs: [True, []])
monkeypatch.setattr('sky.utils.kubernetes_utils.check_instance_fits',
lambda *_args, **_kwargs: [True, ''])
62 changes: 6 additions & 56 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import tempfile
from typing import List
from unittest.mock import patch

import pandas as pd
import common # TODO(zongheng): for some reason isort places it here.
import pytest

# Usage: use
Expand All @@ -20,7 +18,8 @@
# To only run tests for a specific cloud (as well as generic tests), use
# --aws, --gcp, --azure, or --lambda.
#
# To only run tests for managed spot (without generic tests), use --managed-spot.
# To only run tests for managed spot (without generic tests), use
# --managed-spot.
all_clouds_in_smoke_tests = [
'aws', 'gcp', 'azure', 'lambda', 'cloudflare', 'ibm', 'scp', 'oci',
'kubernetes'
Expand Down Expand Up @@ -180,61 +179,12 @@ def generic_cloud(request) -> str:


@pytest.fixture
def enable_all_clouds(monkeypatch):
from sky import clouds
from sky.utils import kubernetes_utils

# Monkey-patching is required because in the test environment, no cloud is
# enabled. The optimizer checks the environment to find enabled clouds, and
# only generates plans within these clouds. The tests assume that all three
# clouds are enabled, so we monkeypatch the `sky.global_user_state` module
# to return all three clouds. We also monkeypatch `sky.check.check` so that
# when the optimizer tries calling it to update enabled_clouds, it does not
# raise exceptions.
enabled_clouds = list(clouds.CLOUD_REGISTRY.values())
monkeypatch.setattr(
'sky.global_user_state.get_enabled_clouds',
lambda: enabled_clouds,
)
monkeypatch.setattr('sky.check.check', lambda *_args, **_kwargs: None)
config_file_backup = tempfile.NamedTemporaryFile(
prefix='tmp_backup_config_default', delete=False)
monkeypatch.setattr('sky.clouds.gcp.GCP_CONFIG_SKY_BACKUP_PATH',
config_file_backup.name)
monkeypatch.setattr(
'sky.clouds.gcp.DEFAULT_GCP_APPLICATION_CREDENTIAL_PATH',
config_file_backup.name)
monkeypatch.setenv('OCI_CONFIG', config_file_backup.name)

az_mappings = pd.read_csv('tests/default_aws_az_mappings.csv')

def _get_az_mappings(_):
return az_mappings

monkeypatch.setattr(
'sky.clouds.service_catalog.aws_catalog._get_az_mappings',
_get_az_mappings)

monkeypatch.setattr('sky.backends.backend_utils.check_owner_identity',
lambda _: None)

monkeypatch.setattr(
'sky.clouds.gcp.GCP._list_reservations_for_instance_type',
lambda *_args, **_kwargs: [])

# Monkey patch Kubernetes resource detection since it queries
# the cluster to detect available cluster resources.
monkeypatch.setattr(
'sky.utils.kubernetes_utils.detect_gpu_label_formatter',
lambda *_args, **_kwargs: [kubernetes_utils.SkyPilotLabelFormatter, []])
monkeypatch.setattr('sky.utils.kubernetes_utils.detect_gpu_resource',
lambda *_args, **_kwargs: [True, []])
monkeypatch.setattr('sky.utils.kubernetes_utils.check_instance_fits',
lambda *_args, **_kwargs: [True, ''])
def enable_all_clouds(monkeypatch: pytest.MonkeyPatch):
common.enable_all_clouds_in_monkeypatch(monkeypatch)


@pytest.fixture
def aws_config_region(monkeypatch) -> str:
def aws_config_region(monkeypatch: pytest.MonkeyPatch) -> str:
from sky import skypilot_config
region = 'us-west-2'
if skypilot_config.loaded():
Expand Down
Loading