Skip to content

Commit

Permalink
[Data] Require batch_size for GPU map_batches in strict mode (ray…
Browse files Browse the repository at this point in the history
…-project#34588)

In strict mode, require a batch_size for map_batches when requesting GPUs. This makes batching more explicit to users.

Reduce the CPU default batch size in strict mode from 4096 to 1024.

---------

Signed-off-by: amogkam <[email protected]>
  • Loading branch information
amogkam committed Apr 25, 2023
1 parent 4bb6013 commit b14795d
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 19 deletions.
29 changes: 27 additions & 2 deletions python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
resource = None

if sys.version_info >= (3, 8):
from typing import Protocol
from typing import Literal, Protocol
else:
from typing_extensions import Protocol
from typing_extensions import Literal, Protocol

if TYPE_CHECKING:
import pandas
Expand Down Expand Up @@ -173,6 +173,31 @@ def _apply_strict_mode_batch_format(given_batch_format: Optional[str]) -> str:
return given_batch_format


def _apply_strict_mode_batch_size(
given_batch_size: Optional[Union[int, Literal["default"]]], use_gpu: bool
) -> Optional[int]:
ctx = ray.data.DatasetContext.get_current()
if ctx.strict_mode:
if use_gpu and (not given_batch_size or given_batch_size == "default"):
raise StrictModeError(
"`batch_size` must be provided to `map_batches` when requesting GPUs. "
"The optimal batch size depends on the model, data, and GPU used. "
"It is recommended to use the largest batch size that doesn't result "
"in your GPU device running out of memory. You can view the GPU memory "
"usage via the Ray dashboard."
)
elif given_batch_size == "default":
return ray.data.context.STRICT_MODE_DEFAULT_BATCH_SIZE
else:
return given_batch_size

else:
if given_batch_size == "default":
return ray.data.context.DEFAULT_BATCH_SIZE
else:
return given_batch_size


@DeveloperAPI
class BlockExecStats:
"""Execution stats for this block.
Expand Down
3 changes: 3 additions & 0 deletions python/ray/data/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@
# Default batch size for batch transformations.
DEFAULT_BATCH_SIZE = 4096

# Default batch size for batch transformations in strict mode.
STRICT_MODE_DEFAULT_BATCH_SIZE = 1024

# Whether to enable progress bars.
DEFAULT_ENABLE_PROGRESS_BARS = not bool(
env_integer("RAY_DATA_DISABLE_PROGRESS_BARS", 0)
Expand Down
10 changes: 6 additions & 4 deletions python/ray/data/datastream.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from ray.data.block import (
VALID_BATCH_FORMATS,
_apply_strict_mode_batch_format,
_apply_strict_mode_batch_size,
BatchUDF,
Block,
BlockAccessor,
Expand All @@ -110,7 +111,6 @@
WARN_PREFIX,
OK_PREFIX,
ESTIMATED_SAFE_MEMORY_FRACTION,
DEFAULT_BATCH_SIZE,
)
from ray.data.datasource import (
BlockWritePathProvider,
Expand Down Expand Up @@ -597,14 +597,16 @@ def map_batches(
logger.warning("The 'native' batch format has been renamed 'default'.")

target_block_size = None
if batch_size == "default":
batch_size = DEFAULT_BATCH_SIZE
elif batch_size is not None:
if batch_size is not None and batch_size != "default":
if batch_size < 1:
raise ValueError("Batch size cannot be negative or 0")
# Enable blocks bundling when batch_size is specified by caller.
target_block_size = batch_size

batch_size = _apply_strict_mode_batch_size(
batch_size, use_gpu="num_gpus" in ray_remote_args
)

if batch_format not in VALID_BATCH_FORMATS:
raise ValueError(
f"The batch format must be one of {VALID_BATCH_FORMATS}, got: "
Expand Down
8 changes: 8 additions & 0 deletions python/ray/data/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def ray_start_10_cpus_shared(request):
yield res


@pytest.fixture(scope="module")
def enable_strict_mode():
ctx = ray.data.DataContext.get_current()
ctx.strict_mode = True
yield
ctx.strict_mode = False


@pytest.fixture(scope="function")
def aws_credentials():
import os
Expand Down
30 changes: 17 additions & 13 deletions python/ray/data/tests/test_strict_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@
from ray.data.tests.conftest import * # noqa
from ray.tests.conftest import * # noqa

# Force strict mode.
ctx = ray.data.DataContext.get_current()
ctx.strict_mode = True


def test_strict_read_schemas(ray_start_regular_shared):
def test_strict_read_schemas(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.range(1)
assert ds.take()[0] == {"id": 0}

Expand Down Expand Up @@ -47,7 +43,7 @@ def test_strict_read_schemas(ray_start_regular_shared):
assert "text" in ds.take()[0]


def test_strict_map_output(ray_start_regular_shared):
def test_strict_map_output(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.range(1)

with pytest.raises(StrictModeError):
Expand Down Expand Up @@ -84,7 +80,7 @@ def test_strict_map_output(ray_start_regular_shared):
ds.map(lambda x: UserDict({"x": object()})).materialize()


def test_strict_default_batch_format(ray_start_regular_shared):
def test_strict_default_batch_format(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.range(1)

@ray.remote
Expand All @@ -111,7 +107,7 @@ def f(x):
assert isinstance(batch["id"], np.ndarray), batch


def test_strict_tensor_support(ray_start_regular_shared):
def test_strict_tensor_support(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.from_items([np.ones(10), np.ones(10)])
assert np.array_equal(ds.take()[0]["item"], np.ones(10))

Expand All @@ -122,7 +118,7 @@ def test_strict_tensor_support(ray_start_regular_shared):
assert np.array_equal(ds.take()[0]["item"], 4 * np.ones(10))


def test_strict_value_repr(ray_start_regular_shared):
def test_strict_value_repr(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.from_items([{"__value__": np.ones(10)}])

ds = ds.map_batches(lambda x: {"__value__": x["__value__"] * 2})
Expand All @@ -131,12 +127,12 @@ def test_strict_value_repr(ray_start_regular_shared):
assert np.array_equal(ds.take_batch()["x"][0], 4 * np.ones(10))


def test_strict_object_support(ray_start_regular_shared):
def test_strict_object_support(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.from_items([{"x": 2}, {"x": object()}])
ds.map_batches(lambda x: x, batch_format="numpy").materialize()


def test_strict_compute(ray_start_regular_shared):
def test_strict_compute(ray_start_regular_shared, enable_strict_mode):
with pytest.raises(StrictModeError):
ray.data.range(10).map(lambda x: x, compute="actors").show()
with pytest.raises(StrictModeError):
Expand All @@ -147,7 +143,7 @@ def test_strict_compute(ray_start_regular_shared):
ray.data.range(10).map(lambda x: x, compute="tasks").show()


def test_strict_schema(ray_start_regular_shared):
def test_strict_schema(ray_start_regular_shared, enable_strict_mode):
import pyarrow
from ray.data._internal.pandas_block import PandasBlockSchema

Expand Down Expand Up @@ -182,7 +178,7 @@ def test_strict_schema(ray_start_regular_shared):
assert isinstance(schema.base_schema, PandasBlockSchema)


def test_use_raw_dicts(ray_start_regular_shared):
def test_use_raw_dicts(ray_start_regular_shared, enable_strict_mode):
assert type(ray.data.range(10).take(1)[0]) is dict
assert type(ray.data.from_items([1]).take(1)[0]) is dict

Expand All @@ -193,6 +189,14 @@ def checker(x):
ray.data.range(10).map(checker).show()


def test_strict_require_batch_size_for_gpu(enable_strict_mode):
ray.shutdown()
ray.init(num_cpus=4, num_gpus=1)
ds = ray.data.range(1)
with pytest.raises(StrictModeError):
ds.map_batches(lambda x: x, num_gpus=1)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit b14795d

Please sign in to comment.