diff --git a/tests/test_yamls/test_aws_config.yaml b/tests/test_yamls/test_aws_config.yaml new file mode 100644 index 00000000000..27ac4d4f555 --- /dev/null +++ b/tests/test_yamls/test_aws_config.yaml @@ -0,0 +1,11 @@ +aws: + vpc_name: fake-vpc + remote_identity: + - sky-serve-fake1-*: fake1-skypilot-role + - sky-serve-fake2-*: fake1-skypilot-role + - "*": fake-skypilot-default-role + + security_group_name: + - sky-serve-fake1-*: fake-1-sg + - sky-serve-fake2-*: fake-2-sg + - "*": fake-skypilot-default diff --git a/tests/unit_tests/test_resources.py b/tests/unit_tests/test_resources.py index f9e3ad51630..097c69d6400 100644 --- a/tests/unit_tests/test_resources.py +++ b/tests/unit_tests/test_resources.py @@ -1,10 +1,13 @@ from typing import Dict from unittest.mock import Mock +from unittest.mock import patch import pytest from sky import clouds +from sky import skypilot_config from sky.resources import Resources +from sky.utils import resources_utils GLOBAL_VALID_LABELS = { 'plaintext': 'plainvalue', @@ -86,3 +89,82 @@ def test_kubernetes_labels_resources(): } cloud = clouds.Kubernetes() _run_label_test(allowed_labels, invalid_labels, cloud) + + +@patch.object(skypilot_config, 'CONFIG_PATH', + './tests/test_yamls/test_aws_config.yaml') +@patch.object(skypilot_config, '_dict', None) +@patch.object(skypilot_config, '_loaded_config_path', None) +@patch("sky.clouds.service_catalog.instance_type_exists", return_value=True) +@patch("sky.clouds.service_catalog.get_accelerators_from_instance_type", + return_value={"fake-acc": 2}) +@patch("sky.clouds.service_catalog.get_image_id_from_tag", + return_value="fake-image") +def test_aws_make_deploy_variables(*mocks) -> None: + skypilot_config._try_load_config() + + cloud = clouds.AWS() + cluster_name = resources_utils.ClusterName(display_name='display', + name_on_cloud='cloud') + region = clouds.Region(name='fake-region') + zones = [clouds.Zone(name='fake-zone')] + resource = Resources(cloud=cloud, instance_type="fake-type: 3") + config = resource.make_deploy_variables(cluster_name, + region, + zones, + dryrun=True) + + expected_config_base = { + 'instance_type': resource.instance_type, + 'custom_resources': '{"fake-acc":2}', + 'use_spot': False, + 'region': 'fake-region', + 'image_id': 'fake-image', + 'disk_tier': 'gp3', + 'disk_throughput': 218, + 'disk_iops': 3500, + 'custom_disk_perf': True, + 'docker_image': None, + 'docker_container_name': 'sky_container', + 'docker_login_config': None, + 'zones': 'fake-zone' + } + + # test using defaults + expected_config = expected_config_base.copy() + expected_config.update({ + 'security_group': "fake-skypilot-default", + 'security_group_managed_by_skypilot': 'false' + }) + assert config == expected_config, ('unexpected resource ' + 'variables generated') + + # test using culuster matches regex, top + cluster_name = resources_utils.ClusterName( + display_name='display', name_on_cloud='sky-serve-fake1-1234') + expected_config = expected_config_base.copy() + expected_config.update({ + 'security_group': "fake-1-sg", + 'security_group_managed_by_skypilot': 'false' + }) + config = resource.make_deploy_variables(cluster_name, + region, + zones, + dryrun=True) + assert config == expected_config, ('unexpected resource ' + 'variables generated') + + # test using culuster matches regex, middle + cluster_name = resources_utils.ClusterName( + display_name='display', name_on_cloud='sky-serve-fake2-1234') + expected_config = expected_config_base.copy() + expected_config.update({ + 'security_group': "fake-2-sg", + 'security_group_managed_by_skypilot': 'false' + }) + config = resource.make_deploy_variables(cluster_name, + region, + zones, + dryrun=True) + assert config == expected_config, ('unexpected resource ' + 'variables generated')