Skip to content

Commit

Permalink
[RLlib] Unify TensorSpecs to a single framework-agnostic class. (ray-…
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturNiederfahrenhorst committed Apr 24, 2023
1 parent 92d570f commit 1531823
Show file tree
Hide file tree
Showing 18 changed files with 451 additions and 406 deletions.
50 changes: 48 additions & 2 deletions rllib/algorithms/ppo/ppo_base_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@

import abc

from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.core.models.base import ActorCriticEncoder
from ray.rllib.core.models.specs.specs_base import TensorSpec
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.rl_module import RLModuleConfig
from ray.rllib.models.distributions import Distribution
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.annotations import override


@ExperimentalAPI
Expand All @@ -27,3 +33,43 @@ def setup(self):
# __sphinx_doc_end__

assert isinstance(self.encoder, ActorCriticEncoder)

@override(RLModule)
def input_specs_inference(self) -> SpecDict:
return self.input_specs_exploration()

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return SpecDict({SampleBatch.ACTION_DIST: Distribution})

@override(RLModule)
def input_specs_exploration(self):
return []

@override(RLModule)
def output_specs_exploration(self) -> SpecDict:
return [
SampleBatch.VF_PREDS,
SampleBatch.ACTION_DIST,
SampleBatch.ACTION_DIST_INPUTS,
]

@override(RLModule)
def input_specs_train(self) -> SpecDict:
specs = self.input_specs_exploration()
specs.append(SampleBatch.ACTIONS)
if SampleBatch.OBS in specs:
specs.append(SampleBatch.NEXT_OBS)
return specs

@override(RLModule)
def output_specs_train(self) -> SpecDict:
spec = SpecDict(
{
SampleBatch.ACTION_DIST: Distribution,
SampleBatch.ACTION_LOGP: TensorSpec("b", framework=self.framework),
SampleBatch.VF_PREDS: TensorSpec("b", framework=self.framework),
"entropy": TensorSpec("b", framework=self.framework),
}
)
return spec
38 changes: 1 addition & 37 deletions rllib/algorithms/ppo/tf/ppo_tf_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Mapping, Any, List
from typing import Mapping, Any

from ray.rllib.algorithms.ppo.ppo_base_rl_module import PPORLModuleBase
from ray.rllib.core.models.base import ACTOR, CRITIC
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
from ray.rllib.models.distributions import Distribution
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
Expand All @@ -30,40 +28,6 @@ def __init__(self, *args, **kwargs):
# else:
# return NestedDict({})

@override(RLModule)
def input_specs_train(self) -> List[str]:
return [SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.ACTION_LOGP]

@override(RLModule)
def output_specs_train(self) -> List[str]:
return [
SampleBatch.ACTION_DIST_INPUTS,
SampleBatch.ACTION_DIST,
SampleBatch.ACTION_LOGP,
SampleBatch.VF_PREDS,
"entropy",
]

@override(RLModule)
def input_specs_exploration(self):
return []

@override(RLModule)
def output_specs_exploration(self) -> List[str]:
return [
SampleBatch.ACTION_DIST,
SampleBatch.VF_PREDS,
SampleBatch.ACTION_DIST_INPUTS,
]

@override(RLModule)
def input_specs_inference(self) -> SpecDict:
return self.input_specs_exploration()

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return SpecDict({SampleBatch.ACTION_DIST: Distribution})

@override(RLModule)
def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]:
output = {}
Expand Down
43 changes: 0 additions & 43 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
from ray.rllib.core.models.base import ACTOR, CRITIC, ENCODER_OUT, STATE_IN
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.models.specs.specs_torch import TorchTensorSpec
from ray.rllib.models.distributions import Distribution
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
Expand Down Expand Up @@ -36,14 +33,6 @@ def __init__(self, *args, **kwargs):
TorchRLModule.__init__(self, *args, **kwargs)
PPORLModuleBase.__init__(self, *args, **kwargs)

@override(RLModule)
def input_specs_inference(self) -> SpecDict:
return self.input_specs_exploration()

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return SpecDict({SampleBatch.ACTION_DIST: Distribution})

@override(RLModule)
def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]:
output = {}
Expand All @@ -69,18 +58,6 @@ def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]:

return output

@override(RLModule)
def input_specs_exploration(self):
return []

@override(RLModule)
def output_specs_exploration(self) -> SpecDict:
return [
SampleBatch.VF_PREDS,
SampleBatch.ACTION_DIST,
SampleBatch.ACTION_DIST_INPUTS,
]

@override(RLModule)
def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]:
"""PPO forward pass during exploration.
Expand Down Expand Up @@ -118,26 +95,6 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]:
)
return output

@override(RLModule)
def input_specs_train(self) -> SpecDict:
specs = self.input_specs_exploration()
specs.append(SampleBatch.ACTIONS)
if SampleBatch.OBS in specs:
specs.append(SampleBatch.NEXT_OBS)
return specs

@override(RLModule)
def output_specs_train(self) -> SpecDict:
spec = SpecDict(
{
SampleBatch.ACTION_DIST: Distribution,
SampleBatch.ACTION_LOGP: TorchTensorSpec("b", dtype=torch.float32),
SampleBatch.VF_PREDS: TorchTensorSpec("b", dtype=torch.float32),
"entropy": TorchTensorSpec("b", dtype=torch.float32),
}
)
return spec

def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]:
output = {}

Expand Down
9 changes: 5 additions & 4 deletions rllib/core/models/specs/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ def convert_to_canonical_format(spec: SpecType) -> Union[Spec, SpecDict]:
# {"foo": TypeSpec(int), "bar": SpecDict({"baz": TypeSpec(str)})}
# )
spec = {"foo": int, "bar": {"baz": TorchTensorSpec("b,h")}}
spec = {"foo": int, "bar": {"baz": TensorSpec("b,h", framework="torch")}}
output = convert_to_canonical_format(spec)
# output = SpecDict(
# {"foo": TypeSpec(int), "bar": SpecDict({"baz": TorchTensorSpec("b,h")})}
# {"foo": TypeSpec(int), "bar": SpecDict({"baz": TensorSpec("b,h",
framework="torch")})}
# )
Expand All @@ -68,9 +69,9 @@ def convert_to_canonical_format(spec: SpecType) -> Union[Spec, SpecDict]:
output = convert_to_canonical_format(spec)
# output = None
spec = TorchTensorSpec("b,h")
spec = TensorSpec("b,h", framework="torch")
output = convert_to_canonical_format(spec)
# output = TorchTensorSpec("b,h")
# output = TensorSpec("b,h", framework="torch")
Args:
spec: The spec to convert to canonical format.
Expand Down
Loading

0 comments on commit 1531823

Please sign in to comment.