forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lambdas.py
86 lines (66 loc) · 2.57 KB
/
lambdas.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import Any, Callable, List, Type
import numpy as np
import tree # dm_tree
from ray.rllib.connectors.connector import (
AgentConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import (
AgentConnectorDataType,
AgentConnectorsOutput,
)
from ray.util.annotations import PublicAPI
@PublicAPI(stability="alpha")
def register_lambda_agent_connector(
name: str, fn: Callable[[Any], Any]
) -> Type[AgentConnector]:
"""A util to register any simple transforming function as an AgentConnector
The only requirement is that fn should take a single data object and return
a single data object.
Args:
name: Name of the resulting actor connector.
fn: The function that transforms env / agent data.
Returns:
A new AgentConnector class that transforms data using fn.
"""
class LambdaAgentConnector(AgentConnector):
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
return AgentConnectorDataType(
ac_data.env_id, ac_data.agent_id, fn(ac_data.data)
)
def to_config(self):
return name, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return LambdaAgentConnector(ctx)
LambdaAgentConnector.__name__ = name
LambdaAgentConnector.__qualname__ = name
register_connector(name, LambdaAgentConnector)
return LambdaAgentConnector
@PublicAPI(stability="alpha")
def flatten_data(data: AgentConnectorsOutput):
assert isinstance(
data, AgentConnectorsOutput
), "Single agent data must be of type AgentConnectorsOutput"
for_training = data.for_training
for_action = data.for_action
flattened = {}
for k, v in for_action.items():
if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"):
# Do not flatten infos, actions, and state_out_ columns.
flattened[k] = v
continue
if v is None:
# Keep the same column shape.
flattened[k] = None
continue
flattened[k] = np.array(tree.flatten(v))
flattened = SampleBatch(flattened, is_training=False)
return AgentConnectorsOutput(for_training, flattened)
# Agent connector to build and return a flattened observation SampleBatch
# in addition to the original input dict.
FlattenDataAgentConnector = PublicAPI(stability="alpha")(
register_lambda_agent_connector("FlattenDataAgentConnector", flatten_data)
)