Skip to content

Commit

Permalink
[Data] Cooperatively exit producer threads for iter_batches (#34819)
Browse files Browse the repository at this point in the history
This is to fix the bug for `iter_batches` where producer daemon threads hanging there, and holding batches in memory, when caller breaks early during iteration. Example of caller like:

```py
for batch in ds.iter_batches():
  if ... :
    break
```

Change from using Python `queue.Queue`, to use a set of `Semaphore`, `Lock` and plain `deque` to allow cooperatively exit producer threads. I don't find how to achieve the same by using any classes of Python thread-safe `Queue`s, so roll my own version of producer-consumer queue here.

Also verified with user this PR fixed the GRAM OOM issue, by rerunning the workload with this PR.
  • Loading branch information
c21 committed Apr 28, 2023
1 parent 1274f86 commit c4b67f4
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 20 deletions.
112 changes: 92 additions & 20 deletions python/ray/data/_internal/block_batching/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import queue
import threading
from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
from collections import deque
from contextlib import nullcontext

import ray
Expand Down Expand Up @@ -230,19 +230,22 @@ class Sentinel:
def __init__(self, thread_index: int):
self.thread_index = thread_index

output_queue = queue.Queue(1)
output_queue = Queue(1)

# Because pulling from the base iterator cannot happen concurrently,
# we must execute the expensive computation in a separate step which
# can be parallelized via a threadpool.
def execute_computation(thread_index: int):
try:
for item in fn(thread_safe_generator):
output_queue.put(item, block=True)
output_queue.put(Sentinel(thread_index), block=True)
if output_queue.put(item):
# Return early when it's instructed to do so.
return
output_queue.put(Sentinel(thread_index))
except Exception as e:
output_queue.put(e, block=True)
output_queue.put(e)

# Use separate threads to produce output batches.
threads = [
threading.Thread(target=execute_computation, args=(i,), daemon=True)
for i in range(num_workers)
Expand All @@ -251,22 +254,28 @@ def execute_computation(thread_index: int):
for thread in threads:
thread.start()

# Use main thread to consume output batches.
num_threads_finished = 0
while True:
next_item = output_queue.get(block=True)
if isinstance(next_item, Exception):
output_queue.task_done()
raise next_item
if isinstance(next_item, Sentinel):
output_queue.task_done()
logger.debug(f"Thread {next_item.thread_index} finished.")
num_threads_finished += 1
threads[next_item.thread_index].join()
else:
yield next_item
output_queue.task_done()
if num_threads_finished >= num_workers:
break
try:
while True:
next_item = output_queue.get()
if isinstance(next_item, Exception):
raise next_item
if isinstance(next_item, Sentinel):
logger.debug(f"Thread {next_item.thread_index} finished.")
num_threads_finished += 1
else:
yield next_item
if num_threads_finished >= num_workers:
break
finally:
# Cooperatively exit all producer threads.
# This is to avoid these daemon threads hanging there with holding batches in
# memory, which can cause GRAM OOM easily. This can happen when caller breaks
# in the middle of iteration.
num_threads_alive = num_workers - num_threads_finished
if num_threads_alive > 0:
output_queue.release(num_threads_alive)


PREFETCHER_ACTOR_NAMESPACE = "ray.datastream"
Expand Down Expand Up @@ -309,3 +318,66 @@ class _BlockPretcher:

def prefetch(self, *blocks) -> None:
pass


class Queue:
"""A thread-safe queue implementation for multiple producers and consumers.
Provide `release()` to exit producer threads cooperatively for resource release.
"""

def __init__(self, queue_size: int):
# The queue shared across multiple producer threads.
self._queue = deque()
# The boolean varilable to indicate whether producer threads should exit.
self._threads_exit = False
# The semaphore for producer threads to put item into queue.
self._producer_semaphore = threading.Semaphore(queue_size)
# The semaphore for consumer threads to get item from queue.
self._consumer_semaphore = threading.Semaphore(0)
# The mutex lock to guard access of `self._queue` and `self._threads_exit`.
self._mutex = threading.Lock()

def put(self, item: Any) -> bool:
"""Put an item into the queue.
Block if necessary until a free slot is available in queue.
This method is called by producer threads.
Returns:
True if the caller thread should exit immediately.
"""
self._producer_semaphore.acquire()
with self._mutex:
if self._threads_exit:
return True
else:
self._queue.append(item)
self._consumer_semaphore.release()
return False

def get(self) -> Any:
"""Remove and return an item from the queue.
Block if necessary until an item is available in queue.
This method is called by consumer threads.
"""
self._consumer_semaphore.acquire()
with self._mutex:
next_item = self._queue.popleft()
self._producer_semaphore.release()
return next_item

def release(self, num_threads: int):
"""Release `num_threads` of producers so they would exit cooperatively."""
with self._mutex:
self._threads_exit = True
for _ in range(num_threads):
# NOTE: After Python 3.9+, Semaphore.release(n) can be used to
# release all threads at once.
self._producer_semaphore.release()

def qsize(self):
"""Return the size of the queue."""
with self._mutex:
return len(self._queue)
82 changes: 82 additions & 0 deletions python/ray/data/tests/block_batching/test_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
import pytest
import time

Expand All @@ -7,6 +8,7 @@

import ray
from ray.data._internal.block_batching.util import (
Queue,
_calculate_ref_hits,
make_async_gen,
blocks_to_batches,
Expand Down Expand Up @@ -173,6 +175,86 @@ def sleep_udf(item):
assert end_time - start_time < 9.5


def test_make_async_gen_multiple_threads_unfinished():
"""Tests that using multiple threads can overlap compute even more.
Do not finish iteration with break in the middle.
"""

num_items = 5

def gen(base_iterator):
for i in base_iterator:
time.sleep(4)
yield i

def sleep_udf(item):
time.sleep(5)
return item

# All 5 items should be fetched concurrently.
iterator = make_async_gen(
base_iterator=iter(range(num_items)), fn=gen, num_workers=5
)

start_time = time.time()

# Only sleep for first item.
sleep_udf(next(iterator))

# All subsequent items should already be prefetched and should be ready.
for i, _ in enumerate(iterator):
if i > 2:
break
end_time = time.time()

# 4 second for first item, 5 seconds for udf, 0.5 seconds buffer
assert end_time - start_time < 9.5


def test_queue():
queue = Queue(5)
num_producers = 10
num_producers_finished = 0
num_items = 20

def execute_computation():
for item in range(num_items):
if queue.put(item):
# Return early when it's instructed to do so.
break
# Put -1 as indicator of thread being finished.
queue.put(-1)

# Use separate threads as producers.
threads = [
threading.Thread(target=execute_computation, daemon=True)
for _ in range(num_producers)
]

for thread in threads:
thread.start()

for i in range(num_producers * num_items):
item = queue.get()
if item == -1:
num_producers_finished += 1
if i > num_producers * num_items / 2:
num_producers_alive = num_producers - num_producers_finished
# Check there are some alive producers.
assert num_producers_alive > 0, num_producers_alive
# Release the alive producers.
queue.release(num_producers_alive)
# Consume the remaining items in queue.
while queue.qsize() > 0:
queue.get()
break

# Sleep 5 seconds to allow producer threads to exit.
time.sleep(5)
# Then check the queue is still empty.
assert queue.qsize() == 0


def test_calculate_ref_hits(ray_start_regular_shared):
refs = [ray.put(0), ray.put(1)]
hits, misses, unknowns = _calculate_ref_hits(refs)
Expand Down

0 comments on commit c4b67f4

Please sign in to comment.