Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Don't add a cpu to bundle for learner when using gpu #35529

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions release/release_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3437,6 +3437,37 @@
cluster_compute: multi_node_checkpointing_compute_config_gce.yaml


- name: rllib_multi_node_e2e_training_smoke_test
group: RLlib tests
working_dir: rllib_tests

frequency: nightly
team: rllib

cluster:
cluster_env: app_config.yaml
cluster_compute: multi_node_checkpointing_compute_config.yaml

run:
timeout: 3600
script: pytest smoke_tests/smoke_test_basic_multi_node_training_learner.py

wait_for_nodes:
num_nodes: 3

alert: default

variations:
- __suffix__: aws
- __suffix__: gce
env: gce
frequency: manual
cluster:
cluster_env: app_config.yaml
cluster_compute: multi_node_checkpointing_compute_config_gce.yaml



- name: rllib_learning_tests_a2c_tf
group: RLlib tests
working_dir: rllib_tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ max_workers: 3

head_node_type:
name: head_node
instance_type: m5.2xlarge
instance_type: m5.xlarge

worker_node_types:
- name: worker_node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ max_workers: 3

head_node_type:
name: head_node
instance_type: n2-standard-8 # m5.2xlarge
instance_type: n2-standard-4 # m5.xlarge

worker_node_types:
- name: worker_node
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import ray
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig


def run_with_tuner_n_rollout_worker_2_gpu(config):
"""Run training with n rollout workers and 2 learner workers with gpu."""
config = config.rollouts(num_rollout_workers=5)
tuner = tune.Tuner(
"PPO",
param_space=config,
run_config=air.RunConfig(
stop={"timesteps_total": 128},
failure_config=air.FailureConfig(fail_fast=True),
),
)
tuner.fit()


def run_with_tuner_0_rollout_worker_2_gpu(config):
"""Run training with 0 rollout workers with 2 learner workers with gpu."""
config = config.rollouts(num_rollout_workers=0)
tuner = tune.Tuner(
"PPO",
param_space=config,
run_config=air.RunConfig(
stop={"timesteps_total": 128},
failure_config=air.FailureConfig(fail_fast=True),
),
)
tuner.fit()


def run_tuner_n_rollout_workers_0_gpu(config):
"""Run training with n rollout workers, multiple learner workers, and no gpu."""
config = config.rollouts(num_rollout_workers=5)
config = config.resources(
num_cpus_per_learner_worker=1,
num_learner_workers=2,
)

tuner = tune.Tuner(
"PPO",
param_space=config,
run_config=air.RunConfig(
stop={"timesteps_total": 128},
failure_config=air.FailureConfig(fail_fast=True),
),
)
tuner.fit()


def run_tuner_n_rollout_workers_1_gpu_local(config):
"""Run training with n rollout workers, local learner, and 1 gpu."""
config = config.rollouts(num_rollout_workers=5)
config = config.resources(
num_gpus_per_learner_worker=1,
num_learner_workers=0,
)

tuner = tune.Tuner(
"PPO",
param_space=config,
run_config=air.RunConfig(
stop={"timesteps_total": 128},
failure_config=air.FailureConfig(fail_fast=True),
),
)
tuner.fit()


def test_multi_node_training_smoke():
"""A smoke test to see if we can run multi node training without pg problems.
This test is run on a 3 node cluster. The head node is a m5.xlarge (4 cpu),
the worker nodes are 2 g4dn.xlarge (1 gpu, 4 cpu) machines.
"""

ray.init()

config = (
PPOConfig()
.training(
_enable_learner_api=True,
model={
"fcnet_hiddens": [256, 256, 256],
"fcnet_activation": "relu",
"vf_share_layers": True,
},
train_batch_size=128,
)
.rl_module(_enable_rl_module_api=True)
.environment("CartPole-v1")
.resources(
num_gpus_per_learner_worker=1,
num_learner_workers=2,
)
.rollouts(num_rollout_workers=2)
.reporting(min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=10)
)
for fw in ["tf2", "torch"]:
config = config.framework(fw, eager_tracing=True)

