-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLLib] Make movement of tensors to device only happen once #36091
[RLLib] Make movement of tensors to device only happen once #36091
Conversation
Signed-off-by: avnishn <[email protected]>
Signed-off-by: Avnish <[email protected]>
…appen once Signed-off-by: avnishn <[email protected]>
Signed-off-by: Avnish <[email protected]>
…into fix_minibatching_gpu
Signed-off-by: Avnish <[email protected]>
rllib/core/learner/learner.py
Outdated
for minibatch in batch_iter(batch, minibatch_size, num_iters): | ||
# Convert minibatch into a tensor batch (NestedDict). | ||
tensor_minibatch = self._convert_batch_type(minibatch) | ||
# Convert minibatch into a tensor batch (NestedDict). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add: "... on the correct device (e.g. GPU)"? This would clarify further.
Also note then that "we only perform copying to the correct device once so we do not have to move data in each minibatch iteration below". something like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rllib/utils/minibatch_utils.py
Outdated
@@ -1,5 +1,6 @@ | |||
import math | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two empty lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can get rid of these, they're left over from previous commits
rllib/policy/sample_batch.py
Outdated
@@ -1631,7 +1635,22 @@ def _concat_values(*values, time_major=None) -> TensorType: | |||
time_major: Whether to concatenate along the first axis | |||
(time_major=False) or the second axis (time_major=True). | |||
""" | |||
return np.concatenate(list(values), axis=1 if time_major else 0) | |||
if torch and isinstance(values[0], torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why we need these changes (add comment)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, let's use torch.is_tensor()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
rllib/policy/sample_batch.py
Outdated
return torch.cat(values, dim=1 if time_major else 0) | ||
elif isinstance(values[0], np.ndarray): | ||
return np.concatenate(values, axis=1 if time_major else 0) | ||
elif tf and isinstance(values[0], tf.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use tf.is_tensor() instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. Just a few nits and questions. Thanks for this important enhancement @avnishn !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. Just a few nits and questions. Thanks for this important enhancement @avnishn !
Signed-off-by: Avnish <[email protected]>
Co-authored-by: Sven Mika <[email protected]> Signed-off-by: Avnish Narayan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Let's wait for all tests to pass again, then, I'll merge. ...
Co-authored-by: kourosh hakhamaneshi <[email protected]> Signed-off-by: Avnish Narayan <[email protected]>
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
…ect#36091) Signed-off-by: Avnishn <[email protected]> Co-authored-by: Sven Mika <[email protected]> Co-authored-by: kourosh hakhamaneshi <[email protected]> Signed-off-by: e428265 <[email protected]>
Signed-off-by: avnishn [email protected]
Our current minibatching logic in the learner stack forces individual minibatches to be moved to the gpu after they have been sliced. This is wasteful since it creates unnecessary copies of batches, and adds unnecessary movements of the batch over to gpu. This pr addresses this by moving the whole batch to the gpu first, then doing any minibatching operations on it.
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.