Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix spot tpu bug #1717

Merged
merged 21 commits into from
Mar 16, 2023
Merged

Fix spot tpu bug #1717

merged 21 commits into from
Mar 16, 2023

Conversation

infwinston
Copy link
Member

@infwinston infwinston commented Feb 23, 2023

Fixes

Tested (run the relevant ones):

  • Any manual or new tests for this PR (please specify below)
  • All smoke tests: pytest tests/test_smoke.py
  • Relevant individual smoke tests: pytest tests/test_smoke.py::test_fill_in_the_name
  • Backward compatibility tests: bash tests/backward_comaptibility_tests.sh

@infwinston infwinston changed the title Fix spot tpu Fix spot tpu bug Feb 23, 2023
Copy link
Member

@concretevitamin concretevitamin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @infwinston, some comments. Also:

For removing colorama.init(), should we call this func somewhere https://github.com/tartley/colorama#usage? After that, perhaps ask @romilbhardwaj to check if colors still look OK on Windows?

Comment on lines 122 to 123
logger.info('wait for 30 seconds and retry...')
time.sleep(30)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remnant?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Comment on lines 119 to 121
import traceback # pylint: disable=import-outside-toplevel
logger.error(f' Detailed exception: {e}')
logger.info(f' Traceback: {traceback.format_exc()}')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the existing usage of format_exc in this module didn't print out the stacktrace. Could we double check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do need format_exc to print out the full trace. I tried one simple example below

def func():
    raise ValueError("test")
    
try:
    func()
except Exception as e:
    print(f'{type(e)}: {str(e)}')
    print('===full trace===')
    import traceback
    print(traceback.format_exc())

and got output

<class 'ValueError'>: test
===full trace===
Traceback (most recent call last):
  File "sky/test_err.py", line 5, in <module>
    func()
  File "sky/test_err.py", line 2, in func
    raise ValueError("test")
ValueError: test

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we try a real spot launch example? IIRC, that didn't produce a usable trace.

Copy link
Member Author

@infwinston infwinston Feb 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I comment out this line and the entire trace will then show

sys.tracebacklimit = 0

import sky

