-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
policy.py
1839 lines (1608 loc) · 75.2 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import json
import logging
import os
import platform
from abc import ABCMeta, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)
import gymnasium as gym
import numpy as np
import tree # pip install dm_tree
from gymnasium.spaces import Box
from packaging import version
import ray
import ray.cloudpickle as pickle
from ray.actor import ActorHandle
from ray.train import Checkpoint
from ray.rllib.core.columns import Columns
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import (
OldAPIStack,
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
is_overridden,
)
from ray.rllib.utils.checkpoints import (
CHECKPOINT_VERSION,
get_checkpoint_info,
try_import_msgpack,
)
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
deprecation_warning,
)
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.serialization import (
deserialize_type,
space_from_dict,
space_to_dict,
)
from ray.rllib.utils.spaces.space_utils import (
get_base_struct_from_space,
get_dummy_batch_for_space,
unbatch,
)
from ray.rllib.utils.tensor_dtype import get_np_dtype
from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
from ray.rllib.utils.typing import (
AgentID,
AlgorithmConfigDict,
ModelGradients,
ModelWeights,
PolicyID,
PolicyState,
T,
TensorStructType,
TensorType,
)
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
if TYPE_CHECKING:
from ray.rllib.evaluation import Episode
from ray.rllib.core.rl_module import RLModule
logger = logging.getLogger(__name__)
@OldAPIStack
class PolicySpec:
"""A policy spec used in the "config.multiagent.policies" specification dict.
As values (keys are the policy IDs (str)). E.g.:
config:
multiagent:
policies: {
"pol1": PolicySpec(None, Box, Discrete(2), {"lr": 0.0001}),
"pol2": PolicySpec(config={"lr": 0.001}),
}
"""
def __init__(
self, policy_class=None, observation_space=None, action_space=None, config=None
):
# If None, use the Algorithm's default policy class stored under
# `Algorithm._policy_class`.
self.policy_class = policy_class
# If None, use the env's observation space. If None and there is no Env
# (e.g. offline RL), an error is thrown.
self.observation_space = observation_space
# If None, use the env's action space. If None and there is no Env
# (e.g. offline RL), an error is thrown.
self.action_space = action_space
# Overrides defined keys in the main Algorithm config.
# If None, use {}.
self.config = config
def __eq__(self, other: "PolicySpec"):
return (
self.policy_class == other.policy_class
and self.observation_space == other.observation_space
and self.action_space == other.action_space
and self.config == other.config
)
def serialize(self) -> Dict:
from ray.rllib.algorithms.registry import get_policy_class_name
# Try to figure out a durable name for this policy.
cls = get_policy_class_name(self.policy_class)
if cls is None:
logger.warning(
f"Can not figure out a durable policy name for {self.policy_class}. "
f"You are probably trying to checkpoint a custom policy. "
f"Raw policy class may cause problems when the checkpoint needs to "
"be loaded in the future. To fix this, make sure you add your "
"custom policy in rllib.algorithms.registry.POLICIES."
)
cls = self.policy_class
return {
"policy_class": cls,
"observation_space": space_to_dict(self.observation_space),
"action_space": space_to_dict(self.action_space),
# TODO(jungong) : try making the config dict durable by maybe
# getting rid of all the fields that are not JSON serializable.
"config": self.config,
}
@classmethod
def deserialize(cls, spec: Dict) -> "PolicySpec":
if isinstance(spec["policy_class"], str):
# Try to recover the actual policy class from durable name.
from ray.rllib.algorithms.registry import get_policy_class
policy_class = get_policy_class(spec["policy_class"])
elif isinstance(spec["policy_class"], type):
# Policy spec is already a class type. Simply use it.
policy_class = spec["policy_class"]
else:
raise AttributeError(f"Unknown policy class spec {spec['policy_class']}")
return cls(
policy_class=policy_class,
observation_space=space_from_dict(spec["observation_space"]),
action_space=space_from_dict(spec["action_space"]),
config=spec["config"],
)
@OldAPIStack
class Policy(metaclass=ABCMeta):
"""RLlib's base class for all Policy implementations.
Policy is the abstract superclass for all DL-framework specific sub-classes
(e.g. TFPolicy or TorchPolicy). It exposes APIs to
1. Compute actions from observation (and possibly other) inputs.
2. Manage the Policy's NN model(s), like exporting and loading their weights.
3. Postprocess a given trajectory from the environment or other input via the
`postprocess_trajectory` method.
4. Compute losses from a train batch.
5. Perform updates from a train batch on the NN-models (this normally includes loss
calculations) either:
a. in one monolithic step (`learn_on_batch`)
b. via batch pre-loading, then n steps of actual loss computations and updates
(`load_batch_into_buffer` + `learn_on_loaded_batch`).
"""
def __init__(
self,
observation_space: gym.Space,
action_space: gym.Space,
config: AlgorithmConfigDict,
):
"""Initializes a Policy instance.
Args:
observation_space: Observation space of the policy.
action_space: Action space of the policy.
config: A complete Algorithm/Policy config dict. For the default
config keys and values, see rllib/algorithm/algorithm.py.
"""
self.observation_space: gym.Space = observation_space
self.action_space: gym.Space = action_space
# the policy id in the global context.
self.__policy_id = config.get("__policy_id")
# The base struct of the observation/action spaces.
# E.g. action-space = gym.spaces.Dict({"a": Discrete(2)}) ->
# action_space_struct = {"a": Discrete(2)}
self.observation_space_struct = get_base_struct_from_space(observation_space)
self.action_space_struct = get_base_struct_from_space(action_space)
self.config: AlgorithmConfigDict = config
self.framework = self.config.get("framework")
# Create the callbacks object to use for handling custom callbacks.
from ray.rllib.algorithms.callbacks import DefaultCallbacks
callbacks = self.config.get("callbacks")
if isinstance(callbacks, DefaultCallbacks):
self.callbacks = callbacks()
elif isinstance(callbacks, (str, type)):
try:
self.callbacks: "DefaultCallbacks" = deserialize_type(
self.config.get("callbacks")
)()
except Exception:
pass # TEST
else:
self.callbacks: "DefaultCallbacks" = DefaultCallbacks()
# The global timestep, broadcast down from time to time from the
# local worker to all remote workers.
self.global_timestep: int = 0
# The number of gradient updates this policy has undergone.
self.num_grad_updates: int = 0
# The action distribution class to use for action sampling, if any.
# Child classes may set this.
self.dist_class: Optional[Type] = None
# Initialize view requirements.
self.init_view_requirements()
# Whether the Model's initial state (method) has been added
# automatically based on the given view requirements of the model.
self._model_init_state_automatically_added = False
# Connectors.
self.agent_connectors = None
self.action_connectors = None
@staticmethod
def from_checkpoint(
checkpoint: Union[str, Checkpoint],
policy_ids: Optional[Collection[PolicyID]] = None,
) -> Union["Policy", Dict[PolicyID, "Policy"]]:
"""Creates new Policy instance(s) from a given Policy or Algorithm checkpoint.
Note: This method must remain backward compatible from 2.1.0 on, wrt.
checkpoints created with Ray 2.0.0 or later.
Args:
checkpoint: The path (str) to a Policy or Algorithm checkpoint directory
or an AIR Checkpoint (Policy or Algorithm) instance to restore
from.
If checkpoint is a Policy checkpoint, `policy_ids` must be None
and only the Policy in that checkpoint is restored and returned.
If checkpoint is an Algorithm checkpoint and `policy_ids` is None,
will return a list of all Policy objects found in
the checkpoint, otherwise a list of those policies in `policy_ids`.
policy_ids: List of policy IDs to extract from a given Algorithm checkpoint.
If None and an Algorithm checkpoint is provided, will restore all
policies found in that checkpoint. If a Policy checkpoint is given,
this arg must be None.
Returns:
An instantiated Policy, if `checkpoint` is a Policy checkpoint. A dict
mapping PolicyID to Policies, if `checkpoint` is an Algorithm checkpoint.
In the latter case, returns all policies within the Algorithm if
`policy_ids` is None, else a dict of only those Policies that are in
`policy_ids`.
"""
checkpoint_info = get_checkpoint_info(checkpoint)
# Algorithm checkpoint: Extract one or more policies from it and return them
# in a dict (mapping PolicyID to Policy instances).
if checkpoint_info["type"] == "Algorithm":
from ray.rllib.algorithms.algorithm import Algorithm
policies = {}
# Old Algorithm checkpoints: State must be completely retrieved from:
# algo state file -> worker -> "state".
if checkpoint_info["checkpoint_version"] < version.Version("1.0"):
with open(checkpoint_info["state_file"], "rb") as f:
state = pickle.load(f)
# In older checkpoint versions, the policy states are stored under
# "state" within the worker state (which is pickled in itself).
worker_state = pickle.loads(state["worker"])
policy_states = worker_state["state"]
for pid, policy_state in policy_states.items():
# Get spec and config, merge config with
serialized_policy_spec = worker_state["policy_specs"][pid]
policy_config = Algorithm.merge_algorithm_configs(
worker_state["policy_config"], serialized_policy_spec["config"]
)
serialized_policy_spec.update({"config": policy_config})
policy_state.update({"policy_spec": serialized_policy_spec})
policies[pid] = Policy.from_state(policy_state)
# Newer versions: Get policy states from "policies/" sub-dirs.
elif checkpoint_info["policy_ids"] is not None:
for policy_id in checkpoint_info["policy_ids"]:
if policy_ids is None or policy_id in policy_ids:
policy_checkpoint_info = get_checkpoint_info(
os.path.join(
checkpoint_info["checkpoint_dir"],
"policies",
policy_id,
)
)
assert policy_checkpoint_info["type"] == "Policy"
with open(policy_checkpoint_info["state_file"], "rb") as f:
policy_state = pickle.load(f)
policies[policy_id] = Policy.from_state(policy_state)
return policies
# Policy checkpoint: Return a single Policy instance.
else:
msgpack = None
if checkpoint_info.get("format") == "msgpack":
msgpack = try_import_msgpack(error=True)
with open(checkpoint_info["state_file"], "rb") as f:
if msgpack is not None:
state = msgpack.load(f)
else:
state = pickle.load(f)
return Policy.from_state(state)
@staticmethod
def from_state(state: PolicyState) -> "Policy":
"""Recovers a Policy from a state object.
The `state` of an instantiated Policy can be retrieved by calling its
`get_state` method. This only works for the V2 Policy classes (EagerTFPolicyV2,
SynamicTFPolicyV2, and TorchPolicyV2). It contains all information necessary
to create the Policy. No access to the original code (e.g. configs, knowledge of
the policy's class, etc..) is needed.
Args:
state: The state to recover a new Policy instance from.
Returns:
A new Policy instance.
"""
serialized_pol_spec: Optional[dict] = state.get("policy_spec")
if serialized_pol_spec is None:
raise ValueError(
"No `policy_spec` key was found in given `state`! "
"Cannot create new Policy."
)
pol_spec = PolicySpec.deserialize(serialized_pol_spec)
actual_class = get_tf_eager_cls_if_necessary(
pol_spec.policy_class,
pol_spec.config,
)
if pol_spec.config["framework"] == "tf":
from ray.rllib.policy.tf_policy import TFPolicy
return TFPolicy._tf1_from_state_helper(state)
# Create the new policy.
new_policy = actual_class(
# Note(jungong) : we are intentionally not using keyward arguments here
# because some policies name the observation space parameter obs_space,
# and some others name it observation_space.
pol_spec.observation_space,
pol_spec.action_space,
pol_spec.config,
)
# Set the new policy's state (weights, optimizer vars, exploration state,
# etc..).
new_policy.set_state(state)
# Return the new policy.
return new_policy
@OverrideToImplementCustomLogic
def make_rl_module(self) -> "RLModule":
"""Returns the RL Module (only for when RLModule API is enabled.)
If RLModule API is enabled
(self.config.api_stack(enable_rl_module_and_learner=True), this method should be
implemented and should return the RLModule instance to use for this Policy.
Otherwise, RLlib will error out.
"""
# if imported on top it creates circular dependency
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
if self.__policy_id is None:
raise ValueError(
"When using RLModule API, `policy_id` within the policies must be "
"set. This should have happened automatically. If you see this "
"bug, please file a github issue."
)
spec = self.config["__multi_rl_module_spec"]
if isinstance(spec, RLModuleSpec):
module = spec.build()
else:
# filter the module_spec to only contain the policy_id of this policy
marl_spec = type(spec)(
multi_rl_module_class=spec.multi_rl_module_class,
module_specs={self.__policy_id: spec.module_specs[self.__policy_id]},
)
multi_rl_module = marl_spec.build()
module = multi_rl_module[self.__policy_id]
return module
def init_view_requirements(self):
"""Maximal view requirements dict for `learn_on_batch()` and
`compute_actions` calls.
Specific policies can override this function to provide custom
list of view requirements.
"""
# Maximal view requirements dict for `learn_on_batch()` and
# `compute_actions` calls.
# View requirements will be automatically filtered out later based
# on the postprocessing and loss functions to ensure optimal data
# collection and transfer performance.
view_reqs = self._get_default_view_requirements()
if not hasattr(self, "view_requirements"):
self.view_requirements = view_reqs
else:
for k, v in view_reqs.items():
if k not in self.view_requirements:
self.view_requirements[k] = v
def get_connector_metrics(self) -> Dict:
"""Get metrics on timing from connectors."""
return {
"agent_connectors": {
name + "_ms": 1000 * timer.mean
for name, timer in self.agent_connectors.timers.items()
},
"action_connectors": {
name + "_ms": 1000 * timer.mean
for name, timer in self.agent_connectors.timers.items()
},
}
def reset_connectors(self, env_id) -> None:
"""Reset action- and agent-connectors for this policy."""
self.agent_connectors.reset(env_id=env_id)
self.action_connectors.reset(env_id=env_id)
def compute_single_action(
self,
obs: Optional[TensorStructType] = None,
state: Optional[List[TensorType]] = None,
*,
prev_action: Optional[TensorStructType] = None,
prev_reward: Optional[TensorStructType] = None,
info: dict = None,
input_dict: Optional[SampleBatch] = None,
episode: Optional["Episode"] = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
# Kwars placeholder for future compatibility.
**kwargs,
) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
"""Computes and returns a single (B=1) action value.
Takes an input dict (usually a SampleBatch) as its main data input.
This allows for using this method in case a more complex input pattern
(view requirements) is needed, for example when the Model requires the
last n observations, the last m actions/rewards, or a combination
of any of these.
Alternatively, in case no complex inputs are required, takes a single
`obs` values (and possibly single state values, prev-action/reward
values, etc..).
Args:
obs: Single observation.
state: List of RNN state inputs, if any.
prev_action: Previous action value, if any.
prev_reward: Previous reward, if any.
info: Info object, if any.
input_dict: A SampleBatch or input dict containing the
single (unbatched) Tensors to compute actions. If given, it'll
be used instead of `obs`, `state`, `prev_action|reward`, and
`info`.
episode: This provides access to all of the internal episode state,
which may be useful for model-based or multi-agent algorithms.
explore: Whether to pick an exploitation or
exploration action
(default: None -> use self.config["explore"]).
timestep: The current (sampling) time step.
Keyword Args:
kwargs: Forward compatibility placeholder.
Returns:
Tuple consisting of the action, the list of RNN state outputs (if
any), and a dictionary of extra features (if any).
"""
# Build the input-dict used for the call to
# `self.compute_actions_from_input_dict()`.
if input_dict is None:
input_dict = {SampleBatch.OBS: obs}
if state is not None:
if self.config.get("enable_rl_module_and_learner", False):
input_dict["state_in"] = state
else:
for i, s in enumerate(state):
input_dict[f"state_in_{i}"] = s
if prev_action is not None:
input_dict[SampleBatch.PREV_ACTIONS] = prev_action
if prev_reward is not None:
input_dict[SampleBatch.PREV_REWARDS] = prev_reward
if info is not None:
input_dict[SampleBatch.INFOS] = info
# Batch all data in input dict.
input_dict = tree.map_structure_with_path(
lambda p, s: (
s
if p == "seq_lens"
else s.unsqueeze(0)
if torch and isinstance(s, torch.Tensor)
else np.expand_dims(s, 0)
),
input_dict,
)
episodes = None
if episode is not None:
episodes = [episode]
out = self.compute_actions_from_input_dict(
input_dict=SampleBatch(input_dict),
episodes=episodes,
explore=explore,
timestep=timestep,
)
# Some policies don't return a tuple, but always just a single action.
# E.g. ES and ARS.
if not isinstance(out, tuple):
single_action = out
state_out = []
info = {}
# Normal case: Policy should return (action, state, info) tuple.
else:
batched_action, state_out, info = out
single_action = unbatch(batched_action)
assert len(single_action) == 1
single_action = single_action[0]
# Return action, internal state(s), infos.
return (
single_action,
tree.map_structure(lambda x: x[0], state_out),
tree.map_structure(lambda x: x[0], info),
)
def compute_actions_from_input_dict(
self,
input_dict: Union[SampleBatch, Dict[str, TensorStructType]],
explore: Optional[bool] = None,
timestep: Optional[int] = None,
episodes: Optional[List["Episode"]] = None,
**kwargs,
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
"""Computes actions from collected samples (across multiple-agents).
Takes an input dict (usually a SampleBatch) as its main data input.
This allows for using this method in case a more complex input pattern
(view requirements) is needed, for example when the Model requires the
last n observations, the last m actions/rewards, or a combination
of any of these.
Args:
input_dict: A SampleBatch or input dict containing the Tensors
to compute actions. `input_dict` already abides to the
Policy's as well as the Model's view requirements and can
thus be passed to the Model as-is.
explore: Whether to pick an exploitation or exploration
action (default: None -> use self.config["explore"]).
timestep: The current (sampling) time step.
episodes: This provides access to all of the internal episodes'
state, which may be useful for model-based or multi-agent
algorithms.
Keyword Args:
kwargs: Forward compatibility placeholder.
Returns:
actions: Batch of output actions, with shape like
[BATCH_SIZE, ACTION_SHAPE].
state_outs: List of RNN state output
batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
info: Dictionary of extra feature batches, if any, with shape like
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
"""
# Default implementation just passes obs, prev-a/r, and states on to
# `self.compute_actions()`.
state_batches = [s for k, s in input_dict.items() if k.startswith("state_in")]
return self.compute_actions(
input_dict[SampleBatch.OBS],
state_batches,
prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
info_batch=input_dict.get(SampleBatch.INFOS),
explore=explore,
timestep=timestep,
episodes=episodes,
**kwargs,
)
@abstractmethod
def compute_actions(
self,
obs_batch: Union[List[TensorStructType], TensorStructType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["Episode"]] = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
**kwargs,
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
"""Computes actions for the current policy.
Args:
obs_batch: Batch of observations.
state_batches: List of RNN state input batches, if any.
prev_action_batch: Batch of previous action values.
prev_reward_batch: Batch of previous rewards.
info_batch: Batch of info objects.
episodes: List of Episode objects, one for each obs in
obs_batch. This provides access to all of the internal
episode state, which may be useful for model-based or
multi-agent algorithms.
explore: Whether to pick an exploitation or exploration action.
Set to None (default) for using the value of
`self.config["explore"]`.
timestep: The current (sampling) time step.
Keyword Args:
kwargs: Forward compatibility placeholder
Returns:
actions: Batch of output actions, with shape like
[BATCH_SIZE, ACTION_SHAPE].
state_outs (List[TensorType]): List of RNN state output
batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
info (List[dict]): Dictionary of extra feature batches, if any,
with shape like
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
"""
raise NotImplementedError
def compute_log_likelihoods(
self,
actions: Union[List[TensorType], TensorType],
obs_batch: Union[List[TensorType], TensorType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
actions_normalized: bool = True,
in_training: bool = True,
) -> TensorType:
"""Computes the log-prob/likelihood for a given action and observation.
The log-likelihood is calculated using this Policy's action
distribution class (self.dist_class).
Args:
actions: Batch of actions, for which to retrieve the
log-probs/likelihoods (given all other inputs: obs,
states, ..).
obs_batch: Batch of observations.
state_batches: List of RNN state input batches, if any.
prev_action_batch: Batch of previous action values.
prev_reward_batch: Batch of previous rewards.
actions_normalized: Is the given `actions` already normalized
(between -1.0 and 1.0) or not? If not and
`normalize_actions=True`, we need to normalize the given
actions first, before calculating log likelihoods.
in_training: Whether to use the forward_train() or forward_exploration() of
the underlying RLModule.
Returns:
Batch of log probs/likelihoods, with shape: [BATCH_SIZE].
"""
raise NotImplementedError
@OverrideToImplementCustomLogic_CallToSuperRecommended
def postprocess_trajectory(
self,
sample_batch: SampleBatch,
other_agent_batches: Optional[
Dict[AgentID, Tuple["Policy", SampleBatch]]
] = None,
episode: Optional["Episode"] = None,
) -> SampleBatch:
"""Implements algorithm-specific trajectory postprocessing.
This will be called on each trajectory fragment computed during policy
evaluation. Each fragment is guaranteed to be only from one episode.
The given fragment may or may not contain the end of this episode,
depending on the `batch_mode=truncate_episodes|complete_episodes`,
`rollout_fragment_length`, and other settings.
Args:
sample_batch: batch of experiences for the policy,
which will contain at most one episode trajectory.
other_agent_batches: In a multi-agent env, this contains a
mapping of agent ids to (policy, agent_batch) tuples
containing the policy and experiences of the other agents.
episode: An optional multi-agent episode object to provide
access to all of the internal episode state, which may
be useful for model-based or multi-agent algorithms.
Returns:
The postprocessed sample batch.
"""
# The default implementation just returns the same, unaltered batch.
return sample_batch
@OverrideToImplementCustomLogic
def loss(
self, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch
) -> Union[TensorType, List[TensorType]]:
"""Loss function for this Policy.
Override this method in order to implement custom loss computations.
Args:
model: The model to calculate the loss(es).
dist_class: The action distribution class to sample actions
from the model's outputs.
train_batch: The input batch on which to calculate the loss.
Returns:
Either a single loss tensor or a list of loss tensors.
"""
raise NotImplementedError
def learn_on_batch(self, samples: SampleBatch) -> Dict[str, TensorType]:
"""Perform one learning update, given `samples`.
Either this method or the combination of `compute_gradients` and
`apply_gradients` must be implemented by subclasses.
Args:
samples: The SampleBatch object to learn from.
Returns:
Dictionary of extra metadata from `compute_gradients()`.
.. testcode::
:skipif: True
policy, sample_batch = ...
policy.learn_on_batch(sample_batch)
"""
# The default implementation is simply a fused `compute_gradients` plus
# `apply_gradients` call.
grads, grad_info = self.compute_gradients(samples)
self.apply_gradients(grads)
return grad_info
def learn_on_batch_from_replay_buffer(
self, replay_actor: ActorHandle, policy_id: PolicyID
) -> Dict[str, TensorType]:
"""Samples a batch from given replay actor and performs an update.
Args:
replay_actor: The replay buffer actor to sample from.
policy_id: The ID of this policy.
Returns:
Dictionary of extra metadata from `compute_gradients()`.
"""
# Sample a batch from the given replay actor.
# Note that for better performance (less data sent through the
# network), this policy should be co-located on the same node
# as `replay_actor`. Such a co-location step is usually done during
# the Algorithm's `setup()` phase.
batch = ray.get(replay_actor.replay.remote(policy_id=policy_id))
if batch is None:
return {}
# Send to own learn_on_batch method for updating.
# TODO: hack w/ `hasattr`
if hasattr(self, "devices") and len(self.devices) > 1:
self.load_batch_into_buffer(batch, buffer_index=0)
return self.learn_on_loaded_batch(offset=0, buffer_index=0)
else:
return self.learn_on_batch(batch)
def load_batch_into_buffer(self, batch: SampleBatch, buffer_index: int = 0) -> int:
"""Bulk-loads the given SampleBatch into the devices' memories.
The data is split equally across all the Policy's devices.
If the data is not evenly divisible by the batch size, excess data
should be discarded.
Args:
batch: The SampleBatch to load.
buffer_index: The index of the buffer (a MultiGPUTowerStack) to use
on the devices. The number of buffers on each device depends
on the value of the `num_multi_gpu_tower_stacks` config key.
Returns:
The number of tuples loaded per device.
"""
raise NotImplementedError
def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
"""Returns the number of currently loaded samples in the given buffer.
Args:
buffer_index: The index of the buffer (a MultiGPUTowerStack)
to use on the devices. The number of buffers on each device
depends on the value of the `num_multi_gpu_tower_stacks` config
key.
Returns:
The number of tuples loaded per device.
"""
raise NotImplementedError
def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
"""Runs a single step of SGD on an already loaded data in a buffer.
Runs an SGD step over a slice of the pre-loaded batch, offset by
the `offset` argument (useful for performing n minibatch SGD
updates repeatedly on the same, already pre-loaded data).
Updates the model weights based on the averaged per-device gradients.
Args:
offset: Offset into the preloaded data. Used for pre-loading
a train-batch once to a device, then iterating over
(subsampling through) this batch n times doing minibatch SGD.
buffer_index: The index of the buffer (a MultiGPUTowerStack)
to take the already pre-loaded data from. The number of buffers
on each device depends on the value of the
`num_multi_gpu_tower_stacks` config key.
Returns:
The outputs of extra_ops evaluated over the batch.
"""
raise NotImplementedError
def compute_gradients(
self, postprocessed_batch: SampleBatch
) -> Tuple[ModelGradients, Dict[str, TensorType]]:
"""Computes gradients given a batch of experiences.
Either this in combination with `apply_gradients()` or
`learn_on_batch()` must be implemented by subclasses.
Args:
postprocessed_batch: The SampleBatch object to use
for calculating gradients.
Returns:
grads: List of gradient output values.
grad_info: Extra policy-specific info values.
"""
raise NotImplementedError
def apply_gradients(self, gradients: ModelGradients) -> None:
"""Applies the (previously) computed gradients.
Either this in combination with `compute_gradients()` or
`learn_on_batch()` must be implemented by subclasses.
Args:
gradients: The already calculated gradients to apply to this
Policy.
"""
raise NotImplementedError
def get_weights(self) -> ModelWeights:
"""Returns model weights.
Note: The return value of this method will reside under the "weights"
key in the return value of Policy.get_state(). Model weights are only
one part of a Policy's state. Other state information contains:
optimizer variables, exploration state, and global state vars such as
the sampling timestep.
Returns:
Serializable copy or view of model weights.
"""
raise NotImplementedError
def set_weights(self, weights: ModelWeights) -> None:
"""Sets this Policy's model's weights.
Note: Model weights are only one part of a Policy's state. Other
state information contains: optimizer variables, exploration state,
and global state vars such as the sampling timestep.
Args:
weights: Serializable copy or view of model weights.
"""
raise NotImplementedError
def get_exploration_state(self) -> Dict[str, TensorType]:
"""Returns the state of this Policy's exploration component.
Returns:
Serializable information on the `self.exploration` object.
"""
return self.exploration.get_state()
def is_recurrent(self) -> bool:
"""Whether this Policy holds a recurrent Model.
Returns:
True if this Policy has-a RNN-based Model.
"""
return False
def num_state_tensors(self) -> int:
"""The number of internal states needed by the RNN-Model of the Policy.
Returns:
int: The number of RNN internal states kept by this Policy's Model.
"""
return 0
def get_initial_state(self) -> List[TensorType]:
"""Returns initial RNN state for the current policy.
Returns:
List[TensorType]: Initial RNN state for the current policy.
"""
return []
@OverrideToImplementCustomLogic_CallToSuperRecommended
def get_state(self) -> PolicyState:
"""Returns the entire current state of this Policy.
Note: Not to be confused with an RNN model's internal state.
State includes the Model(s)' weights, optimizer weights,
the exploration component's state, as well as global variables, such
as sampling timesteps.
Note that the state may contain references to the original variables.
This means that you may need to deepcopy() the state before mutating it.
Returns:
Serialized local state.
"""
state = {
# All the policy's weights.
"weights": self.get_weights(),
# The current global timestep.
"global_timestep": self.global_timestep,
# The current num_grad_updates counter.
"num_grad_updates": self.num_grad_updates,
}
# Add this Policy's spec so it can be retreived w/o access to the original
# code.
policy_spec = PolicySpec(
policy_class=type(self),
observation_space=self.observation_space,
action_space=self.action_space,
config=self.config,
)
state["policy_spec"] = policy_spec.serialize()
if self.config.get("enable_connectors", False):
# Checkpoint connectors state as well if enabled.
connector_configs = {}
if self.agent_connectors:
connector_configs["agent"] = self.agent_connectors.to_state()
if self.action_connectors:
connector_configs["action"] = self.action_connectors.to_state()
state["connector_configs"] = connector_configs