Skip to content

Commit

Permalink
[RLLib] Make movement of tensors to device only happen once (ray-proj…
Browse files Browse the repository at this point in the history
…ect#36091)

Signed-off-by: Avnishn <[email protected]>
Co-authored-by: Sven Mika <[email protected]>
Co-authored-by: kourosh hakhamaneshi <[email protected]>
  • Loading branch information
3 people committed Jun 8, 2023
1 parent 69584e5 commit d950281
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 20 deletions.
19 changes: 11 additions & 8 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,14 +765,14 @@ def get_parameters(self, module: RLModule) -> Sequence[Param]:
"""

@abc.abstractmethod
def _convert_batch_type(self, batch: MultiAgentBatch) -> NestedDict[TensorType]:
"""Converts a MultiAgentBatch to a NestedDict of Tensors on the correct device.
def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
"""Converts the elements of a MultiAgentBatch to Tensors on the correct device.
Args:
batch: The MultiAgentBatch object to convert.
Returns:
The resulting NestedDict with framework-specific tensor values placed
The resulting MultiAgentBatch with framework-specific tensor values placed
on the correct device.
"""

Expand Down Expand Up @@ -1143,19 +1143,22 @@ def update(
batch_iter = MiniBatchDummyIterator

results = []
for minibatch in batch_iter(batch, minibatch_size, num_iters):
# Convert minibatch into a tensor batch (NestedDict).
tensor_minibatch = self._convert_batch_type(minibatch)
# Convert input batch into a tensor batch (MultiAgentBatch) on the correct
# device (e.g. GPU). We move the batch already here to avoid having to move
# every single minibatch that is created in the `batch_iter` below.
batch = self._convert_batch_type(batch)
for tensor_minibatch in batch_iter(batch, minibatch_size, num_iters):
# Make the actual in-graph/traced `_update` call. This should return
# all tensor values (no numpy).
nested_tensor_minibatch = NestedDict(tensor_minibatch.policy_batches)
(
fwd_out,
loss_per_module,
metrics_per_module,
) = self._update(tensor_minibatch)
) = self._update(nested_tensor_minibatch)

result = self.compile_results(
batch=minibatch,
batch=tensor_minibatch,
fwd_out=fwd_out,
loss_per_module=loss_per_module,
metrics_per_module=defaultdict(dict, **metrics_per_module),
Expand Down
5 changes: 3 additions & 2 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,10 @@ def _check_registered_optimizer(
)

@override(Learner)
def _convert_batch_type(self, batch: MultiAgentBatch) -> NestedDict[TensorType]:
def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
batch = _convert_to_tf(batch.policy_batches)
batch = NestedDict(batch)
length = max(len(b) for b in batch.values())
batch = MultiAgentBatch(batch, env_steps=length)
return batch

@override(Learner)
Expand Down
5 changes: 3 additions & 2 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,10 @@ def get_parameters(self, module: RLModule) -> Sequence[Param]:
return list(module.parameters())

@override(Learner)
def _convert_batch_type(self, batch: MultiAgentBatch):
def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
batch = convert_to_torch_tensor(batch.policy_batches, device=self._device)
batch = NestedDict(batch)
length = max(len(b) for b in batch.values())
batch = MultiAgentBatch(batch, env_steps=length)
return batch

@override(Learner)
Expand Down
4 changes: 1 addition & 3 deletions rllib/examples/learner/ppo_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def _parse_args():

config = (
PPOConfig()
.framework(args.framework)
.training(_enable_learner_api=True)
.rl_module(_enable_rl_module_api=True)
.framework(args.framework, eager_tracing=True)
.environment("CartPole-v1")
.resources(**RESOURCE_CONFIG[args.config])
)
Expand Down
25 changes: 22 additions & 3 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
from functools import partial
import numpy as np
import sys
import itertools
Expand Down Expand Up @@ -1532,11 +1533,14 @@ def concat_samples(samples: List[SampleBatchType]) -> SampleBatchType:
try:
if k == "infos":
concatd_data[k] = _concat_values(
*[s[k] for s in concated_samples], time_major=time_major
*[s[k] for s in concated_samples],
time_major=time_major,
)
else:
values_to_concat = [c[k] for c in concated_samples]
_concat_values_w_time = partial(_concat_values, time_major=time_major)
concatd_data[k] = tree.map_structure(
_concat_values, *[c[k] for c in concated_samples]
_concat_values_w_time, *values_to_concat
)
except RuntimeError as e:
# This should catch torch errors that occur when concatenating
Expand Down Expand Up @@ -1631,7 +1635,22 @@ def _concat_values(*values, time_major=None) -> TensorType:
time_major: Whether to concatenate along the first axis
(time_major=False) or the second axis (time_major=True).
"""
return np.concatenate(list(values), axis=1 if time_major else 0)
if torch and torch.is_tensor(values[0]):
return torch.cat(values, dim=1 if time_major else 0)
elif isinstance(values[0], np.ndarray):
return np.concatenate(values, axis=1 if time_major else 0)
elif tf and tf.is_tensor(values[0]):
return tf.concat(values, axis=1 if time_major else 0)
elif isinstance(values[0], list):
concatenated_list = []
for sublist in values:
concatenated_list.extend(sublist)
return concatenated_list
else:
raise ValueError(
f"Unsupported type for concatenation: {type(values[0])} "
f"first element: {values[0]}"
)


@DeveloperAPI
Expand Down
5 changes: 3 additions & 2 deletions rllib/utils/minibatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __init__(
self._num_covered_epochs = {mid: 0 for mid in batch.policy_batches.keys()}

def __iter__(self):

while min(self._num_covered_epochs.values()) < self._num_iters:

minibatch = {}
for module_id, module_batch in self._batch.policy_batches.items():

Expand Down Expand Up @@ -90,7 +90,8 @@ def __iter__(self):
# TODO (Kourosh): len(batch) is not correct here. However it's also not
# clear what the correct value should be. Since training does not depend on
# this it will be fine for now.
minibatch = MultiAgentBatch(minibatch, len(self._batch))
length = max(len(b) for b in minibatch.values())
minibatch = MultiAgentBatch(minibatch, length)
yield minibatch


Expand Down

0 comments on commit d950281

Please sign in to comment.