Skip to content

Commit

Permalink
[RLlib] Don't add a cpu to bundle for learner when using gpu (ray-pro…
Browse files Browse the repository at this point in the history
…ject#35529)

solves ray-project#35409

Prevent fragmentation of resources by not placing gpus
with cpus in bundles for the learner workers, making it
so that an actor that requires only cpu does not
potentially take a bundle that has both a cpu and gpu.

The long term fix will be to allow the specification
of placement group bundle index via tune and ray train.

Signed-off-by: avnishn <[email protected]>
  • Loading branch information
avnishn authored and scv119 committed Jun 11, 2023
1 parent 8d25749 commit 02db24a
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 16 deletions.
31 changes: 31 additions & 0 deletions release/release_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3465,6 +3465,37 @@
cluster_compute: multi_node_checkpointing_compute_config_gce.yaml


- name: rllib_multi_node_e2e_training_smoke_test
group: RLlib tests
working_dir: rllib_tests

frequency: nightly
team: rllib

cluster:
cluster_env: app_config.yaml
cluster_compute: multi_node_checkpointing_compute_config.yaml

run:
timeout: 3600
script: pytest smoke_tests/smoke_test_basic_multi_node_training_learner.py

wait_for_nodes:
num_nodes: 3

alert: default

variations:
- __suffix__: aws
- __suffix__: gce
env: gce
frequency: manual
cluster:
cluster_env: app_config.yaml
cluster_compute: multi_node_checkpointing_compute_config_gce.yaml



- name: rllib_learning_tests_a2c_tf
group: RLlib tests
working_dir: rllib_tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ max_workers: 3

head_node_type:
name: head_node
instance_type: m5.2xlarge
instance_type: m5.xlarge

worker_node_types:
- name: worker_node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ max_workers: 3

head_node_type:
name: head_node
instance_type: n2-standard-8 # m5.2xlarge
instance_type: n2-standard-4 # m5.xlarge

worker_node_types:
- name: worker_node
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import ray
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig


def run_with_tuner_n_rollout_worker_2_gpu(config):
"""Run training with n rollout workers and 2 learner workers with gpu."""
config = config.rollouts(num_rollout_workers=5)
tuner = tune.Tuner(
"PPO",
param_space=config,
run_config=air.RunConfig(
stop={"timesteps_total": 128},
failure_config=air.FailureConfig(fail_fast=True),
),
)
tuner.fit()


def run_with_tuner_0_rollout_worker_2_gpu(config):
"""Run training with 0 rollout workers with 2 learner workers with gpu."""
config = config.rollouts(num_rollout_workers=0)
tuner = tune.Tuner(
"PPO",
param_space=config,
run_config=air.RunConfig(
stop={"timesteps_total": 128},
failure_config=air.FailureConfig(fail_fast=True),
),
)
tuner.fit()


def run_tuner_n_rollout_workers_0_gpu(config):
"""Run training with n rollout workers, multiple learner workers, and no gpu."""
config = config.rollouts(num_rollout_workers=5)
config = config.resources(
num_cpus_per_learner_worker=1,
num_learner_workers=2,
)

tuner = tune.Tuner(
"PPO",
param_space=config,
run_config=air.RunConfig(
stop={"timesteps_total": 128},
failure_config=air.FailureConfig(fail_fast=True),
),
)
tuner.fit()


def run_tuner_n_rollout_workers_1_gpu_local(config):
"""Run training with n rollout workers, local learner, and 1 gpu."""
config = config.rollouts(num_rollout_workers=5)
config = config.resources(
num_gpus_per_learner_worker=1,
num_learner_workers=0,
)

tuner = tune.Tuner(
"PPO",
param_space=config,
run_config=air.RunConfig(
stop={"timesteps_total": 128},
failure_config=air.FailureConfig(fail_fast=True),
),
)
tuner.fit()


def test_multi_node_training_smoke():
"""A smoke test to see if we can run multi node training without pg problems.
This test is run on a 3 node cluster. The head node is a m5.xlarge (4 cpu),
the worker nodes are 2 g4dn.xlarge (1 gpu, 4 cpu) machines.
"""

ray.init()

config = (
PPOConfig()
.training(
_enable_learner_api=True,
model={
"fcnet_hiddens": [256, 256, 256],
"fcnet_activation": "relu",
"vf_share_layers": True,
},
train_batch_size=128,
)
.rl_module(_enable_rl_module_api=True)
.environment("CartPole-v1")
.resources(
num_gpus_per_learner_worker=1,
num_learner_workers=2,
)
.rollouts(num_rollout_workers=2)
.reporting(min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=10)
)
for fw in ["tf2", "torch"]:
config = config.framework(fw, eager_tracing=True)

