forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] Learner group checkpointing (ray-project#34379)
Implement multinode learner group checkpointing and tests. --------- Signed-off-by: Avnish <[email protected]> Signed-off-by: avnishn <[email protected]>
- Loading branch information
Showing
12 changed files
with
462 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions
126
release/rllib_tests/checkpointing_tests/test_learner_group_checkpointing.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import gymnasium as gym | ||
import itertools | ||
import numpy as np | ||
import tempfile | ||
import unittest | ||
|
||
import ray | ||
from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig | ||
from ray.rllib.core.testing.utils import get_learner_group | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.utils.test_utils import check | ||
|
||
|
||
FAKE_BATCH = { | ||
SampleBatch.OBS: np.array( | ||
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], | ||
dtype=np.float32, | ||
), | ||
SampleBatch.NEXT_OBS: np.array( | ||
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], | ||
dtype=np.float32, | ||
), | ||
SampleBatch.ACTIONS: np.array([0, 1, 1]), | ||
SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]), | ||
SampleBatch.REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32), | ||
SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32), | ||
SampleBatch.TERMINATEDS: np.array([False, False, True]), | ||
SampleBatch.TRUNCATEDS: np.array([False, False, False]), | ||
SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32), | ||
SampleBatch.ACTION_DIST_INPUTS: np.array( | ||
[[-2.0, 0.5], [-3.0, -0.3], [-0.1, 2.5]], dtype=np.float32 | ||
), | ||
SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32), | ||
SampleBatch.EPS_ID: np.array([0, 0, 0]), | ||
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]), | ||
} | ||
|
||
|
||
REMOTE_SCALING_CONFIGS = { | ||
"remote-cpu": LearnerGroupScalingConfig(num_workers=1), | ||
"remote-gpu": LearnerGroupScalingConfig(num_workers=1, num_gpus_per_worker=1), | ||
"multi-gpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_gpus_per_worker=1), | ||
"multi-cpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_cpus_per_worker=2), | ||
# "multi-gpu-ddp-pipeline": LearnerGroupScalingConfig( | ||
# num_workers=2, num_gpus_per_worker=2 | ||
# ), | ||
} | ||
|
||
|
||
class TestLearnerGroupCheckpointing(unittest.TestCase): | ||
def setUp(self) -> None: | ||
ray.init() | ||
|
||
def tearDown(self) -> None: | ||
ray.shutdown() | ||
|
||
def test_save_load_state(self): | ||
fws = ["tf", "torch"] | ||
scaling_modes = REMOTE_SCALING_CONFIGS.keys() | ||
test_iterator = itertools.product(fws, scaling_modes) | ||
|
||
batch = SampleBatch(FAKE_BATCH) | ||
for fw, scaling_mode in test_iterator: | ||
print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.") | ||
env = gym.make("CartPole-v1") | ||
|
||
scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode] | ||
initial_learner_group = get_learner_group( | ||
fw, env, scaling_config, eager_tracing=True | ||
) | ||
|
||
# checkpoint the initial learner state for later comparison | ||
initial_learner_checkpoint_dir = tempfile.TemporaryDirectory().name | ||
initial_learner_group.save_state(initial_learner_checkpoint_dir) | ||
initial_learner_group_weights = initial_learner_group.get_weights() | ||
|
||
# do a single update | ||
initial_learner_group.update(batch.as_multi_agent(), reduce_fn=None) | ||
|
||
# checkpoint the learner state after 1 update for later comparison | ||
learner_after_1_update_checkpoint_dir = tempfile.TemporaryDirectory().name | ||
initial_learner_group.save_state(learner_after_1_update_checkpoint_dir) | ||
|
||
# remove that learner, construct a new one, and load the state of the old | ||
# learner into the new one | ||
initial_learner_group.shutdown() | ||
del initial_learner_group | ||
new_learner_group = get_learner_group( | ||
fw, env, scaling_config, eager_tracing=True | ||
) | ||
new_learner_group.load_state(learner_after_1_update_checkpoint_dir) | ||
|
||
# do another update | ||
results_with_break = new_learner_group.update( | ||
batch.as_multi_agent(), reduce_fn=None | ||
) | ||
weights_after_1_update_with_break = new_learner_group.get_weights() | ||
new_learner_group.shutdown() | ||
del new_learner_group | ||
|
||
# construct a new learner group and load the initial state of the learner | ||
learner_group = get_learner_group( | ||
fw, env, scaling_config, eager_tracing=True | ||
) | ||
learner_group.load_state(initial_learner_checkpoint_dir) | ||
check(learner_group.get_weights(), initial_learner_group_weights) | ||
learner_group.update(batch.as_multi_agent(), reduce_fn=None) | ||
results_without_break = learner_group.update( | ||
batch.as_multi_agent(), reduce_fn=None | ||
) | ||
weights_after_1_update_without_break = learner_group.get_weights() | ||
learner_group.shutdown() | ||
del learner_group | ||
|
||
# compare the results of the two updates | ||
check(results_with_break, results_without_break) | ||
check( | ||
weights_after_1_update_with_break, weights_after_1_update_without_break | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
import pytest | ||
import sys | ||
|
||
sys.exit(pytest.main(["-v", __file__])) |
22 changes: 22 additions & 0 deletions
22
release/rllib_tests/multi_node_checkpointing_compute_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} | ||
region: us-west-2 | ||
|
||
max_workers: 3 | ||
|
||
head_node_type: | ||
name: head_node | ||
instance_type: m5.2xlarge | ||
|
||
worker_node_types: | ||
- name: worker_node | ||
instance_type: g4dn.xlarge | ||
min_workers: 2 | ||
max_workers: 2 | ||
use_spot: false | ||
|
||
aws: | ||
BlockDeviceMappings: | ||
- DeviceName: /dev/sda1 | ||
Ebs: | ||
DeleteOnTermination: true | ||
VolumeSize: 150 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.