run_with_tuner_0_rollout_worker_2_gpu(config)
run_with_tuner_n_rollout_worker_2_gpu(config)
run_tuner_n_rollout_workers_0_gpu(config)
run_tuner_n_rollout_workers_1_gpu_local(config)


if __name__ == "__main__":
import sys
import pytest

sys.exit(pytest.main(["-v", __file__]))
23 changes: 16 additions & 7 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2272,13 +2272,22 @@ def default_resource_request(
# resources for remote learner workers
learner_bundles = []
if cf._enable_learner_api and cf.num_learner_workers > 0:
learner_bundles = [
{
"CPU": cf.num_cpus_per_learner_worker,
"GPU": cf.num_gpus_per_learner_worker,
}
for _ in range(cf.num_learner_workers)
]
# can't specify cpus for learner workers at the same
# time as gpus
if cf.num_gpus_per_learner_worker:
learner_bundles = [
{
"GPU": cf.num_gpus_per_learner_worker,
}
for _ in range(cf.num_learner_workers)
]
elif cf.num_cpus_per_learner_worker:
learner_bundles = [
{
"CPU": cf.num_cpus_per_learner_worker,
}
for _ in range(cf.num_learner_workers)
]

bundles = [driver] + rollout_bundles + evaluation_bundles + learner_bundles

Expand Down
23 changes: 21 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,20 @@ def validate(self) -> None:
"via `config.training(_enable_learner_api=True)` (or set both to "
"False)."
)
# TODO @Avnishn: This is a short-term work around due to
# https://github.com/ray-project/ray/issues/35409
# Remove this once we are able to specify placement group bundle index in RLlib
if (
avnishn marked this conversation as resolved.
Show resolved Hide resolved
self.num_cpus_per_learner_worker > 1
and self.num_gpus_per_learner_worker > 0
):
raise ValueError(
"Cannot set both `num_cpus_per_learner_worker` and "
" `num_gpus_per_learner_worker` > 0! Users must set one"
" or the other due to issues with placement group"
" fragmentation. See "
"https://github.com/ray-project/ray/issues/35409 for more details."
)

if bool(os.environ.get("RLLIB_ENABLE_RL_MODULE", False)):
# Enable RLModule API and connectors if env variable is set
Expand Down Expand Up @@ -1149,7 +1163,8 @@ def resources(
num_gpus_per_learner_worker: Number of GPUs allocated per worker. If
`num_learner_workers = 0`, any value greater than 0 will run the
training on a single GPU on the head node, while a value of 0 will run
the training on head node CPU cores.
the training on head node CPU cores. If num_gpus_per_learner_worker is
set, then num_cpus_per_learner_worker cannot be set.
local_gpu_idx: if num_gpus_per_worker > 0, and num_workers<2, then this gpu
index will be used for training. This is an index into the available
cuda devices. For example if os.environ["CUDA_VISIBLE_DEVICES"] = "1"
Expand Down Expand Up @@ -3227,7 +3242,11 @@ def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfi
)
.resources(
num_learner_workers=self.num_learner_workers,
num_cpus_per_learner_worker=self.num_cpus_per_learner_worker,
num_cpus_per_learner_worker=(
self.num_cpus_per_learner_worker
if not self.num_gpus_per_learner_worker
else 0
),
Comment on lines +3245 to +3249
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: this is not needed right? cause self.num_cpus_per_learner_worker will be zero if self.num_gpus_per_learner_worker > 0. If this is not the case, we already raise errors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think I just added this in here before I added the error. We don't technically need this Ill remove it if I need to do anything to get the release tests to run.

num_gpus_per_learner_worker=self.num_gpus_per_learner_worker,
local_gpu_idx=self.local_gpu_idx,
)
Expand Down
34 changes: 33 additions & 1 deletion rllib/algorithms/ppo/ppo_learner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import logging
import math
from typing import Any, Dict, List, Mapping, Optional, Union

