Skip to content

Commit

Permalink
[RLlib] Learner group checkpointing (ray-project#34379)
Browse files Browse the repository at this point in the history
Implement multinode learner group checkpointing and tests.

---------

Signed-off-by: Avnish <[email protected]>
Signed-off-by: avnishn <[email protected]>
  • Loading branch information
avnishn committed Apr 18, 2023
1 parent 6843408 commit 4995e14
Show file tree
Hide file tree
Showing 12 changed files with 462 additions and 30 deletions.
21 changes: 21 additions & 0 deletions release/release_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2995,6 +2995,27 @@
# RLlib tests
########################

- name: rllib_learner_group_checkpointing_multinode
group: RLlib tests
working_dir: rllib_tests

frequency: nightly
team: rllib

cluster:
cluster_env: app_config.yaml
cluster_compute: multi_node_checkpointing_compute_config.yaml

run:
timeout: 3600
script: pytest checkpointing_tests/test_learner_group_checkpointing.py

wait_for_nodes:
num_nodes: 3

alert: default


- name: rllib_learning_tests_a2c_tf
group: RLlib tests
working_dir: rllib_tests
Expand Down
4 changes: 4 additions & 0 deletions release/rllib_tests/app_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ python:
# so we built it for py3 and use that instead. This wheel was tested for python 3.7, 3.8,
# and 3.9.
- https://ray-ci-deps-wheels.s3.us-west-2.amazonaws.com/AutoROM.accept_rom_license-0.5.4-py3-none-any.whl
- pytest
conda_packages: []

post_build_cmds:
Expand All @@ -41,3 +42,6 @@ post_build_cmds:
- mv mujoco210-linux-x86_64.tar.gz ~/.mujoco/.
- cd ~/.mujoco
- tar -xf ~/.mujoco/mujoco210-linux-x86_64.tar.gz

# not strictly necessary, but makes debugging easier
- git clone https://github.com/ray-project/ray.git
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import gymnasium as gym
import itertools
import numpy as np
import tempfile
import unittest

import ray
from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig
from ray.rllib.core.testing.utils import get_learner_group
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.test_utils import check


FAKE_BATCH = {
SampleBatch.OBS: np.array(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
dtype=np.float32,
),
SampleBatch.NEXT_OBS: np.array(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
dtype=np.float32,
),
SampleBatch.ACTIONS: np.array([0, 1, 1]),
SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]),
SampleBatch.REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32),
SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32),
SampleBatch.TERMINATEDS: np.array([False, False, True]),
SampleBatch.TRUNCATEDS: np.array([False, False, False]),
SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32),
SampleBatch.ACTION_DIST_INPUTS: np.array(
[[-2.0, 0.5], [-3.0, -0.3], [-0.1, 2.5]], dtype=np.float32
),
SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32),
SampleBatch.EPS_ID: np.array([0, 0, 0]),
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
}


REMOTE_SCALING_CONFIGS = {
"remote-cpu": LearnerGroupScalingConfig(num_workers=1),
"remote-gpu": LearnerGroupScalingConfig(num_workers=1, num_gpus_per_worker=1),
"multi-gpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_gpus_per_worker=1),
"multi-cpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_cpus_per_worker=2),
# "multi-gpu-ddp-pipeline": LearnerGroupScalingConfig(
# num_workers=2, num_gpus_per_worker=2
# ),
}


class TestLearnerGroupCheckpointing(unittest.TestCase):
def setUp(self) -> None:
ray.init()

def tearDown(self) -> None:
ray.shutdown()

def test_save_load_state(self):
fws = ["tf", "torch"]
scaling_modes = REMOTE_SCALING_CONFIGS.keys()
test_iterator = itertools.product(fws, scaling_modes)

batch = SampleBatch(FAKE_BATCH)
for fw, scaling_mode in test_iterator:
print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.")
env = gym.make("CartPole-v1")

scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode]
initial_learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)

# checkpoint the initial learner state for later comparison
initial_learner_checkpoint_dir = tempfile.TemporaryDirectory().name
initial_learner_group.save_state(initial_learner_checkpoint_dir)
initial_learner_group_weights = initial_learner_group.get_weights()

# do a single update
initial_learner_group.update(batch.as_multi_agent(), reduce_fn=None)

# checkpoint the learner state after 1 update for later comparison
learner_after_1_update_checkpoint_dir = tempfile.TemporaryDirectory().name
initial_learner_group.save_state(learner_after_1_update_checkpoint_dir)

# remove that learner, construct a new one, and load the state of the old
# learner into the new one
initial_learner_group.shutdown()
del initial_learner_group
new_learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
new_learner_group.load_state(learner_after_1_update_checkpoint_dir)

# do another update
results_with_break = new_learner_group.update(
batch.as_multi_agent(), reduce_fn=None
)
weights_after_1_update_with_break = new_learner_group.get_weights()
new_learner_group.shutdown()
del new_learner_group

# construct a new learner group and load the initial state of the learner
learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
learner_group.load_state(initial_learner_checkpoint_dir)
check(learner_group.get_weights(), initial_learner_group_weights)
learner_group.update(batch.as_multi_agent(), reduce_fn=None)
results_without_break = learner_group.update(
batch.as_multi_agent(), reduce_fn=None
)
weights_after_1_update_without_break = learner_group.get_weights()
learner_group.shutdown()
del learner_group

