Skip to content

Commit

Permalink
[RLlib] Checkpoint and restore connectors. (ray-project#26253)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jun Gong committed Jul 9, 2022
1 parent 7fcf0ad commit 0c469e4
Show file tree
Hide file tree
Showing 22 changed files with 784 additions and 139 deletions.
39 changes: 39 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3272,6 +3272,45 @@ py_test(
srcs = ["examples/bandit/tune_lin_ucb_train_recsim_env.py"],
)

py_test(
name = "examples/connectors/run_connector_policy",
main = "examples/connectors/run_connector_policy.py",
tags = ["team:rllib", "exclusive", "examples", ],
size = "small",
srcs = ["examples/connectors/run_connector_policy.py"],
data = [
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6",
],
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6"]
)

py_test(
name = "examples/connectors/adapt_connector_policy",
main = "examples/connectors/adapt_connector_policy.py",
tags = ["team:rllib", "exclusive", "examples", ],
size = "small",
srcs = ["examples/connectors/adapt_connector_policy.py"],
data = [
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6",
],
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6"]
)

py_test(
name = "examples/connectors/self_play_with_policy_checkpoint",
main = "examples/connectors/self_play_with_policy_checkpoint.py",
tags = ["team:rllib", "exclusive", "examples", ],
size = "small",
srcs = ["examples/connectors/self_play_with_policy_checkpoint.py"],
data = [
"tests/data/checkpoints/PPO_open_spiel_checkpoint-6",
],
args = [
"--checkpoint_file=tests/data/checkpoints/PPO_open_spiel_checkpoint-6",
"--train_iteration=1" # Smoke test.
]
)

# --------------------------------------------------------------------
# examples/documentation directory
#
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/action/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
self.connectors = connectors

def is_training(self, is_training: bool):
self.is_training = is_training
self._is_training = is_training
for c in self.connectors:
c.is_training(is_training)

Expand Down
16 changes: 9 additions & 7 deletions rllib/connectors/agent/lambdas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Type
from typing import Any, Callable, List, Type

import numpy as np
import tree # dm_tree
Expand All @@ -12,7 +12,6 @@
from ray.rllib.utils.typing import (
AgentConnectorDataType,
AgentConnectorsOutput,
TensorStructType,
)
from ray.util.annotations import PublicAPI

Expand Down Expand Up @@ -56,13 +55,16 @@ def from_config(ctx: ConnectorContext, params: List[Any]):


@PublicAPI(stability="alpha")
def flatten_data(data: Dict[str, TensorStructType]):
def flatten_data(data: AgentConnectorsOutput):
assert isinstance(
data, dict
), "Single agent data must be of type Dict[str, TensorStructType]"
data, AgentConnectorsOutput
), "Single agent data must be of type AgentConnectorsOutput"

for_training = data.for_training
for_action = data.for_action

flattened = {}
for k, v in data.items():
for k, v in for_action.items():
if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"):
# Do not flatten infos, actions, and state_out_ columns.
flattened[k] = v
Expand All @@ -74,7 +76,7 @@ def flatten_data(data: Dict[str, TensorStructType]):
flattened[k] = np.array(tree.flatten(v))
flattened = SampleBatch(flattened, is_training=False)

return AgentConnectorsOutput(data, flattened)
return AgentConnectorsOutput(for_training, flattened)


# Agent connector to build and return a flattened observation SampleBatch
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
self.connectors = connectors

def is_training(self, is_training: bool):
self.is_training = is_training
self._is_training = is_training
for c in self.connectors:
c.is_training(is_training)

Expand Down
1 change: 1 addition & 0 deletions rllib/connectors/agent/state_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:

action, states, fetches = self._states[env_id][agent_id]

# TODO(jungong): Support buffering more than 1 prev actions.
if action is not None:
d[SampleBatch.ACTIONS] = action # Last action
else:
Expand Down
42 changes: 18 additions & 24 deletions rllib/connectors/agent/view_requirement.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from collections import defaultdict
from typing import Any, List

import numpy as np

