Skip to content

Commit

Permalink
[RLlib] checkpoint learner (ray-project#33598)
Browse files Browse the repository at this point in the history
Signed-off-by: Avnish <[email protected]>
  • Loading branch information
avnishn committed Apr 7, 2023
1 parent 4917e24 commit aa4007d
Show file tree
Hide file tree
Showing 15 changed files with 654 additions and 225 deletions.
8 changes: 4 additions & 4 deletions doc/source/rllib/package_ref/rl_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ Saving and Loading

~RLModule.get_state
~RLModule.set_state
~RLModule.save_state_to_file
~RLModule.load_state_from_file
~RLModule.save_state
~RLModule.load_state
~RLModule.save_to_checkpoint
~RLModule.from_checkpoint

Expand Down Expand Up @@ -132,5 +132,5 @@ Saving and Loading
.. autosummary::
:toctree: doc/

~MultiAgentRLModule.load_state_from_dir
~MultiAgentRLModule.save_state_to_dir
~MultiAgentRLModule.save_state
~MultiAgentRLModule.load_state
311 changes: 245 additions & 66 deletions rllib/core/learner/learner.py

Large diffs are not rendered by default.

16 changes: 0 additions & 16 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@

from ray.rllib.core.learner.reduce_result_dict_fn import _reduce_mean_results
from ray.rllib.core.rl_module.rl_module import (
RLModule,
ModuleID,
SingleAgentRLModuleSpec,
)
from ray.rllib.core.learner.learner import (
LearnerSpec,
ParamOptimizerPairs,
Optimizer,
)
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.actor_manager import FaultTolerantActorManager
Expand Down Expand Up @@ -305,36 +302,23 @@ def add_module(
*,
module_id: ModuleID,
module_spec: SingleAgentRLModuleSpec,
set_optimizer_fn: Optional[Callable[[RLModule], ParamOptimizerPairs]] = None,
optimizer_cls: Optional[Type[Optimizer]] = None,
) -> None:
"""Add a module to the Learners maintained by this LearnerGroup.
Args:
module_id: The id of the module to add.
module_spec: #TODO (Kourosh) fill in here.
set_optimizer_fn: A function that takes in the module and returns a list of
(param, optimizer) pairs. Each element in the tuple describes a
parameter group that share the same optimizer object, if None, the
default optimizer (obtained from the exiting optimizer dictionary) will
be used.
optimizer_cls: The optimizer class to use. If None, the set_optimizer_fn
should be provided.
"""
if self.is_local:
self._learner.add_module(
module_id=module_id,
module_spec=module_spec,
set_optimizer_fn=set_optimizer_fn,
optimizer_cls=optimizer_cls,
)
else:
results = self._worker_manager.foreach_actor(
lambda w: w.add_module(
module_id=module_id,
module_spec=module_spec,
set_optimizer_fn=set_optimizer_fn,
optimizer_cls=optimizer_cls,
)
)
return self._get_results(results)
Expand Down
103 changes: 89 additions & 14 deletions rllib/core/learner/tests/test_learner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import gymnasium as gym
import unittest
import tensorflow as tf
import numpy as np
import tensorflow as tf
import tempfile
import unittest

import ray

Expand All @@ -15,7 +16,7 @@
from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig


def get_learner() -> Learner:
def get_learner(learning_rate=1e-3) -> Learner:
env = gym.make("CartPole-v1")

learner = BCTfLearner(
Expand All @@ -25,7 +26,9 @@ def get_learner() -> Learner:
action_space=env.action_space,
model_config_dict={"fcnet_hiddens": [32]},
),
optimizer_config={"lr": 1e-3},
# made this a configurable hparam to avoid information leakage in tests where we
# need to know what the learning rate is.
optimizer_config={"lr": learning_rate},
learner_scaling_config=LearnerGroupScalingConfig(),
framework_hyperparameters=FrameworkHPs(eager_tracing=True),
)
Expand Down Expand Up @@ -112,15 +115,8 @@ def test_add_remove_module(self):
all variables the updated parameters follow the SGD update rule.
"""
env = gym.make("CartPole-v1")
learner = get_learner()

# add a test module with SGD optimizer with a known lr
lr = 1e-4

def set_optimizer_fn(module):
return [
(module.trainable_variables, tf.keras.optimizers.SGD(learning_rate=lr))
]
lr = 1e-3
learner = get_learner(lr)

learner.add_module(
module_id="test",
Expand All @@ -130,7 +126,6 @@ def set_optimizer_fn(module):
action_space=env.action_space,
model_config_dict={"fcnet_hiddens": [16]},
),
set_optimizer_fn=set_optimizer_fn,
)

learner.remove_module(DEFAULT_POLICY_ID)
Expand All @@ -150,6 +145,86 @@ def set_optimizer_fn(module):

check(params, expected)

def test_save_load_state(self):
env = gym.make("CartPole-v1")

learner1 = BCTfLearner(
module_spec=SingleAgentRLModuleSpec(
module_class=DiscreteBCTFModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config_dict={"fcnet_hiddens": [64]},
),
optimizer_config={"lr": 2e-3},
learner_scaling_config=LearnerGroupScalingConfig(),
framework_hyperparameters=FrameworkHPs(eager_tracing=True),
)

learner1.build()
with tempfile.TemporaryDirectory() as tmpdir:
learner1.save_state(tmpdir)

learner2 = BCTfLearner(
module_spec=SingleAgentRLModuleSpec(
module_class=DiscreteBCTFModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config_dict={"fcnet_hiddens": [32]},
),
optimizer_config={"lr": 1e-3},
learner_scaling_config=LearnerGroupScalingConfig(),
framework_hyperparameters=FrameworkHPs(eager_tracing=True),
)
learner2.build()
learner2.load_state(tmpdir)
self._check_learner_states(learner1, learner2)

# add a module then save/load and check states
with tempfile.TemporaryDirectory() as tmpdir:
learner1.add_module(
module_id="test",
module_spec=SingleAgentRLModuleSpec(
module_class=DiscreteBCTFModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config_dict={"fcnet_hiddens": [32]},
),
)
learner1.save_state(tmpdir)
learner2.load_state(tmpdir)
self._check_learner_states(learner1, learner2)

# remove a module then save/load and check states
with tempfile.TemporaryDirectory() as tmpdir:
learner1.remove_module(module_id=DEFAULT_POLICY_ID)
learner1.save_state(tmpdir)
learner2.load_state(tmpdir)
self._check_learner_states(learner1, learner2)

def _check_learner_states(self, learner1, learner2):
check(learner1.get_weights(), learner2.get_weights())

# check all internal optimizer state dictionaries have been updated
learner_1_optims_serialized = {
name: optim.get_config()
for name, optim in learner1._named_optimizers.items()
}
learner_2_optims_serialized = {
name: optim.get_config()
for name, optim in learner2._named_optimizers.items()
}
check(learner_1_optims_serialized, learner_2_optims_serialized)

learner_1_optims_serialized = [
optim.get_config() for optim in learner1._optimizer_parameters.keys()
]
learner_2_optims_serialized = [
optim.get_config() for optim in learner2._optimizer_parameters.keys()
]
check(learner_1_optims_serialized, learner_2_optims_serialized)

check(learner1._module_optimizers, learner2._module_optimizers)


if __name__ == "__main__":
import pytest
Expand Down
13 changes: 9 additions & 4 deletions rllib/core/learner/tests/test_learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,22 @@ class RemoteTrainingHelper:
def local_training_helper(self, fw, scaling_mode) -> None:
env = gym.make("CartPole-v1")
scaling_config = LOCAL_SCALING_CONFIGS[scaling_mode]
learner_group = get_learner_group(fw, env, scaling_config, eager_tracing=True)
local_learner = get_learner(fw, env)
lr = 1e-3
learner_group = get_learner_group(
fw, env, scaling_config, learning_rate=lr, eager_tracing=True
)
local_learner = get_learner(fw, env, learning_rate=lr)
local_learner.build()

# make the state of the learner and the local learner_group identical
local_learner.set_state(learner_group.get_state())

check(local_learner.get_state(), learner_group.get_state())
reader = get_cartpole_dataset_reader(batch_size=500)
batch = reader.next()
batch = batch.as_multi_agent()
check(local_learner.update(batch), learner_group.update(batch))
learner_update = local_learner.update(batch)
learner_group_update = learner_group.update(batch)
check(learner_update, learner_group_update)

new_module_id = "test_module"

Expand Down
Loading

0 comments on commit aa4007d

Please sign in to comment.