try:
    t = sky.Task(run='ls')
    t.set_resources(sky.Resources(sky.GCP()))
    sky.launch(task=t, cluster_name='xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
except Exception as e:
    print(f'{type(e)}: {str(e)}')
    print('===full trace===')
    import traceback
    print(traceback.format_exc())

Output:

...
<class 'sky.exceptions.InvalidClusterNameError'>: Cluster name 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx' has 60 chars; maximum length is 35 chars on GCP.
===full trace===
Traceback (most recent call last):
  File "test_err.py", line 11, in <module>
    sky.launch(task=t, cluster_name='xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
  File "/home/gcpuser/sky_workdir/skypilot/sky/utils/common_utils.py", line 241, in _record
    return f(*args, **kwargs)
  File "/home/gcpuser/sky_workdir/skypilot/sky/utils/common_utils.py", line 241, in _record
    return f(*args, **kwargs)
  File "/home/gcpuser/sky_workdir/skypilot/sky/execution.py", line 411, in launch
    _execute(
  File "/home/gcpuser/sky_workdir/skypilot/sky/execution.py", line 260, in _execute
    handle = backend.provision(task,
  File "/home/gcpuser/sky_workdir/skypilot/sky/utils/common_utils.py", line 241, in _record
    return f(*args, **kwargs)
  File "/home/gcpuser/sky_workdir/skypilot/sky/utils/common_utils.py", line 220, in _record
    return f(*args, **kwargs)
  File "/home/gcpuser/sky_workdir/skypilot/sky/backends/backend.py", line 49, in provision
    return self._provision(task, to_provision, dryrun, stream_logs,
  File "/home/gcpuser/sky_workdir/skypilot/sky/backends/cloud_vm_ray_backend.py", line 2129, in _provision
    to_provision_config = self._check_existing_cluster(
  File "/home/gcpuser/sky_workdir/skypilot/sky/utils/common_utils.py", line 241, in _record
    return f(*args, **kwargs)
  File "/home/gcpuser/sky_workdir/skypilot/sky/backends/cloud_vm_ray_backend.py", line 3231, in _check_existing_cluster
    task_cloud.check_cluster_name_is_valid(cluster_name)
  File "/home/gcpuser/sky_workdir/skypilot/sky/clouds/cloud.py", line 408, in check_cluster_name_is_valid
    raise exceptions.InvalidClusterNameError(
sky.exceptions.InvalidClusterNameError: Cluster name 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx' has 60 chars; maximum length is 35 chars on GCP.

Was it because of sys.tracebacklimit = 0 so it didn't produce trace?

Comment on lines 1411 to 1412
ip = (endpoint.get('ipAddress', None) if get_internal_ips else
endpoint['accessConfig'].get('externalIp', None))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on why we do this safeguard?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

returncode = runner.run('ray status', stream_logs=False)
if returncode:
use_spot = handle.launched_resources.use_spot
# if cluster is not spot, we can determine its health by "ray status".
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change makes sense to me. Cc @Michaelvll to take another look.

Can we expand the comment here? E.g., mention including the following

  • For non-spot clusters, this is an optimization: we call external_ips() and/or SSH into the cluster to run 'ray status' in various cases to determine cluster health, because these may be faster than querying the true node statuses from the cloud provider.
  • For spot clusters, the above can be unsafe. Therefore we directly query the cloud provider.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I've updated the comments.

@Michaelvll Michaelvll self-requested a review March 3, 2023 02:27
# the true node statuses from the cloud provider.
# For spot clusters, the above can be unsafe.
# Therefore we directly query the cloud provider.
if not use_spot:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check the ray status on the spot clusters as well. It is possible that the ray cluster is not running, but the cloud provider shows UP, when the user call ray stop manually. Our job queue depends on the healthiness of the ray cluster, if the ray cluster is not running, we should consider the cluster as INIT.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is to say, we previously assume ray cluster is healthy indicates cloud provider shows UP, but based on your assumption both should be checked for the spot cluster.

Copy link
Member Author

@infwinston infwinston Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good point. I'm looking at the behavior of the current codebase. It seems to me if ray cluster is not running. we will also determine the cluster status by just querying the cloud provider. and if it's running, we will set the state to UP.

returncode = runner.run('ray status', stream_logs=False)
if returncode:
raise exceptions.FetchIPError(
reason=exceptions.FetchIPError.Reason.HEAD)

Do you mean we should also set INIT for on-demand clusters if ray is not running?

Copy link
Collaborator

@Michaelvll Michaelvll Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we will never set the state to UP after failing to get the IPs. The code will go through the following code path and set the cluster to INIT. We should probably maintain that behavior.

if is_abnormal:
backend = get_backend_from_handle(handle)
if isinstance(backend,
backends.CloudVmRayBackend) and record['autostop'] >= 0:
if not backend.is_definitely_autostopping(handle,
stream_logs=False):
# Reset the autostopping as the cluster is abnormal, and may
# not correctly autostop. Resetting the autostop will let
# the user know that the autostop may not happen to avoid
# leakages from the assumption that the cluster will autostop.
try:
backend.set_autostop(handle, -1, stream_logs=False)
except (Exception, SystemExit) as e: # pylint: disable=broad-except
logger.debug(f'Failed to reset autostop. Due to '
f'{common_utils.format_exception(e)}')
global_user_state.set_cluster_autostop_value(
handle.cluster_name, -1, to_down=False)
else:
ux_utils.console_newline()
operation_str = 'autodowning' if record[
'to_down'] else 'autostopping'
logger.info(
f'Cluster {cluster_name!r} is {operation_str}. Setting to '
'INIT status; try refresh again in a while.')
# If the user starts part of a STOPPED cluster, we still need a status
# to represent the abnormal status. For spot cluster, it can also
# represent that the cluster is partially preempted.
# TODO(zhwu): the definition of INIT should be audited/changed.
# Adding a new status UNHEALTHY for abnormal status can be a choice.
global_user_state.add_or_update_cluster(cluster_name,
handle,
requested_resources=None,
ready=False,
is_launch=False)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah sorry I misread the code. then this makes sense. let me add this case to the change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified the code to also check ray cluster healthiness for Spot clusters. Now only when ray_cluster_up and all clusters are running, we set the state UP. Does this look okay?

Copy link
Collaborator

@Michaelvll Michaelvll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the fix @infwinston! The main concern is that we should check the healthiness of the ray cluster for the spot cluster as well. Please see the comment below for the detail.

# the true node statuses from the cloud provider.
# For spot clusters, the above can be unsafe.
# Therefore we directly query the cloud provider.
if not use_spot:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is to say, we previously assume ray cluster is healthy indicates cloud provider shows UP, but based on your assumption both should be checked for the spot cluster.

sky/backends/cloud_vm_ray_backend.py Outdated Show resolved Hide resolved
sky/spot/recovery_strategy.py Outdated Show resolved Hide resolved
sky/spot/recovery_strategy.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@Michaelvll Michaelvll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix @infwinston! Left several comments.

Comment on lines 1457 to 1458
ip = (endpoint.get('ipAddress', None) if get_internal_ips else
endpoint['accessConfig'].get('externalIp', None))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: having the expanded if...else... may be easier to read?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

sky/backends/backend_utils.py Outdated Show resolved Hide resolved
sky/backends/backend_utils.py Outdated Show resolved Hide resolved
Comment on lines 1791 to 1793
# If we get node ips correctly, the cluster is UP. It is safe to
# set the status to UP, as the `handle.external_ips` function uses ray
# to fetch IPs and starting ray is the final step of sky launch.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it may not be true for multi-node either, as the ray get-node-ips only use the tag on the cloud without actually checking the ray cluster status. We may want to make L1778 check status for all the nodes instead.

Copy link
Member Author

@infwinston infwinston Mar 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. I remove the if and now it also run ray status for multi-node case. just want to make sure that's what we want right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about something similar as the following code to check the number of workers is the same as the expected number with ray status.

We can refactor this part of code out and use it?

rc, output, stderr = runner.run('ray status',
log_path=log_path,
stream_logs=False,
require_outputs=True,
separate_stderr=True)
subprocess_utils.handle_returncode(
rc, 'ray status', 'Failed to run ray status on head node.',
stderr)
logger.debug(output)
# Workers that are ready
ready_workers = 0
# On-prem/local case is handled differently.
# `ray status` produces different output for local case, and
# we poll for number of nodes launched instead of counting for
# head and number of worker nodes separately (it is impossible
# to distinguish between head and worker node for local case).
if is_local_cloud:
result = _LAUNCHED_LOCAL_WORKER_PATTERN.findall(output)
# In the local case, ready_workers mean the total number
# of nodes launched, including head.
ready_workers = len(result)
else:
result = _LAUNCHED_WORKER_PATTERN.findall(output)
if len(result) == 0:
ready_workers = 0
else:
assert len(result) == 1, result
ready_workers = int(result[0])
result = _LAUNCHED_HEAD_PATTERN.findall(output)
ready_head = 0
if result:
assert len(result) == 1, result
ready_head = int(result[0])
assert ready_head <= 1, ready_head
worker_status.update('[bold cyan]'
f'{ready_workers} out of {num_nodes - 1} '
'workers ready')
# In the local case, ready_head=0 and ready_workers=num_nodes. This
# is because there is no matching regex for _LAUNCHED_HEAD_PATTERN.
if ready_head + ready_workers == num_nodes:
# All nodes are up.
break

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh got it. Parsing the output from ray status is needed. I just updated the code. PTAL.

Comment on lines 1791 to 1793
# If we get node ips correctly, the cluster is UP. It is safe to
# set the status to UP, as the `handle.external_ips` function uses ray
# to fetch IPs and starting ray is the final step of sky launch.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about something similar as the following code to check the number of workers is the same as the expected number with ray status.

We can refactor this part of code out and use it?

rc, output, stderr = runner.run('ray status',
log_path=log_path,
stream_logs=False,
require_outputs=True,
separate_stderr=True)
subprocess_utils.handle_returncode(
rc, 'ray status', 'Failed to run ray status on head node.',
stderr)
logger.debug(output)
# Workers that are ready
ready_workers = 0
# On-prem/local case is handled differently.
# `ray status` produces different output for local case, and
# we poll for number of nodes launched instead of counting for
# head and number of worker nodes separately (it is impossible
# to distinguish between head and worker node for local case).
if is_local_cloud:
result = _LAUNCHED_LOCAL_WORKER_PATTERN.findall(output)
# In the local case, ready_workers mean the total number
# of nodes launched, including head.
ready_workers = len(result)
else:
result = _LAUNCHED_WORKER_PATTERN.findall(output)
if len(result) == 0:
ready_workers = 0
else:
assert len(result) == 1, result
ready_workers = int(result[0])
result = _LAUNCHED_HEAD_PATTERN.findall(output)
ready_head = 0
if result:
assert len(result) == 1, result
ready_head = int(result[0])
assert ready_head <= 1, ready_head
worker_status.update('[bold cyan]'
f'{ready_workers} out of {num_nodes - 1} '
'workers ready')
# In the local case, ready_head=0 and ready_workers=num_nodes. This
# is because there is no matching regex for _LAUNCHED_HEAD_PATTERN.
if ready_head + ready_workers == num_nodes:
# All nodes are up.
break

Comment on lines 1793 to 1796
# For spot clusters, the above can be unsafe because the Ray cluster
# may remain healty for a while before the cloud completely
# terminates the VMs.
# Additionally, we query the VM state from the cloud provider.
Copy link
Collaborator

@Michaelvll Michaelvll Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the fix for checking the ray status for multiple nodes already enough? The previous problem might be because the worker VM is preempted, but the IP can still be got with ray get-node-ips?

Did the problem happen for a user with a single node or multiple nodes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem happened to a user with a single node (tpu-v2-8) I believe

Comment on lines 1780 to 1804
# Check if ray cluster status is healthy.
ssh_credentials = ssh_credential_from_yaml(handle.cluster_yaml)
runner = command_runner.SSHCommandRunner(external_ips[0],
**ssh_credentials)
rc, output, _ = runner.run('ray status',
stream_logs=False,
require_outputs=True,
separate_stderr=True)
if rc:
raise exceptions.FetchIPError(
reason=exceptions.FetchIPError.Reason.HEAD)

def get_ready_nodes(pattern, output):
result = pattern.findall(output)
if len(result) == 0:
return 0
assert len(result) == 1, result
return int(result[0])

ready_workers = get_ready_nodes(_LAUNCHED_WORKER_PATTERN, output)
ready_head = get_ready_nodes(_LAUNCHED_HEAD_PATTERN, output)
assert ready_head <= 1, f'#head node should be <=1 (Got {ready_head}).'

if ready_head + ready_workers == handle.launched_nodes:
ray_cluster_up = True
Copy link
Collaborator

@Michaelvll Michaelvll Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of directly copying the code here. can we refactor it out as a function so it can be reused in both places? For example,

def count_healthy_nodes_with_ray(runner, is_local: bool) -> Tuple[int, int]:
    rc, output, _ = runner.run('ray status',
                               stream_logs=False,
                               require_outputs=True,
                               separate_stderr=True)
    if rc:
        raise exceptions.FetchIPError(
            reason=exceptions.FetchIPError.Reason.HEAD)

    def get_ready_nodes(pattern, output):
        result = pattern.findall(output)
        if len(result) == 0:
            return 0
        assert len(result) == 1, result
        return int(result[0])

    ready_workers = get_ready_nodes(_LAUNCHED_WORKER_PATTERN, output)
    ready_head = get_ready_nodes(_LAUNCHED_HEAD_PATTERN, output)
    return ready_head, ready_workers

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually refactor a bit as they differ in some places but yeah I can try taking it out as a function.

Copy link
Member Author

@infwinston infwinston Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just refactored a bit. I tried to incorporate runner into the function but looks like the output from ray status will be used here.

if '(no pending nodes)' in output and '(no failures)' in output:

so I ended up not taking runner but the output as the argument. PTAL, thanks!

Copy link
Collaborator

@Michaelvll Michaelvll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick fix and refactoring @infwinston! The changes look good to me now. Thanks for the great effort.

Comment on lines 1009 to 1033
def get_ready_nodes(pattern, output, local=False):
result = pattern.findall(output)
# On-prem/local case is handled differently.
# `ray status` produces different output for local case, and
# we poll for number of nodes launched instead of counting for
# head and number of worker nodes separately (it is impossible
# to distinguish between head and worker node for local case).
if local:
# In the local case, ready_workers mean the total number
# of nodes launched, including head.
return len(result)
if len(result) == 0:
return 0
assert len(result) == 1, result
return int(result[0])

if is_local_cloud:
ready_workers = get_ready_nodes(_LAUNCHED_LOCAL_WORKER_PATTERN,
output,
local=True)
else:
ready_workers = get_ready_nodes(_LAUNCHED_WORKER_PATTERN,
output,
local=False)
ready_head = get_ready_nodes(_LAUNCHED_HEAD_PATTERN, output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
def get_ready_nodes(pattern, output, local=False):
result = pattern.findall(output)
# On-prem/local case is handled differently.
# `ray status` produces different output for local case, and
# we poll for number of nodes launched instead of counting for
# head and number of worker nodes separately (it is impossible
# to distinguish between head and worker node for local case).
if local:
# In the local case, ready_workers mean the total number
# of nodes launched, including head.
return len(result)
if len(result) == 0:
return 0
assert len(result) == 1, result
return int(result[0])
if is_local_cloud:
ready_workers = get_ready_nodes(_LAUNCHED_LOCAL_WORKER_PATTERN,
output,
local=True)
else:
ready_workers = get_ready_nodes(_LAUNCHED_WORKER_PATTERN,
output,
local=False)
ready_head = get_ready_nodes(_LAUNCHED_HEAD_PATTERN, output)
def get_ready_nodes(pattern, output):
result = pattern.findall(output)
# On-prem/local case is handled differently.
# `ray status` produces different output for local case, and
# we poll for number of nodes launched instead of counting for
# head and number of worker nodes separately (it is impossible
# to distinguish between head and worker node for local case).
if is_local_cloud:
# In the local case, ready_workers mean the total number
# of nodes launched, including head.
return len(result)
if len(result) == 0:
return 0
assert len(result) == 1, result
return int(result[0])
if is_local_cloud:
ready_head = 0
ready_workers = get_ready_nodes(_LAUNCHED_LOCAL_WORKER_PATTERN, output)
else:
ready_head = get_ready_nodes(_LAUNCHED_HEAD_PATTERN, output)
ready_workers = get_ready_nodes(_LAUNCHED_WORKER_PATTERN, output)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

@@ -1001,6 +1001,40 @@ def get_timestamp_from_run_timestamp(run_timestamp: str) -> float:
run_timestamp.partition('-')[2], '%Y-%m-%d-%H-%M-%S-%f').timestamp()


def count_healthy_nodes_from_ray(output: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: make this function private?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

sky/spot/recovery_strategy.py Outdated Show resolved Hide resolved
@infwinston
Copy link
Member Author

infwinston commented Mar 16, 2023

I just re-run the smoke tests and it passed. merging now. Thanks a lot for the reviews!

@infwinston infwinston merged commit 322ffad into master Mar 16, 2023
@infwinston infwinston deleted the fix-spot-tpu branch March 16, 2023 07:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants