Skip to content

Commit

Permalink
[RLlib] Fix LearnerAPI + tf2 slowness due to wrong batch -> tensor co…
Browse files Browse the repository at this point in the history
…nversion. (ray-project#35818)
  • Loading branch information
sven1977 committed May 26, 2023
1 parent 3a7f8d9 commit ab8fd0a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 28 deletions.
10 changes: 4 additions & 6 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,16 +559,14 @@ def get_parameters(self, module: RLModule) -> Sequence[Param]:

@abc.abstractmethod
def _convert_batch_type(self, batch: MultiAgentBatch) -> NestedDict[TensorType]:
"""Converts a MultiAgentBatch to a NestedDict of Tensors.
This should convert the input batch from a MultiAgentBatch format to framework
specific tensor format located on the correct device.
"""Converts a MultiAgentBatch to a NestedDict of Tensors on the correct device.
Args:
batch: A MultiAgentBatch.
batch: The MultiAgentBatch object to convert.
Returns:
A NestedDict.
The resulting NestedDict with framework-specific tensor values placed
on the correct device.
"""

@OverrideToImplementCustomLogic_CallToSuperRecommended
Expand Down
26 changes: 4 additions & 22 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import logging
import numpy as np
import pathlib
from typing import (
Any,
Expand All @@ -27,6 +26,7 @@
SingleAgentRLModuleSpec,
)
from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule
from ray.rllib.policy.eager_tf_policy import _convert_to_tf
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import (
override,
Expand Down Expand Up @@ -330,27 +330,9 @@ def _check_structure_param_optim_pair(self, param_optim_pair: Any) -> None:

@override(Learner)
def _convert_batch_type(self, batch: MultiAgentBatch) -> NestedDict[TensorType]:
"""Convert the arrays of batch to tf.Tensor's.
Note: This is an in place operation.
Args:
batch: The batch to convert.
Returns:
The converted batch.
"""
# TODO(avnishn): This is a hack to get around the fact that
# SampleBatch.count becomes 0 after decorating the function with
# tf.function. This messes with input spec checking. Other fields of
# the sample batch are possibly modified by tf.function which may lead
# to unwanted consequences. We'll need to further investigate this.
ma_batch = NestedDict(batch.policy_batches)
for key, value in ma_batch.items():
if isinstance(value, np.ndarray):
ma_batch[key] = tf.convert_to_tensor(value, dtype=tf.float32)
return ma_batch
batch = _convert_to_tf(batch.policy_batches)
batch = NestedDict(batch)
return batch

@override(Learner)
def add_module(
Expand Down

0 comments on commit ab8fd0a

Please sign in to comment.