Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Support async callable classes in map_batches() #46129

Merged
merged 9 commits into from
Jun 25, 2024

Conversation

scottjlee
Copy link
Contributor

@scottjlee scottjlee commented Jun 18, 2024

Why are these changes needed?

Add support for passing CallableClass with asynchronous generator __call__ method to Dataset.map_batches() API. This is useful for streaming outputs from asynchronous generators as they become available to maximize throughput.

Related issue number

Closes #46235

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: Scott Lee <[email protected]>
Copy link
Contributor

@raulchen raulchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that we also need to set max_concurrency to allow handling multiple batches a time.
Some additional changes needed for this PR:

  • the event loop and the thread executor should outlive the transform_fn. they should be global singletons.

tasks = [sleep_and_yield(i) for i in batch["id"]]
results = await asyncio.gather(*tasks)
for result in results:
yield result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using gather, so we can yield results as soon as they become available.

tasks = [asyncio.create_task(sleep_and_yield(i)) for i in batch["id"]]
for task in tasks:
    yield await task


async def process_all_batches():
tasks = [asyncio.create_task(process_batch(x)) for x in input_iterable]
for task in asyncio.as_completed(tasks):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as_completed doesn't preserve the order. Shouldn't use it when preserve_order is set.

pass

async def __call__(self, batch):
tasks = [sleep_and_yield(i) for i in batch["id"]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line.

# Use the existing event loop to create and run
# Tasks to process each batch
loop = ray.data._cached_loop
loop.run_until_complete(process_all_batches())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still running in the main thread. Thus cannot run multiple batches at the same time.
We should put loop.run_forever() in a thread, and then call loop.call_soon_threadsafe here.

for task in asyncio.as_completed(tasks):
await task
# Sentinel to indicate completion.
output_batch_queue.put(None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a special object, in case the UDF also returns None

res = [batch]
if inspect.iscoroutinefunction(fn):
# UDF is a callable class with async generator `__call__` method.
def transform_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This inline function is too long. let's define it as a util function.

@@ -104,22 +107,45 @@ def _parse_op_fn(op: AbstractUDFMap):
fn_constructor_args = op._fn_constructor_args or ()
fn_constructor_kwargs = op._fn_constructor_kwargs or {}

op_fn = make_callable_class_concurrent(op_fn)
if inspect.isasyncgenfunction(op._fn.__call__):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a TODO that we should 1) support non-generator async functions; 2) make the entire map actor async.


n = 5
ds = ray.data.range(n, override_num_blocks=1)
ds = ds.map_batches(AsyncActor, batch_size=None, concurrency=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add max_concurrency to test concurrently handling many batches.

async def fn(item: Any) -> Any:
assert ray.data._cached_fn is not None
assert ray.data._cached_cls == op_fn
assert ray.data._cached_loop is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make code cleaner, let's define a class (e.g. MapActorContext) to capture all these variables

Signed-off-by: Scott Lee <[email protected]>
Signed-off-by: Scott Lee <[email protected]>
@scottjlee scottjlee assigned scottjlee and unassigned raulchen and c21 Jun 24, 2024
@scottjlee scottjlee added the data Ray Data-related issues label Jun 24, 2024
@scottjlee scottjlee assigned raulchen and unassigned scottjlee Jun 24, 2024
Copy link
Contributor

@raulchen raulchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except for some nits

loop.run_forever()

thread = Thread(target=run_loop)
thread.start()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, maybe move the above to MapActorContext to make the code cleaner here.
We can pass in a boolean flag here _MapActorContext(..., is_asyncio=True).
Then the code for sync and async branches can be consolidated.

cached_cls: UserDefinedFunction,
cached_fn: Callable[[Any], Any],
cached_loop: Optional[asyncio.AbstractEventLoop] = None,
cached_asyncio_thread: Optional[Thread] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the cached_ prefixes?

for task in asyncio.as_completed(tasks):
await task
# Sentinel to indicate completion.
output_batch_queue.put(OutputQueueSentinel())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OutputQueueSentinel isn't needed now because of while not future.done()

pass

async def __call__(self, batch):
tasks = [sleep_and_yield(i) for i in batch["id"]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line is redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the catch, i had thought i removed it but somehow ended up back again...

@scottjlee scottjlee requested a review from raulchen June 25, 2024 21:35

else:

def fn(item: Any) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these 2 fns are almost identical, except the assertion? better avoid duplicating the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that's what i initially thought, but one of them is async, and the other is not. is there a way to combine the inner implementation but create an async version?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, nvm, I didn't notice the async prefix

@raulchen raulchen enabled auto-merge (squash) June 25, 2024 22:00
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Jun 25, 2024
@raulchen raulchen merged commit f75ad5d into ray-project:master Jun 25, 2024
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data Ray Data-related issues go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Data] Support async callable classes in map_batches()
3 participants