Skip to content

Commit

Permalink
[RLlib] Actually save the optimizer state for tf learners (ray-projec…
Browse files Browse the repository at this point in the history
…t#34252)

It turns out you can get the actual optimizer state by calling optimizer.variables for tf keras.
this pr enables us to save the full optimizer state and restore it. To do this I added a new
file called optimizer_name_state.txt to the checkpoint. This holds a bytestring serialized
representation of the optimizer's state. It looks like the optimizer's variable state doesn't include
things like the learning rate, so I still need to save those as a separate file and
reconstruct the optimizer first before loading the state.

---------

Signed-off-by: Avnish <[email protected]>
  • Loading branch information
avnishn committed Apr 12, 2023
1 parent 74325ef commit fa238f7
Showing 1 changed file with 107 additions and 8 deletions.
115 changes: 107 additions & 8 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,15 @@ def configure_optimizer_per_module(
) -> Union[ParamOptimizerPair, NamedParamOptimizerPairs]:
module = self._module[module_id]
lr = self._optimizer_config["lr"]
optim = tf.keras.optimizers.Adam(learning_rate=lr)
pair: ParamOptimizerPair = (
self.get_parameters(module),
tf.keras.optimizers.Adam(learning_rate=lr),
optim,
)
# this isn't strictly necessary, but makes it so that if a checkpoint is
# computed before training actually starts, then it will be the same in
# shape / size as a checkpoint after training starts.
optim.build(module.trainable_variables)
return pair

@override(Learner)
Expand Down Expand Up @@ -139,30 +144,124 @@ def load_state(
with self._strategy.scope():
super().load_state(path)

def _save_optimizer_hparams(
self,
path: pathlib.Path,
optim: "tf.keras.optimizers.Optimizer",
optim_name: str,
) -> None:
"""Save the hyperparameters of optim to path/optim_name_hparams.json.
Args:
path: The path to the directory to save the hyperparameters to.
optim: The optimizer to save the hyperparameters of.
optim_name: The name of the optimizer.
"""
hparams = tf.keras.optimizers.serialize(optim)
hparams = tf.nest.map_structure(convert_numpy_to_python_primitives, hparams)
with open(path / f"{optim_name}_hparams.json", "w") as f:
json.dump(hparams, f)

def _save_optimizer_state(
self,
path: pathlib.Path,
optim: "tf.keras.optimizers.Optimizer",
optim_name: str,
) -> None:
"""Save the state variables of optim to path/optim_name_state.txt.
Args:
path: The path to the directory to save the state to.
optim: The optimizer to save the state of.
optim_name: The name of the optimizer.
"""
state = optim.variables()
serialized_tensors = [tf.io.serialize_tensor(tensor) for tensor in state]
contents = tf.strings.join(serialized_tensors, separator="tensor: ")
tf.io.write_file(str(path / f"{optim_name}_state.txt"), contents)

@override(Learner)
def _save_optimizers(self, path: Union[str, pathlib.Path]) -> None:
path = pathlib.Path(path)
path.mkdir(parents=True, exist_ok=True)
for name, optim in self._named_optimizers.items():
state = tf.keras.optimizers.serialize(optim)
state = tf.nest.map_structure(convert_numpy_to_python_primitives, state)
with open(path / f"{name}.json", "w") as f:
json.dump(state, f)
self._save_optimizer_hparams(path, optim, name)
self._save_optimizer_state(path, optim, name)

def _load_optimizer_from_hparams(
self, path: pathlib.Path, optim_name: str
) -> "tf.keras.optimizers.Optimizer":
"""Load an optimizer from the hyperparameters saved at path/optim_name_hparams.json.
Args:
path: The path to the directory to load the hyperparameters from.
optim_name: The name of the optimizer.
Returns:
The optimizer loaded from the hyperparameters.
"""
with open(path / f"{optim_name}_hparams.json", "r") as f:
state = json.load(f)
return tf.keras.optimizers.deserialize(state)

def _load_optimizer_state(
self,
path: pathlib.Path,
optim: "tf.keras.optimizers.Optimizer",
optim_name: str,
) -> None:
"""Load the state of optim from the state saved at path/optim_name_state.txt.
Args:
path: The path to the directory to load the state from.
optim: The optimizer to load the state into.
optim_name: The name of the optimizer.
"""
contents = tf.io.read_file(str(path / f"{optim_name}_state.txt"))
serialized_tensors = tf.strings.split(contents, sep="tensor: ")
unserialized_optim_state = []
for serialized_tensor, optim_tensor in zip(
serialized_tensors, optim.variables()
):
unserialized_optim_state.append(
tf.io.parse_tensor(serialized_tensor, optim_tensor.dtype)
)

# set the state of the optimizer to the state that was saved
optim.set_weights(unserialized_optim_state)

@override(Learner)
def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None:
path = pathlib.Path(path)
for name in self._named_optimizers.keys():
with open(path / f"{name}.json", "r") as f:
state = json.load(f)
new_optim = tf.keras.optimizers.deserialize(state)
new_optim = self._load_optimizer_from_hparams(path, name)
old_optim = self._named_optimizers[name]

# assign replace the old optim with the new optim in the learner's state
self._named_optimizers[name] = new_optim
param_seq = self._optimizer_parameters.pop(old_optim)
self._optimizer_parameters[new_optim] = []
for param_ref in param_seq:
self._optimizer_parameters[new_optim].append(param_ref)

# delete the old optimizer / free its memory
del old_optim
# these are the variables that the optimizer is supposed to optimize over
variable_list = [
self._params[param_ref]
for param_ref in self._optimizer_parameters[new_optim]
]
# initialize the optimizer with the variables that it is supposed to
# optimize over
new_optim.build(variable_list)

# This loads in the actual state of the optimizer.
self._load_optimizer_state(path, new_optim, name)

@override(Learner)
def set_weights(self, weights: Mapping[str, Any]) -> None:
self._module.set_state(weights)
Expand Down

0 comments on commit fa238f7

Please sign in to comment.