Skip to content

Commit

Permalink
[AIR] Discard returns of train loops in Trainers (ray-project#26448)
Browse files Browse the repository at this point in the history
Discards returns of user defined train loop functions to prevent deser issues with eg. torch models. Those returns are not used anywhere in AIR, so there is no loss of functionality.
  • Loading branch information
Yard1 committed Jul 12, 2022
1 parent 781c2a7 commit 8bb6742
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 2 deletions.
20 changes: 18 additions & 2 deletions python/ray/train/_internal/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import functools
import inspect
import os
import logging
Expand Down Expand Up @@ -111,6 +112,7 @@ def construct_train_func(
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
config: Optional[Dict[str, Any]],
fn_arg_name: Optional[str] = "train_func",
discard_returns: bool = False,
) -> Callable[[], T]:
"""Validates and constructs the training function to execute.
Args:
Expand All @@ -120,13 +122,27 @@ def construct_train_func(
``train_func``. If None then an empty Dict will be created.
fn_arg_name (Optional[str]): The name of training function to use for error
messages.
discard_returns: Whether to discard any returns from train_func or not.
Returns:
A valid training function.
Raises:
ValueError: if the input ``train_func`` is invalid.
"""
signature = inspect.signature(train_func)
num_params = len(signature.parameters)

if discard_returns:
# Discard any returns from the function so that
# BackendExecutor doesn't try to deserialize them.
# Those returns are inaccesible with AIR anyway.
@functools.wraps(train_func)
def discard_return_wrapper(*args, **kwargs):
train_func(*args, **kwargs)

wrapped_train_func = discard_return_wrapper
else:
wrapped_train_func = train_func

if num_params > 1:
err_msg = (
f"{fn_arg_name} should take in 0 or 1 arguments, but it accepts "
Expand All @@ -135,9 +151,9 @@ def construct_train_func(
raise ValueError(err_msg)
elif num_params == 1:
config = {} if config is None else config
return lambda: train_func(config)
return lambda: wrapped_train_func(config)
else: # num_params == 0
return train_func
return wrapped_train_func


class Singleton(abc.ABCMeta):
Expand Down
4 changes: 4 additions & 0 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def train_loop_per_worker():
# Returns the rank of the worker on the current node.
session.get_local_rank()
Any returns from the ``train_loop_per_worker`` will be discarded and not
used or persisted anywhere.
**How do I use ``DataParallelTrainer`` or any of its subclasses?**
Example:
Expand Down Expand Up @@ -317,6 +320,7 @@ def training_loop(self) -> None:
self._train_loop_per_worker,
self._train_loop_config,
fn_arg_name="train_loop_per_worker",
discard_returns=True,
)

additional_resources_per_worker = (
Expand Down
3 changes: 3 additions & 0 deletions python/ray/train/horovod/horovod_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def train_loop_per_worker():
# Returns the rank of the worker on the current node.
session.get_local_rank()
Any returns from the ``train_loop_per_worker`` will be discarded and not
used or persisted anywhere.
You could use ``TensorflowPredictor`` or ``TorchPredictor`` in conjunction with
HorovodTrainer. You must save the model under the "model" kwarg in the
``Checkpoint`` passed to ``session.report()``, so that it can be used by
Expand Down
3 changes: 3 additions & 0 deletions python/ray/train/tensorflow/tensorflow_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def train_loop_per_worker():
# as the data will be already sharded.
train.tensorflow.prepare_dataset_shard(...)
Any returns from the ``train_loop_per_worker`` will be discarded and not
used or persisted anywhere.
To save a model to use for the ``TensorflowPredictor``, you must save it under the
"model" kwarg in ``Checkpoint`` passed to ``session.report()``.
Expand Down
20 changes: 20 additions & 0 deletions python/ray/train/tests/test_data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,26 @@ def train_loop(config, extra_arg):
DataParallelTrainer(train_loop_per_worker=train_loop)


def test_bad_return_in_train_loop(ray_start_4_cpus):
"""Test to check if returns from train loop are discarded."""

# Simulates what happens with eg. torch models
class FailOnUnpickle:
def __reduce__(self):
raise RuntimeError("Failing")

def train_loop(config):
session.report({"loss": 1})
return FailOnUnpickle()

trainer = DataParallelTrainer(
train_loop_per_worker=train_loop, scaling_config=scale_config
)

# No exception should happen here
trainer.fit()


def test_tune(ray_start_4_cpus):
def train_func(config):
session.report({"loss": config["x"]})
Expand Down
3 changes: 3 additions & 0 deletions python/ray/train/torch/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def train_loop_per_worker():
# Returns the current torch device.
train.torch.get_device()
Any returns from the ``train_loop_per_worker`` will be discarded and not
used or persisted anywhere.
To save a model to use for the ``TorchPredictor``, you must save it under the
"model" kwarg in ``Checkpoint`` passed to ``session.report()``.
Expand Down

0 comments on commit 8bb6742

Please sign in to comment.