# compare the results of the two updates
check(results_with_break, results_without_break)
check(
weights_after_1_update_with_break, weights_after_1_update_without_break
)


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
22 changes: 22 additions & 0 deletions release/rllib_tests/multi_node_checkpointing_compute_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
cloud_id: {{env["ANYSCALE_CLOUD_ID"]}}
region: us-west-2

max_workers: 3

head_node_type:
name: head_node
instance_type: m5.2xlarge

worker_node_types:
- name: worker_node
instance_type: g4dn.xlarge
min_workers: 2
max_workers: 2
use_spot: false

aws:
BlockDeviceMappings:
- DeviceName: /dev/sda1
Ebs:
DeleteOnTermination: true
VolumeSize: 150
132 changes: 132 additions & 0 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections import deque
import pathlib
import socket
from typing import Any, List, Mapping, Type, Optional, Callable, Set, TYPE_CHECKING

import ray
Expand All @@ -17,6 +19,8 @@
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.numpy import convert_to_numpy
from ray.train._internal.backend_executor import BackendExecutor
from ray.tune.utils.file_transfer import sync_dir_between_nodes


if TYPE_CHECKING:
from ray.rllib.core.learner.learner import Learner
Expand Down Expand Up @@ -404,6 +408,134 @@ def set_is_module_trainable(
if is_module_trainable is not None:
self._is_module_trainable = is_module_trainable

def save_state(self, path: str) -> None:
"""Saves the state of the LearnerGroup.
Args:
path: The path to save the state to.
"""
if self.is_local:
self._learner.save_state(path)
else:
worker = self._worker_manager.healthy_actor_ids()[0]
worker_ip_addr = self._worker_manager.foreach_actor(
self._get_ip_address, remote_actor_ids=[worker]
)
worker_ip_addr = self._get_results(worker_ip_addr)[0]
self_ip_addr = self._get_ip_address()

if worker_ip_addr == self_ip_addr:
self._worker_manager.foreach_actor(
lambda w: w.save_state(path), remote_actor_ids=[worker]
)
else:
# save the checkpoint to a temporary location on the worker

# create a temporary directory on the worker
worker_temp_dir = self._worker_manager.foreach_actor(
self._create_temporary_dir, remote_actor_ids=[worker]
)
worker_temp_dir = self._get_results(worker_temp_dir)[0]

# save the checkpoint to the temporary directory on the worker
self._worker_manager.foreach_actor(
lambda w: w.save_state(worker_temp_dir), remote_actor_ids=[worker]
)

# sync the temporary directory on the worker to the local directory
sync_dir_between_nodes(
worker_ip_addr, worker_temp_dir, self_ip_addr, path
)

# creating this function here instead of making it a member funciton
# becasue it uses the worker_temp_dir variable, and this can't
# be passed in as an argument to foreach_actor
def remove_dir(w):
import shutil

shutil.rmtree(worker_temp_dir)

# remove the temporary directory on the worker
self._worker_manager.foreach_actor(
remove_dir, remote_actor_ids=[worker]
)

def load_state(self, path: str) -> None:
"""Loads the state of the LearnerGroup.
Args:
path: The path to load the state from.
"""
path = pathlib.Path(path)
if not path.is_dir():
raise ValueError(
f"Path {path} is not a directory. "
"Please specify a directory containing the checkpoint files."
)
if not path.exists():
raise ValueError(f"Path {path} does not exist.")
path = str(path.absolute())
assert len(self._workers) == self._worker_manager.num_healthy_actors()
if self.is_local:
self._learner.load_state(path)
else:
head_node_ip = socket.gethostbyname(socket.gethostname())
workers = self._worker_manager.healthy_actor_ids()

def _load_state(w):
# doing imports here since they might not be imported on the worker
import socket
import tempfile

hostname = socket.gethostname()
worker_node_ip = socket.gethostbyname(hostname)
# if the worker is on the same node as the head, load the checkpoint
# directly from the path otherwise sync the checkpoint from the head
# to the worker and load it from there
if worker_node_ip == head_node_ip:
w.load_state(path)
else:
with tempfile.TemporaryDirectory() as temp_dir:
sync_dir_between_nodes(
head_node_ip, path, worker_node_ip, temp_dir
)
w.load_state(temp_dir)

self._worker_manager.foreach_actor(_load_state, remote_actor_ids=workers)

@staticmethod
def _create_temporary_dir(_=None) -> str:
"""Creates a temporary directory.
Args:
_: Unused arg. Exists to make this function compatible with foreach_actor
calls.
Returns:
The path to the temporary directory.
"""
import tempfile

return tempfile.mkdtemp()

@staticmethod
def _get_ip_address(_=None) -> str:
"""Returns this process's address.
Args:
_: Unused arg. Exists to make this function compatible with foreach_actor
calls.
Returns:
The address of this process.
"""
import socket

hostname = socket.gethostname()

return socket.gethostbyname(hostname)

def shutdown(self):
"""Shuts down the LearnerGroup."""
if not self._is_local:
Expand Down
Loading

0 comments on commit 4995e14

Please sign in to comment.