Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add Optimizer State To Learner get_state #34760

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2283,8 +2283,8 @@ py_test(

py_test(
name = "utils/tests/test_torch_utils",
tags = ["team:rllib", "utils"],
size = "small",
tags = ["team:rllib", "utils", "gpu"],
size = "medium",
srcs = ["utils/tests/test_torch_utils.py"]
)

Expand Down
43 changes: 40 additions & 3 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,14 +849,29 @@ def set_state(self, state: Mapping[str, Any]) -> None:

Args:
state: The state of the optimizer and module. Can be obtained
from `get_state`.
from `get_state`. State is a dictionary with two keys:
"module_state" and "optimizer_state". The value of each key
is a dictionary that can be passed to `set_weights` and
`set_optimizer_weights` respectively.

"""
# TODO (Kourosh): We have both get(set)_state and get(set)_weights. I think
# having both can become confusing. Can we simplify this API requirement?
self._check_is_built()
# TODO: once we figure out the optimizer format, we can set/get the state
self._module.set_state(state.get("module_state", {}))
if "module_state" not in state:
raise ValueError(
"state must have a key 'module_state' for the module weights"
)
if "optimizer_state" not in state:
raise ValueError(
"state must have a key 'optimizer_state' for the optimizer weights"
)

module_state = state.get("module_state")
optimizer_state = state.get("optimizer_state")
self.set_weights(module_state)
self.set_optimizer_weights(optimizer_state)

def get_state(self) -> Mapping[str, Any]:
"""Get the state of the learner.
Expand All @@ -867,7 +882,29 @@ def get_state(self) -> Mapping[str, Any]:
"""
self._check_is_built()
# TODO: once we figure out the optimizer format, we can set/get the state
return {"module_state": self._module.get_state()}
return {
"module_state": self.get_weights(),
"optimizer_state": self.get_optimizer_weights(),
}
# return {"module_state": self.get_weights(), "optimizer_state": {}}

def set_optimizer_weights(self, weights: Mapping[str, Any]) -> None:
"""Set the weights of the optimizer.

Args:
weights: The weights of the optimizer.

"""
raise NotImplementedError

def get_optimizer_weights(self) -> Mapping[str, Any]:
"""Get the weights of the optimizer.

Returns:
The weights of the optimizer.

"""
raise NotImplementedError

def _get_metadata(self) -> Dict[str, Any]:
metadata = {
Expand Down
27 changes: 25 additions & 2 deletions rllib/core/learner/tests/test_learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

LOCAL_SCALING_CONFIGS = {
"local-cpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=0),
"local-gpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=0.5),
"local-gpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=1),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't actually support fractional gpu, so this doesn't matter.

}


Expand All @@ -45,6 +45,17 @@
@ray.remote(num_gpus=1)
class RemoteTrainingHelper:
def local_training_helper(self, fw, scaling_mode) -> None:
if fw == "torch":
import torch

torch.manual_seed(0)
elif fw == "tf":
import tensorflow as tf

# this is done by rllib already inside of the policy class, but we need to
# do it here for testing purposes
tf.compat.v1.enable_eager_execution()
tf.random.set_seed(0)
env = gym.make("CartPole-v1")
scaling_config = LOCAL_SCALING_CONFIGS[scaling_mode]
lr = 1e-3
Expand All @@ -71,13 +82,25 @@ def local_training_helper(self, fw, scaling_mode) -> None:

# make the state of the learner and the local learner_group identical
local_learner.set_state(learner_group.get_state())
# learner_group.set_state(learner_group.get_state())
check(local_learner.get_state(), learner_group.get_state())

# do another update
batch = reader.next()
ma_batch = MultiAgentBatch(
{new_module_id: batch, DEFAULT_POLICY_ID: batch}, env_steps=batch.count
)
check(local_learner.update(ma_batch), learner_group.update(ma_batch))
# the optimizer state is not initialized fully until the first time that
# training is completed. A call to get state before that won't contain the
# optimizer state. So we do a dummy update here to initialize the optimizer
local_learner.update(ma_batch)
learner_group.update(ma_batch)

check(local_learner.get_state(), learner_group.get_state())
local_learner_results = local_learner.update(ma_batch)
learner_group_results = learner_group.update(ma_batch)
avnishn marked this conversation as resolved.
Show resolved Hide resolved

check(local_learner_results, learner_group_results)

check(local_learner.get_state(), learner_group.get_state())

Expand Down
19 changes: 19 additions & 0 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,25 @@ def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None:
def set_weights(self, weights: Mapping[str, Any]) -> None:
self._module.set_state(weights)

@override(Learner)
def get_optimizer_weights(self) -> Mapping[str, Any]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to find a way to reuse these functions when saving the optimizer state, but its difficult since there is actually little overlap -- when saving the optimizer state, we actually save in native tensorflow format instead of numpy.

optim_weights = {}
with tf.init_scope():
for name, optim in self._named_optimizers.items():
optim_weights[name] = [var.numpy() for var in optim.variables()]
return optim_weights

@override(Learner)
def set_optimizer_weights(self, weights: Mapping[str, Any]) -> None:
for name, weight_array in weights.items():
if name not in self._named_optimizers:
raise ValueError(
f"Optimizer {name} in weights is not known."
f"Known optimizers are {self._named_optimizers.keys()}"
)
optim = self._named_optimizers[name]
optim.set_weights(weight_array)

@override(Learner)
def get_param_ref(self, param: ParamType) -> Hashable:
return param.ref()
Expand Down
38 changes: 34 additions & 4 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchDDPRLModule
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.torch_utils import clip_gradients, convert_to_torch_tensor
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.torch_utils import (
clip_gradients,
convert_to_torch_tensor,
copy_torch_tensors,
)
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()
Expand Down Expand Up @@ -119,16 +123,42 @@ def set_weights(self, weights: Mapping[str, Any]) -> None:
def _save_optimizers(self, path: Union[str, pathlib.Path]) -> None:
path = pathlib.Path(path)
path.mkdir(parents=True, exist_ok=True)
for name, optim in self._named_optimizers.items():
torch.save(optim.state_dict(), path / f"{name}.pt")
optim_weights = self.get_optimizer_weights()
for name, weights in optim_weights.items():
torch.save(weights, path / f"{name}.pt")

@override(Learner)
def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None:
path = pathlib.Path(path)
if not path.exists():
raise ValueError(f"Directory {path} does not exist.")
weights = {}
for name in self._named_optimizers.keys():
weights[name] = torch.load(path / f"{name}.pt")
self.set_optimizer_weights(weights)

@override(Learner)
def get_optimizer_weights(self) -> Mapping[str, Any]:
optimizer_name_weights = {}
for name, optim in self._named_optimizers.items():
optim.load_state_dict(torch.load(path / f"{name}.pt"))
optim_state_dict = optim.state_dict()
optim_state_dict_cpu = copy_torch_tensors(optim_state_dict, device="cpu")
optimizer_name_weights[name] = optim_state_dict_cpu
return optimizer_name_weights

@override(Learner)
def set_optimizer_weights(self, weights: Mapping[str, Any]) -> None:
for name, weight_dict in weights.items():
if name not in self._named_optimizers:
raise ValueError(
f"Optimizer {name} in weights is not known."
f"Known optimizers are {self._named_optimizers.keys()}"
)
optim = self._named_optimizers[name]
weight_dict_correct_device = copy_torch_tensors(
weight_dict, device=self._device
)
optim.load_state_dict(weight_dict_correct_device)

@override(Learner)
def get_param_ref(self, param: ParamType) -> Hashable:
Expand Down
53 changes: 52 additions & 1 deletion rllib/utils/tests/test_torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import torch.cuda

import ray
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.torch_utils import (
convert_to_torch_tensor,
copy_torch_tensors,
)


class TestTorchUtils(unittest.TestCase):
Expand Down Expand Up @@ -43,6 +46,54 @@ def test_convert_to_torch_tensor(self):
self.assertTrue(converted["b"].dtype is torch.float32)
self.assertTrue(converted["c"] is None)

def test_copy_torch_tensors(self):
array = np.array([1, 2, 3], dtype=np.float32)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tensor = torch.from_numpy(array).to(device)
tensor_2 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64).to(device)

# Test single tensor
copied_tensor = copy_torch_tensors(tensor, device)
self.assertTrue(copied_tensor.device == device)
self.assertNotEqual(id(copied_tensor), id(tensor))
self.assertTrue(all(copied_tensor == tensor))

# check that dtypes aren't modified
copied_tensor_2 = copy_torch_tensors(tensor_2, device)
self.assertTrue(copied_tensor_2.dtype == tensor_2.dtype)
self.assertFalse(copied_tensor_2.dtype == torch.float32)

# Test nested structure can be converted
nested_structure = {"a": tensor, "b": tensor_2, "c": 1}
copied_nested_structure = copy_torch_tensors(nested_structure, device)
self.assertTrue(copied_nested_structure["a"].device == device)
self.assertTrue(copied_nested_structure["b"].device == device)
self.assertTrue(copied_nested_structure["c"] == 1)
self.assertNotEqual(id(copied_nested_structure["a"]), id(tensor))
self.assertNotEqual(id(copied_nested_structure["b"]), id(tensor_2))
self.assertTrue(all(copied_nested_structure["a"] == tensor))
self.assertTrue(all(copied_nested_structure["b"] == tensor_2))
avnishn marked this conversation as resolved.
Show resolved Hide resolved

# if gpu is available test moving tensor from cpu to gpu and vice versa
if torch.cuda.is_available():
tensor = torch.from_numpy(array).to("cpu")
copied_tensor = copy_torch_tensors(tensor, "cuda:0")
self.assertFalse(copied_tensor.device == torch.device("cpu"))
self.assertTrue(copied_tensor.device == torch.device("cuda:0"))
self.assertNotEqual(id(copied_tensor), id(tensor))
self.assertTrue(
all(copied_tensor.detach().cpu().numpy() == tensor.detach().numpy())
)

tensor = torch.from_numpy(array).to("cuda:0")
copied_tensor = copy_torch_tensors(tensor, "cpu")
self.assertFalse(copied_tensor.device == torch.device("cuda:0"))
self.assertTrue(copied_tensor.device == torch.device("cpu"))
self.assertNotEqual(id(copied_tensor), id(tensor))
self.assertTrue(
all(copied_tensor.detach().numpy() == tensor.detach().cpu().numpy())
)


if __name__ == "__main__":
import pytest
Expand Down
31 changes: 31 additions & 0 deletions rllib/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,37 @@ def mapping(item):
return tree.map_structure(mapping, x)


@PublicAPI
def copy_torch_tensors(x: TensorStructType, device: Optional[str] = None):
"""Creates a copy of `x` and makes deep copies torch.Tensors in x.

Also moves the copied tensors to the specified device (if not None).

Note if an object in x is not a torch.Tensor, it will be shallow-copied.

Args:
x : Any (possibly nested) struct possibly containing torch.Tensors.
device : The device to move the tensors to.

Returns:
Any: A new struct with the same structure as `x`, but with all
torch.Tensors deep-copied and moved to the specified device.

"""

def mapping(item):
if isinstance(item, torch.Tensor):
return (
torch.clone(item.detach())
if device is None
else item.detach().to(device)
)
else:
return item

return tree.map_structure(mapping, x)


@PublicAPI
def explained_variance(y: TensorType, pred: TensorType) -> TensorType:
"""Computes the explained variance for a pair of labels and predictions.
Expand Down