Skip to content

Commit

Permalink
[AIR] Fix ResourceChangingScheduler not working with AIR (ray-proje…
Browse files Browse the repository at this point in the history
…ct#26307)

This PR ensures that the new trial resources set by `ResourceChangingScheduler` are respected by the train loop logic by modifying the scaling config to match. Previously, even though trials had their resources updated, the scaling config was not modified which lead to eg. new workers not being spawned in the `DataParallelTrainer` even though resources were available.

In order to accomplish this, `ScalingConfigDataClass` is updated to allow equality comparisons with other `ScalingConfigDataClass`es (using the underlying PGF) and to create a `ScalingConfigDataClass` from a PGF.

Please note that this is an internal only change intended to actually make `ResourceChangingScheduler` work. In the future, `ResourceChangingScheduler` should be updated to operate on `ScalingConfigDataClass`es instead of PGFs as it is now. That will require a deprecation cycle.
  • Loading branch information
Yard1 committed Jul 12, 2022
1 parent f5c5215 commit b3878e2
Show file tree
Hide file tree
Showing 13 changed files with 347 additions and 20 deletions.
8 changes: 8 additions & 0 deletions python/ray/air/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,14 @@ py_test(
deps = [":ml_lib"]
)

py_test(
name = "test_resource_changing",
size = "medium",
srcs = ["tests/test_resource_changing.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"]
)

py_test(
name = "test_tensor_extension",
size = "small",
Expand Down
7 changes: 5 additions & 2 deletions python/ray/air/_internal/session.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import abc
import logging
from typing import Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional

from ray.air.checkpoint import Checkpoint

if TYPE_CHECKING:
from ray.tune.execution.placement_groups import PlacementGroupFactory

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -67,7 +70,7 @@ def trial_id(self) -> str:
raise NotImplementedError

@property
def trial_resources(self) -> Dict[str, float]:
def trial_resources(self) -> "PlacementGroupFactory":
"""Trial resources for the corresponding trial."""
raise NotImplementedError

Expand Down
89 changes: 79 additions & 10 deletions python/ray/air/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ScalingConfigDataClass:
This is the schema for the scaling_config dict, and after beta, this will be the
actual representation for Scaling config objects.
trainer_resources: Resources to allocate for the trainer. If none is provided,
trainer_resources: Resources to allocate for the trainer. If None is provided,
will default to 1 CPU.
num_workers: The number of workers (Ray actors) to launch.
Each worker will reserve 1 CPU by default. The number of CPUs
Expand All @@ -52,9 +52,6 @@ class ScalingConfigDataClass:
placement_strategy: str = "PACK"

def __post_init__(self):
self.resources_per_worker = (
self.resources_per_worker if self.resources_per_worker else {}
)
if self.resources_per_worker:
if not self.use_gpu and self.num_gpus_per_worker > 0:
raise ValueError(
Expand All @@ -71,32 +68,51 @@ def __post_init__(self):
"`resources_per_worker."
)

def __eq__(self, o: "ScalingConfigDataClass") -> bool:
if not isinstance(o, type(self)):
return False
return self.as_placement_group_factory() == o.as_placement_group_factory()

@property
def _resources_per_worker_not_none(self):
if self.resources_per_worker is None:
return {"CPU": 1, "GPU": int(self.use_gpu)}
resources_per_worker = {
k: v for k, v in self.resources_per_worker.items() if v != 0
}
resources_per_worker.setdefault("GPU", int(self.use_gpu))
return resources_per_worker

@property
def _trainer_resources_not_none(self):
if self.trainer_resources is None:
return {"CPU": 1}
return {k: v for k, v in self.trainer_resources.items() if v != 0}

@property
def num_cpus_per_worker(self):
"""The number of CPUs to set per worker."""
return self.resources_per_worker.get("CPU", 1)
return self._resources_per_worker_not_none.get("CPU", 0)

@property
def num_gpus_per_worker(self):
"""The number of GPUs to set per worker."""
return self.resources_per_worker.get("GPU", int(self.use_gpu))
return self._resources_per_worker_not_none.get("GPU", 0)

@property
def additional_resources_per_worker(self):
"""Resources per worker, not including CPU or GPU resources."""
return {
k: v
for k, v in self.resources_per_worker.items()
for k, v in self._resources_per_worker_not_none.items()
if k not in ["CPU", "GPU"]
}

def as_placement_group_factory(self) -> "PlacementGroupFactory":
"""Returns a PlacementGroupFactory to specify resources for Tune."""
from ray.tune.execution.placement_groups import PlacementGroupFactory

trainer_resources = (
self.trainer_resources if self.trainer_resources else {"CPU": 1}
)
trainer_resources = self._trainer_resources_not_none
trainer_bundle = [trainer_resources]
worker_resources = {
"CPU": self.num_cpus_per_worker,
Expand All @@ -112,6 +128,42 @@ def as_placement_group_factory(self) -> "PlacementGroupFactory":
bundles = trainer_bundle + worker_bundles
return PlacementGroupFactory(bundles, strategy=self.placement_strategy)

@classmethod
def from_placement_group_factory(
cls, pgf: "PlacementGroupFactory"
) -> "ScalingConfigDataClass":
"""Create a ScalingConfig from a Tune's PlacementGroupFactory"""
if pgf.head_bundle_is_empty:
trainer_resources = {}
worker_bundles = pgf.bundles
else:
trainer_resources = pgf.bundles[0]
worker_bundles = pgf.bundles[1:]

use_gpu = False
placement_strategy = pgf.strategy
resources_per_worker = None
num_workers = None

if worker_bundles:
first_bundle = worker_bundles[0]
if not all(bundle == first_bundle for bundle in worker_bundles[1:]):
raise ValueError(
"All worker bundles (any other than the first one) "
"must be equal to each other."
)
use_gpu = bool(first_bundle.get("GPU"))
num_workers = len(worker_bundles)
resources_per_worker = first_bundle

return ScalingConfigDataClass(
trainer_resources=trainer_resources,
num_workers=num_workers,
use_gpu=use_gpu,
resources_per_worker=resources_per_worker,
placement_strategy=placement_strategy,
)


@dataclass
@PublicAPI(stability="alpha")
Expand Down Expand Up @@ -283,9 +335,26 @@ class FailureConfig:
Will recover from the latest checkpoint if present.
Setting to -1 will lead to infinite recovery retries.
Setting to 0 will disable retries. Defaults to 0.
fail_fast: Whether to fail upon the first error.
If fail_fast='raise' provided, Tune will automatically
raise the exception received by the Trainable. fail_fast='raise'
can easily leak resources and should be used with caution (it
is best used with `ray.init(local_mode=True)`).
"""

max_failures: int = 0
fail_fast: Union[bool, str] = False

def __post_init__(self):
# Same check as in tune.run
if self.fail_fast and self.max_failures != 0:
raise ValueError("max_failures must be 0 if fail_fast=True.")

# Same check as in TrialRunner
if not (isinstance(self.fail_fast, bool) or self.fail_fast.upper() != "RAISE"):
raise ValueError(
"fail_fast must be one of {bool, 'raise'}. " f"Got {self.fail_fast}."
)


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion python/ray/air/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

if TYPE_CHECKING:
from ray.data import Dataset, DatasetPipeline
from ray.tune.execution.placement_groups import PlacementGroupFactory


def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
Expand Down Expand Up @@ -118,7 +119,7 @@ def get_trial_id() -> str:
return _get_session().trial_id


def get_trial_resources() -> Dict[str, float]:
def get_trial_resources() -> "PlacementGroupFactory":
"""Trial resources for the corresponding trial."""
return _get_session().trial_resources

Expand Down
33 changes: 33 additions & 0 deletions python/ray/air/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,39 @@ def test_scaling_config_validate_config_bad_allowed_keys():
assert "are not present in" in str(exc_info.value)


@pytest.mark.parametrize(
"trainer_resources", [None, {}, {"CPU": 1}, {"CPU": 2, "GPU": 1}, {"CPU": 0}]
)
@pytest.mark.parametrize("num_workers", [None, 1, 2])
@pytest.mark.parametrize(
"resources_per_worker_and_use_gpu",
[
(None, False),
(None, True),
({}, False),
({"CPU": 1}, False),
({"CPU": 2, "GPU": 1}, True),
({"CPU": 0}, False),
],
)
@pytest.mark.parametrize("placement_strategy", ["PACK", "SPREAD"])
def test_scaling_config_pgf_equivalance(
trainer_resources, resources_per_worker_and_use_gpu, num_workers, placement_strategy
):
resources_per_worker, use_gpu = resources_per_worker_and_use_gpu
scaling_config = ScalingConfigDataClass(
trainer_resources=trainer_resources,
num_workers=num_workers,
resources_per_worker=resources_per_worker,
use_gpu=use_gpu,
placement_strategy=placement_strategy,
)
pgf = scaling_config.as_placement_group_factory()
scaling_config_from_pgf = ScalingConfigDataClass.from_placement_group_factory(pgf)
assert scaling_config == scaling_config_from_pgf
assert scaling_config_from_pgf.as_placement_group_factory() == pgf


def test_datasets():
with pytest.raises(ValueError):
DummyTrainer(datasets="invalid")
Expand Down
3 changes: 3 additions & 0 deletions python/ray/air/tests/test_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def train_loop_per_worker():
else:
assert shard.count() == v, shard

kwargs.pop("scaling_config", None)
super().__init__(
train_loop_per_worker=train_loop_per_worker,
scaling_config={"num_workers": num_workers},
Expand Down Expand Up @@ -204,6 +205,7 @@ def train_loop_per_worker():
results.append(epoch.take())
check_results_fn(data_shard, results)

kwargs.pop("scaling_config", None)
super().__init__(
train_loop_per_worker=train_loop_per_worker,
scaling_config={"num_workers": 1},
Expand All @@ -223,6 +225,7 @@ def train_loop_per_worker():
results = data_shard.take()
check_results_fn(data_shard, results)

kwargs.pop("scaling_config", None)
super().__init__(
train_loop_per_worker=train_loop_per_worker,
scaling_config={"num_workers": 1},
Expand Down
Loading

0 comments on commit b3878e2

Please sign in to comment.