Skip to content

Commit

Permalink
[core] handle unserializable user exception (ray-project#44878)
Browse files Browse the repository at this point in the history
Signed-off-by: hongchaodeng <[email protected]>
Signed-off-by: Hongchao Deng <[email protected]>
Co-authored-by: angelinalg <[email protected]>
  • Loading branch information
hongchaodeng and angelinalg authored May 2, 2024
1 parent 4671784 commit dcbf195
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 5 deletions.
26 changes: 26 additions & 0 deletions doc/source/ray-core/doc_code/task_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,32 @@ def g(x):
# Exception: the real error

# __task_exceptions_end__
# __unserializable_exceptions_begin__

import threading

class UnserializableException(Exception):
def __init__(self):
self.lock = threading.Lock()

@ray.remote
def raise_unserializable_error():
raise UnserializableException

try:
ray.get(raise_unserializable_error.remote())
except ray.exceptions.RayTaskError as e:
print(e)
# ray::raise_unserializable_error() (pid=328577, ip=172.31.5.154)
# File "/home/ubuntu/ray/tmp~/main.py", line 25, in raise_unserializable_error
# raise UnserializableException
# UnserializableException
print(type(e.cause))
# <class 'ray.exceptions.RayError'>
print(e.cause)
# The original cause of the RayTaskError (<class '__main__.UnserializableException'>) isn't serializable: cannot pickle '_thread.lock' object. Overwriting the cause to a RayError.

# __unserializable_exceptions_end__
# __catch_user_exceptions_begin__

class MyException(Exception):
Expand Down
7 changes: 7 additions & 0 deletions doc/source/ray-core/fault_tolerance/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ Example code of accessing the user exception type when the exception type can *n
:start-after: __catch_user_final_exceptions_begin__
:end-before: __catch_user_final_exceptions_end__

If Ray can't serialize the user's exception, it converts the exception to a ``RayError``.

.. literalinclude:: ../doc_code/task_exceptions.py
:language: python
:start-after: __unserializable_exceptions_begin__
:end-before: __unserializable_exceptions_end__

Use `ray list tasks` from :ref:`State API CLI <state-api-overview-ref>` to query task exit details:

.. code-block:: bash
Expand Down
22 changes: 17 additions & 5 deletions python/ray/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,6 @@ def __init__(
"""Initialize a RayTaskError."""
import ray

# BaseException implements a __reduce__ method that returns
# a tuple with the type and the value of self.args.
# https://stackoverflow.com/a/49715949/2213289
self.args = (function_name, traceback_str, cause, proctitle, pid, ip)
if proctitle:
self.proctitle = proctitle
else:
Expand All @@ -130,8 +126,24 @@ def __init__(
self.traceback_str = traceback_str
self.actor_repr = actor_repr
self._actor_id = actor_id
# TODO(edoakes): should we handle non-serializable exception objects?
self.cause = cause

try:
pickle.dumps(cause)
except (pickle.PicklingError, TypeError) as e:
err_msg = (
"The original cause of the RayTaskError"
f" ({self.cause.__class__}) isn't serializable: {e}."
" Overwriting the cause to a RayError."
)
logger.warning(err_msg)
self.cause = RayError(err_msg)

# BaseException implements a __reduce__ method that returns
# a tuple with the type and the value of self.args.
# https://stackoverflow.com/a/49715949/2213289
self.args = (function_name, traceback_str, self.cause, proctitle, pid, ip)

assert traceback_str is not None

def make_dual_exception_type(self) -> Type:
Expand Down
20 changes: 20 additions & 0 deletions python/ray/tests/test_failure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import time
import logging
import threading

import numpy as np
import pytest
Expand Down Expand Up @@ -600,6 +601,25 @@ def check_actor_restart():
ray.get(obj2)


# Previously when threading.Lock is in the exception, it causes
# the serialization to fail. This test case is to cover that scenario.
def test_unserializable_exception(ray_start_regular, propagate_logs):
class UnserializableException(Exception):
def __init__(self):
self.lock = threading.Lock()

@ray.remote
def func():
raise UnserializableException

with pytest.raises(ray.exceptions.RayTaskError) as exc_info:
ray.get(func.remote())

assert isinstance(exc_info.value, ray.exceptions.RayTaskError)
assert isinstance(exc_info.value.cause, ray.exceptions.RayError)
assert "isn't serializable" in str(exc_info.value.cause)


def test_final_user_exception(ray_start_regular, propagate_logs, caplog):
class MyFinalException(Exception):
def __init_subclass__(cls, /, *args, **kwargs):
Expand Down

0 comments on commit dcbf195

Please sign in to comment.