Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLLib] Make movement of tensors to device only happen once #36091

Merged
merged 13 commits into from
Jun 8, 2023
19 changes: 11 additions & 8 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,14 +765,14 @@ def get_parameters(self, module: RLModule) -> Sequence[Param]:
"""

@abc.abstractmethod
def _convert_batch_type(self, batch: MultiAgentBatch) -> NestedDict[TensorType]:
"""Converts a MultiAgentBatch to a NestedDict of Tensors on the correct device.
def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
"""Converts the elements of a MultiAgentBatch to Tensors on the correct device.

Args:
batch: The MultiAgentBatch object to convert.

Returns:
The resulting NestedDict with framework-specific tensor values placed
The resulting MultiAgentBatch with framework-specific tensor values placed
on the correct device.
"""

Expand Down Expand Up @@ -1143,19 +1143,22 @@ def update(
batch_iter = MiniBatchDummyIterator

results = []
for minibatch in batch_iter(batch, minibatch_size, num_iters):
# Convert minibatch into a tensor batch (NestedDict).
tensor_minibatch = self._convert_batch_type(minibatch)
# Convert minibatch into a tensor batch (NestedDict) on the correct device
# (e.g. GPU). We mov the batch to device here to avoid moving ever minibatch
# that is created in the `batch_iter` below.
avnishn marked this conversation as resolved.
Show resolved Hide resolved
batch = self._convert_batch_type(batch)
for tensor_minibatch in batch_iter(batch, minibatch_size, num_iters):
# Make the actual in-graph/traced `_update` call. This should return
# all tensor values (no numpy).
nested_tensor_minibatch = NestedDict(tensor_minibatch.policy_batches)
(
fwd_out,
loss_per_module,
metrics_per_module,
) = self._update(tensor_minibatch)
) = self._update(nested_tensor_minibatch)

result = self.compile_results(
batch=minibatch,
batch=tensor_minibatch,
sven1977 marked this conversation as resolved.
Show resolved Hide resolved
fwd_out=fwd_out,
loss_per_module=loss_per_module,
metrics_per_module=defaultdict(dict, **metrics_per_module),
Expand Down
5 changes: 3 additions & 2 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,10 @@ def _check_registered_optimizer(
)

@override(Learner)
def _convert_batch_type(self, batch: MultiAgentBatch) -> NestedDict[TensorType]:
def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
sven1977 marked this conversation as resolved.
Show resolved Hide resolved
batch = _convert_to_tf(batch.policy_batches)
batch = NestedDict(batch)
length = max(len(b) for b in batch.values())
batch = MultiAgentBatch(batch, env_steps=length)
return batch

@override(Learner)
Expand Down
5 changes: 3 additions & 2 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,10 @@ def get_parameters(self, module: RLModule) -> Sequence[Param]:
return list(module.parameters())

@override(Learner)
def _convert_batch_type(self, batch: MultiAgentBatch):
def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
batch = convert_to_torch_tensor(batch.policy_batches, device=self._device)
batch = NestedDict(batch)
length = max(len(b) for b in batch.values())
batch = MultiAgentBatch(batch, env_steps=length)
return batch

@override(Learner)
Expand Down
4 changes: 1 addition & 3 deletions rllib/examples/learner/ppo_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def _parse_args():

config = (
PPOConfig()
.framework(args.framework)
.training(_enable_learner_api=True)
.rl_module(_enable_rl_module_api=True)
.framework(args.framework, eager_tracing=True)
.environment("CartPole-v1")
.resources(**RESOURCE_CONFIG[args.config])
)
Expand Down
25 changes: 22 additions & 3 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
from functools import partial
import numpy as np
import sys
import itertools
Expand Down Expand Up @@ -1532,11 +1533,14 @@ def concat_samples(samples: List[SampleBatchType]) -> SampleBatchType:
try:
if k == "infos":
concatd_data[k] = _concat_values(
*[s[k] for s in concated_samples], time_major=time_major
*[s[k] for s in concated_samples],
time_major=time_major,
)
else:
values_to_concat = [c[k] for c in concated_samples]
_concat_values_w_time = partial(_concat_values, time_major=time_major)
concatd_data[k] = tree.map_structure(
_concat_values, *[c[k] for c in concated_samples]
_concat_values_w_time, *values_to_concat
)
except RuntimeError as e:
# This should catch torch errors that occur when concatenating
Expand Down Expand Up @@ -1631,7 +1635,22 @@ def _concat_values(*values, time_major=None) -> TensorType:
time_major: Whether to concatenate along the first axis
(time_major=False) or the second axis (time_major=True).
"""
return np.concatenate(list(values), axis=1 if time_major else 0)
if torch and torch.is_tensor(values[0]):
return torch.cat(values, dim=1 if time_major else 0)
elif isinstance(values[0], np.ndarray):
return np.concatenate(values, axis=1 if time_major else 0)
elif tf and tf.is_tensor(values[0]):
return tf.concat(values, axis=1 if time_major else 0)
elif isinstance(values[0], list):
concatenated_list = []
for sublist in values:
concatenated_list.extend(sublist)
return concatenated_list
else:
raise ValueError(
f"Unsupported type for concatenation: {type(values[0])} "
f"first element: {values[0]}"
)


@DeveloperAPI
Expand Down
5 changes: 3 additions & 2 deletions rllib/utils/minibatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __init__(
self._num_covered_epochs = {mid: 0 for mid in batch.policy_batches.keys()}

def __iter__(self):

while min(self._num_covered_epochs.values()) < self._num_iters:

minibatch = {}
for module_id, module_batch in self._batch.policy_batches.items():

Expand Down Expand Up @@ -90,7 +90,8 @@ def __iter__(self):
# TODO (Kourosh): len(batch) is not correct here. However it's also not
# clear what the correct value should be. Since training does not depend on
# this it will be fine for now.
minibatch = MultiAgentBatch(minibatch, len(self._batch))
length = max(len(b) for b in minibatch.values())
minibatch = MultiAgentBatch(minibatch, length)
yield minibatch


Expand Down
Loading