run_with_tuner_0_rollout_worker_2_gpu(config)
run_with_tuner_n_rollout_worker_2_gpu(config)
run_tuner_n_rollout_workers_0_gpu(config)
run_tuner_n_rollout_workers_1_gpu_local(config)


if __name__ == "__main__":
import sys
import pytest

sys.exit(pytest.main(["-v", __file__]))
23 changes: 16 additions & 7 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2272,13 +2272,22 @@ def default_resource_request(
# resources for remote learner workers
learner_bundles = []
if cf._enable_learner_api and cf.num_learner_workers > 0:
learner_bundles = [
{
"CPU": cf.num_cpus_per_learner_worker,
"GPU": cf.num_gpus_per_learner_worker,
}
for _ in range(cf.num_learner_workers)
]
# can't specify cpus for learner workers at the same
# time as gpus
if cf.num_gpus_per_learner_worker:
learner_bundles = [
{
"GPU": cf.num_gpus_per_learner_worker,
}
for _ in range(cf.num_learner_workers)
]
elif cf.num_cpus_per_learner_worker:
learner_bundles = [
{
"CPU": cf.num_cpus_per_learner_worker,
}
for _ in range(cf.num_learner_workers)
]

bundles = [driver] + rollout_bundles + evaluation_bundles + learner_bundles

Expand Down
23 changes: 21 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,20 @@ def validate(self) -> None:
"via `config.training(_enable_learner_api=True)` (or set both to "
"False)."
)
# TODO @Avnishn: This is a short-term work around due to
# https://github.com/ray-project/ray/issues/35409
# Remove this once we are able to specify placement group bundle index in RLlib
if (
self.num_cpus_per_learner_worker > 1
and self.num_gpus_per_learner_worker > 0
):
raise ValueError(
"Cannot set both `num_cpus_per_learner_worker` and "
" `num_gpus_per_learner_worker` > 0! Users must set one"
" or the other due to issues with placement group"
" fragmentation. See "
"https://github.com/ray-project/ray/issues/35409 for more details."
)

if bool(os.environ.get("RLLIB_ENABLE_RL_MODULE", False)):
# Enable RLModule API and connectors if env variable is set
Expand Down Expand Up @@ -1149,7 +1163,8 @@ def resources(
num_gpus_per_learner_worker: Number of GPUs allocated per worker. If
`num_learner_workers = 0`, any value greater than 0 will run the
training on a single GPU on the head node, while a value of 0 will run
the training on head node CPU cores.
the training on head node CPU cores. If num_gpus_per_learner_worker is
set, then num_cpus_per_learner_worker cannot be set.
local_gpu_idx: if num_gpus_per_worker > 0, and num_workers<2, then this gpu
index will be used for training. This is an index into the available
cuda devices. For example if os.environ["CUDA_VISIBLE_DEVICES"] = "1"
Expand Down Expand Up @@ -3227,7 +3242,11 @@ def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfi
)
.resources(
num_learner_workers=self.num_learner_workers,
num_cpus_per_learner_worker=self.num_cpus_per_learner_worker,
num_cpus_per_learner_worker=(
self.num_cpus_per_learner_worker
if not self.num_gpus_per_learner_worker
else 0
),
num_gpus_per_learner_worker=self.num_gpus_per_learner_worker,
local_gpu_idx=self.local_gpu_idx,
)
Expand Down
2 changes: 0 additions & 2 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,6 @@ def compile_results(
loss_per_module: A dict mapping module IDs (including ALL_MODULES) to the
individual loss tensors as returned by calls to
`compute_loss_for_module(module_id=...)`.
postprocessed_gradients: The postprocessed gradients dict, (flat) mapping
gradient tensor refs to the already postprocessed gradient tensors.
metrics_per_module: The collected metrics defaultdict mapping ModuleIDs to
metrics dicts. These metrics are collected during loss- and
gradient computation, gradient postprocessing, and gradient application.
Expand Down
6 changes: 3 additions & 3 deletions rllib/examples/learner/multi_agent_cartpole_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs):

config = (
PPOConfig()
.rollouts(rollout_fragment_length=500)
.rollouts(rollout_fragment_length="auto", num_rollout_workers=3)
.environment(MultiAgentCartPole, env_config={"num_agents": args.num_agents})
.framework(args.framework)
.training(num_sgd_iter=10)
.training(num_sgd_iter=10, sgd_minibatch_size=2**9, train_batch_size=2**12)
.multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn)
.rl_module(_enable_rl_module_api=True)
.training(_enable_learner_api=True)
Expand All @@ -115,7 +115,7 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs):
results = tune.Tuner(
"PPO",
param_space=config.to_dict(),
run_config=air.RunConfig(stop=stop, verbose=1),
run_config=air.RunConfig(stop=stop, verbose=3),
).fit()

if args.as_test:
Expand Down

0 comments on commit 02db24a

Please sign in to comment.