from ray.rllib.core.learner.learner import LearnerHyperparameters
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.rl_module.rl_module import ModuleID
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.schedules.scheduler import Scheduler

Expand All @@ -16,6 +19,9 @@
LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY = "curr_entropy_coeff"


logger = logging.getLogger(__name__)


@dataclass
class PPOLearnerHyperparameters(LearnerHyperparameters):
"""Hyperparameters for the PPOLearner sub-classes (framework specific).
Expand Down Expand Up @@ -73,3 +79,29 @@ def additional_update_per_module(
results.update({LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY: new_entropy_coeff})

return results

avnishn marked this conversation as resolved.
Show resolved Hide resolved
@override(Learner)
def compile_results(
self,
batch: MultiAgentBatch,
fwd_out: Mapping[str, Any],
postprocessed_loss: Mapping[str, Any],
postprocessed_gradients: Mapping[str, Any],
) -> Mapping[str, Any]:
for module_id, module_loss_results in postprocessed_loss.items():
if module_id == self.TOTAL_LOSS_KEY:
continue
if math.isinf(module_loss_results[LEARNER_RESULTS_KL_KEY]):
logger.warning(
"KL divergence is non-finite, this will likely destabilize "
"your model and the training process. Action(s) in a "
"specific state have near-zero probability. "
"This can happen naturally in deterministic "
"environments where the optimal policy has zero mass "
"for a specific action. To fix this issue, consider "
"setting `kl_coeff` to 0.0 or increasing `entropy_coeff` in your "
"config."
)
return super().compile_results(
batch, fwd_out, postprocessed_loss, postprocessed_gradients
)
13 changes: 0 additions & 13 deletions rllib/algorithms/ppo/tf/ppo_tf_learner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from typing import Any, Dict, Mapping

from ray.rllib.algorithms.ppo.ppo_learner import (
Expand All @@ -20,7 +19,6 @@


_, tf, _ = try_import_tf()
logger = logging.getLogger(__name__)


class PPOTfLearner(PPOLearner, TfLearner):
Expand Down Expand Up @@ -59,17 +57,6 @@ def compute_loss_per_module(
if self.hps.kl_coeff > 0.0:
action_kl = prev_action_dist.kl(curr_action_dist)
mean_kl_loss = tf.reduce_mean(action_kl)
if tf.math.is_inf(mean_kl_loss):
logger.warning(
"KL divergence is non-finite, this will likely destabilize "
"your model and the training process. Action(s) in a "
"specific state have near-zero probability. "
"This can happen naturally in deterministic "
"environments where the optimal policy has zero mass "
"for a specific action. To fix this issue, consider "
"setting `kl_coeff` to 0.0 or increasing `entropy_coeff` in your "
"config."
)
else:
mean_kl_loss = tf.constant(0.0, dtype=logp_ratio.dtype)

Expand Down
14 changes: 0 additions & 14 deletions rllib/algorithms/ppo/torch/ppo_torch_learner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from typing import Any, Dict, Mapping

from ray.rllib.algorithms.ppo.ppo_learner import (
Expand All @@ -20,8 +19,6 @@

torch, nn = try_import_torch()

logger = logging.getLogger(__name__)


class PPOTorchLearner(PPOLearner, TorchLearner):
"""Implements torch-specific PPO loss logic on top of PPOLearner.
Expand Down Expand Up @@ -62,17 +59,6 @@ def compute_loss_per_module(
if self.hps.kl_coeff > 0.0:
action_kl = prev_action_dist.kl(curr_action_dist)
mean_kl_loss = torch.mean(action_kl)
if mean_kl_loss.isinf():
logger.warning(
"KL divergence is non-finite, this will likely destabilize "
"your model and the training process. Action(s) in a "
"specific state have near-zero probability. "
"This can happen naturally in deterministic "
"environments where the optimal policy has zero mass "
"for a specific action. To fix this issue, consider "
"setting `kl_coeff` to 0.0 or increasing `entropy_coeff` in your "
"config."
)
else:
mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device)

Expand Down
Loading