From d914e9a93261b62e015aa3a5a6dfdb7723193cb4 Mon Sep 17 00:00:00 2001 From: Avnish Narayan <38871737+avnishn@users.noreply.github.com> Date: Tue, 11 Apr 2023 18:04:19 -0700 Subject: [PATCH] [RLlib] Actually save the optimizer state for tf learners (#34252) It turns out you can get the actual optimizer state by calling optimizer.variables for tf keras. this pr enables us to save the full optimizer state and restore it. To do this I added a new file called optimizer_name_state.txt to the checkpoint. This holds a bytestring serialized representation of the optimizer's state. It looks like the optimizer's variable state doesn't include things like the learning rate, so I still need to save those as a separate file and reconstruct the optimizer first before loading the state. --------- Signed-off-by: Avnish Signed-off-by: Jack He --- rllib/core/learner/tf/tf_learner.py | 115 ++++++++++++++++++++++++++-- 1 file changed, 107 insertions(+), 8 deletions(-) diff --git a/rllib/core/learner/tf/tf_learner.py b/rllib/core/learner/tf/tf_learner.py index 004ffa55825f7..0183f0fc2e648 100644 --- a/rllib/core/learner/tf/tf_learner.py +++ b/rllib/core/learner/tf/tf_learner.py @@ -78,10 +78,15 @@ def configure_optimizer_per_module( ) -> Union[ParamOptimizerPair, NamedParamOptimizerPairs]: module = self._module[module_id] lr = self._optimizer_config["lr"] + optim = tf.keras.optimizers.Adam(learning_rate=lr) pair: ParamOptimizerPair = ( self.get_parameters(module), - tf.keras.optimizers.Adam(learning_rate=lr), + optim, ) + # this isn't strictly necessary, but makes it so that if a checkpoint is + # computed before training actually starts, then it will be the same in + # shape / size as a checkpoint after training starts. + optim.build(module.trainable_variables) return pair @override(Learner) @@ -139,30 +144,124 @@ def load_state( with self._strategy.scope(): super().load_state(path) + def _save_optimizer_hparams( + self, + path: pathlib.Path, + optim: "tf.keras.optimizers.Optimizer", + optim_name: str, + ) -> None: + """Save the hyperparameters of optim to path/optim_name_hparams.json. + + Args: + path: The path to the directory to save the hyperparameters to. + optim: The optimizer to save the hyperparameters of. + optim_name: The name of the optimizer. + + """ + hparams = tf.keras.optimizers.serialize(optim) + hparams = tf.nest.map_structure(convert_numpy_to_python_primitives, hparams) + with open(path / f"{optim_name}_hparams.json", "w") as f: + json.dump(hparams, f) + + def _save_optimizer_state( + self, + path: pathlib.Path, + optim: "tf.keras.optimizers.Optimizer", + optim_name: str, + ) -> None: + """Save the state variables of optim to path/optim_name_state.txt. + + Args: + path: The path to the directory to save the state to. + optim: The optimizer to save the state of. + optim_name: The name of the optimizer. + + """ + state = optim.variables() + serialized_tensors = [tf.io.serialize_tensor(tensor) for tensor in state] + contents = tf.strings.join(serialized_tensors, separator="tensor: ") + tf.io.write_file(str(path / f"{optim_name}_state.txt"), contents) + @override(Learner) def _save_optimizers(self, path: Union[str, pathlib.Path]) -> None: path = pathlib.Path(path) path.mkdir(parents=True, exist_ok=True) for name, optim in self._named_optimizers.items(): - state = tf.keras.optimizers.serialize(optim) - state = tf.nest.map_structure(convert_numpy_to_python_primitives, state) - with open(path / f"{name}.json", "w") as f: - json.dump(state, f) + self._save_optimizer_hparams(path, optim, name) + self._save_optimizer_state(path, optim, name) + + def _load_optimizer_from_hparams( + self, path: pathlib.Path, optim_name: str + ) -> "tf.keras.optimizers.Optimizer": + """Load an optimizer from the hyperparameters saved at path/optim_name_hparams.json. + + Args: + path: The path to the directory to load the hyperparameters from. + optim_name: The name of the optimizer. + + Returns: + The optimizer loaded from the hyperparameters. + + """ + with open(path / f"{optim_name}_hparams.json", "r") as f: + state = json.load(f) + return tf.keras.optimizers.deserialize(state) + + def _load_optimizer_state( + self, + path: pathlib.Path, + optim: "tf.keras.optimizers.Optimizer", + optim_name: str, + ) -> None: + """Load the state of optim from the state saved at path/optim_name_state.txt. + + Args: + path: The path to the directory to load the state from. + optim: The optimizer to load the state into. + optim_name: The name of the optimizer. + + """ + contents = tf.io.read_file(str(path / f"{optim_name}_state.txt")) + serialized_tensors = tf.strings.split(contents, sep="tensor: ") + unserialized_optim_state = [] + for serialized_tensor, optim_tensor in zip( + serialized_tensors, optim.variables() + ): + unserialized_optim_state.append( + tf.io.parse_tensor(serialized_tensor, optim_tensor.dtype) + ) + + # set the state of the optimizer to the state that was saved + optim.set_weights(unserialized_optim_state) @override(Learner) def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None: path = pathlib.Path(path) for name in self._named_optimizers.keys(): - with open(path / f"{name}.json", "r") as f: - state = json.load(f) - new_optim = tf.keras.optimizers.deserialize(state) + new_optim = self._load_optimizer_from_hparams(path, name) old_optim = self._named_optimizers[name] + + # assign replace the old optim with the new optim in the learner's state self._named_optimizers[name] = new_optim param_seq = self._optimizer_parameters.pop(old_optim) self._optimizer_parameters[new_optim] = [] for param_ref in param_seq: self._optimizer_parameters[new_optim].append(param_ref) + # delete the old optimizer / free its memory + del old_optim + # these are the variables that the optimizer is supposed to optimize over + variable_list = [ + self._params[param_ref] + for param_ref in self._optimizer_parameters[new_optim] + ] + # initialize the optimizer with the variables that it is supposed to + # optimize over + new_optim.build(variable_list) + + # This loads in the actual state of the optimizer. + self._load_optimizer_state(path, new_optim, name) + @override(Learner) def set_weights(self, weights: Mapping[str, Any]) -> None: self._module.set_state(weights)