from ray.rllib.connectors.connector import (
AgentConnector,
ConnectorContext,
Expand Down Expand Up @@ -45,25 +43,18 @@ def _get_sample_batch_for_action(
) -> SampleBatch:
# TODO(jungong) : actually support buildling input sample batch with all the
# view shift requirements, etc.
# For now, we use some simple logics for demo purpose.
# For now, we only support last elemen (no shift).
input_dict = {}
for k, v in view_requirements.items():
if not v.used_for_compute_actions:
for col, req in view_requirements.items():
if not req.used_for_compute_actions:
continue
data_col = v.data_col or k
if data_col not in agent_batch:
if col not in agent_batch:
continue
input_dict[k] = agent_batch[data_col][-1:]
input_dict[col] = agent_batch[col][-1]
return SampleBatch(input_dict, is_training=False)

def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
assert isinstance(ac_data.data, AgentConnectorsOutput), (
"ViewRequirementAgentConnector operates on raw input dict and its"
"flattened SampleBatch."
)

d = ac_data.data.for_training
f = ac_data.data.for_action
d = ac_data.data
assert (
type(d) == dict
), "Single agent data must be of type Dict[str, TensorStructType]"
Expand All @@ -79,8 +70,7 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
assert vr, "ViewRequirements required by ViewRequirementConnector"

training_dict = None
# We construct a proper per-timeslice dict in training mode,
# for env runner to construct a complete episode.
# Return full training_dict for env runner to construct episodes.
if self.is_training:
# Note(jungong) : we need to keep the entire input dict here.
# A column may be used by postprocessing (GAE) even if its
Expand All @@ -106,14 +96,18 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
if data_col not in d:
continue

if col in agent_batch:
# Stack along batch dim.
agent_batch[col] = np.vstack((agent_batch[col], f[data_col]))
else:
agent_batch[col] = f[data_col]
if col not in agent_batch:
agent_batch[col] = []
# Stack along batch dim.
agent_batch[col].append(d[data_col])

# Only keep the useful part of the history.
h = req.shift_from if req.shift_from else -1
assert h <= 0, "Can use future data to compute action"
h = -1
if req.shift_from is not None:
h = req.shift_from
elif type(req.shift) == int:
h = req.shift
assert h <= 0, "Cannot use future data to compute action"
agent_batch[col] = agent_batch[col][h:]

sample_batch = self._get_sample_batch_for_action(vr, agent_batch)
Expand Down
26 changes: 24 additions & 2 deletions rllib/connectors/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ class Connector(abc.ABC):

def __init__(self, ctx: ConnectorContext):
# This gets flipped to False for inference.
self.is_training = True
self._is_training = True

def is_training(self, is_training: bool):
self.is_training = is_training
self._is_training = is_training

def __str__(self, indentation: int = 0):
return " " * indentation + self.__class__.__name__
Expand Down Expand Up @@ -325,6 +325,8 @@ def remove(self, name: str):
raise ValueError(f"Can not find connector {name}")
del self.connectors[idx]

logger.info(f"Removed connector {name} from {self.__class__.__name__}.")

def insert_before(self, name: str, connector: Connector):
"""Insert a new connector before connector <name>
Expand All @@ -341,6 +343,11 @@ def insert_before(self, name: str, connector: Connector):
raise ValueError(f"Can not find connector {name}")
self.connectors.insert(idx, connector)

logger.info(
f"Inserted {connector.__class__.__name__} before {name} "
f"to {self.__class__.__name__}."
)

def insert_after(self, name: str, connector: Connector):
"""Insert a new connector after connector <name>
Expand All @@ -357,6 +364,11 @@ def insert_after(self, name: str, connector: Connector):
raise ValueError(f"Can not find connector {name}")
self.connectors.insert(idx + 1, connector)

logger.info(
f"Inserted {connector.__class__.__name__} after {name} "
f"to {self.__class__.__name__}."
)

def prepend(self, connector: Connector):
"""Append a new connector at the beginning of a connector pipeline.
Expand All @@ -365,6 +377,11 @@ def prepend(self, connector: Connector):
"""
self.connectors.insert(0, connector)

logger.info(
f"Added {connector.__class__.__name__} to the beginning of "
f"{self.__class__.__name__}."
)

def append(self, connector: Connector):
"""Append a new connector at the end of a connector pipeline.
Expand All @@ -373,6 +390,11 @@ def append(self, connector: Connector):
"""
self.connectors.append(connector)

logger.info(
f"Added {connector.__class__.__name__} to the end of "
f"{self.__class__.__name__}."
)

def __str__(self, indentation: int = 0):
return "\n".join(
[" " * indentation + self.__class__.__name__]
Expand Down
60 changes: 50 additions & 10 deletions rllib/connectors/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
from ray.rllib.connectors.connector import ConnectorContext, get_connector
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import AgentConnectorDataType
from ray.rllib.utils.typing import AgentConnectorDataType, AgentConnectorsOutput


class TestAgentConnector(unittest.TestCase):
Expand Down Expand Up @@ -90,18 +92,22 @@ def test_flatten_data_connector(self):
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, FlattenDataAgentConnector))

sample_batch = {
SampleBatch.NEXT_OBS: {
"sensor1": [[1, 1], [2, 2]],
"sensor2": 8.8,
},
SampleBatch.REWARDS: 5.8,
SampleBatch.ACTIONS: [[1, 1], [2]],
SampleBatch.INFOS: {"random": "info"},
}

d = AgentConnectorDataType(
0,
1,
{
SampleBatch.NEXT_OBS: {
"sensor1": [[1, 1], [2, 2]],
"sensor2": 8.8,
},
SampleBatch.REWARDS: 5.8,
SampleBatch.ACTIONS: [[1, 1], [2]],
SampleBatch.INFOS: {"random": "info"},
},
# FlattenDataAgentConnector does NOT touch for_training dict,
# so simply pass None here.
AgentConnectorsOutput(None, sample_batch),
)

flattened = c([d])
Expand All @@ -114,6 +120,40 @@ def test_flatten_data_connector(self):
self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2)
self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")

def test_view_requirement_connector(self):
view_requirements = {
"obs": ViewRequirement(
used_for_training=True, used_for_compute_actions=True
),
"prev_actions": ViewRequirement(
data_col="actions",
shift=-1,
used_for_training=True,
used_for_compute_actions=True,
),
}
ctx = ConnectorContext(view_requirements=view_requirements)

c = ViewRequirementAgentConnector(ctx)
f = FlattenDataAgentConnector(ctx)

d = AgentConnectorDataType(
0,
1,
{
SampleBatch.NEXT_OBS: {
"sensor1": [[1, 1], [2, 2]],
"sensor2": 8.8,
},
SampleBatch.ACTIONS: np.array(0),
},
)
# ViewRequirementAgentConnector then FlattenAgentConnector.
processed = f(c([d]))

self.assertTrue("obs" in processed[0].data.for_action)
self.assertTrue("prev_actions" in processed[0].data.for_action)


if __name__ == "__main__":
import sys
Expand Down
Loading

0 comments on commit 0c469e4

Please sign in to comment.