diff --git a/rllib/core/learner/tf/tf_learner.py b/rllib/core/learner/tf/tf_learner.py index 004ffa55825f7..0183f0fc2e648 100644 --- a/rllib/core/learner/tf/tf_learner.py +++ b/rllib/core/learner/tf/tf_learner.py @@ -78,10 +78,15 @@ def configure_optimizer_per_module( ) -> Union[ParamOptimizerPair, NamedParamOptimizerPairs]: module = self._module[module_id] lr = self._optimizer_config["lr"] + optim = tf.keras.optimizers.Adam(learning_rate=lr) pair: ParamOptimizerPair = ( self.get_parameters(module), - tf.keras.optimizers.Adam(learning_rate=lr), + optim, ) + # this isn't strictly necessary, but makes it so that if a checkpoint is + # computed before training actually starts, then it will be the same in + # shape / size as a checkpoint after training starts. + optim.build(module.trainable_variables) return pair @override(Learner) @@ -139,30 +144,124 @@ def load_state( with self._strategy.scope(): super().load_state(path) + def _save_optimizer_hparams( + self, + path: pathlib.Path, + optim: "tf.keras.optimizers.Optimizer", + optim_name: str, + ) -> None: + """Save the hyperparameters of optim to path/optim_name_hparams.json. + + Args: + path: The path to the directory to save the hyperparameters to. + optim: The optimizer to save the hyperparameters of. + optim_name: The name of the optimizer. + + """ + hparams = tf.keras.optimizers.serialize(optim) + hparams = tf.nest.map_structure(convert_numpy_to_python_primitives, hparams) + with open(path / f"{optim_name}_hparams.json", "w") as f: + json.dump(hparams, f) + + def _save_optimizer_state( + self, + path: pathlib.Path, + optim: "tf.keras.optimizers.Optimizer", + optim_name: str, + ) -> None: + """Save the state variables of optim to path/optim_name_state.txt. + + Args: + path: The path to the directory to save the state to. + optim: The optimizer to save the state of. + optim_name: The name of the optimizer. + + """ + state = optim.variables() + serialized_tensors = [tf.io.serialize_tensor(tensor) for tensor in state] + contents = tf.strings.join(serialized_tensors, separator="tensor: ") + tf.io.write_file(str(path / f"{optim_name}_state.txt"), contents) + @override(Learner) def _save_optimizers(self, path: Union[str, pathlib.Path]) -> None: path = pathlib.Path(path) path.mkdir(parents=True, exist_ok=True) for name, optim in self._named_optimizers.items(): - state = tf.keras.optimizers.serialize(optim) - state = tf.nest.map_structure(convert_numpy_to_python_primitives, state) - with open(path / f"{name}.json", "w") as f: - json.dump(state, f) + self._save_optimizer_hparams(path, optim, name) + self._save_optimizer_state(path, optim, name) + + def _load_optimizer_from_hparams( + self, path: pathlib.Path, optim_name: str + ) -> "tf.keras.optimizers.Optimizer": + """Load an optimizer from the hyperparameters saved at path/optim_name_hparams.json. + + Args: + path: The path to the directory to load the hyperparameters from. + optim_name: The name of the optimizer. + + Returns: + The optimizer loaded from the hyperparameters. + + """ + with open(path / f"{optim_name}_hparams.json", "r") as f: + state = json.load(f) + return tf.keras.optimizers.deserialize(state) + + def _load_optimizer_state( + self, + path: pathlib.Path, + optim: "tf.keras.optimizers.Optimizer", + optim_name: str, + ) -> None: + """Load the state of optim from the state saved at path/optim_name_state.txt. + + Args: + path: The path to the directory to load the state from. + optim: The optimizer to load the state into. + optim_name: The name of the optimizer. + + """ + contents = tf.io.read_file(str(path / f"{optim_name}_state.txt")) + serialized_tensors = tf.strings.split(contents, sep="tensor: ") + unserialized_optim_state = [] + for serialized_tensor, optim_tensor in zip( + serialized_tensors, optim.variables() + ): + unserialized_optim_state.append( + tf.io.parse_tensor(serialized_tensor, optim_tensor.dtype) + ) + + # set the state of the optimizer to the state that was saved + optim.set_weights(unserialized_optim_state) @override(Learner) def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None: path = pathlib.Path(path) for name in self._named_optimizers.keys(): - with open(path / f"{name}.json", "r") as f: - state = json.load(f) - new_optim = tf.keras.optimizers.deserialize(state) + new_optim = self._load_optimizer_from_hparams(path, name) old_optim = self._named_optimizers[name] + + # assign replace the old optim with the new optim in the learner's state self._named_optimizers[name] = new_optim param_seq = self._optimizer_parameters.pop(old_optim) self._optimizer_parameters[new_optim] = [] for param_ref in param_seq: self._optimizer_parameters[new_optim].append(param_ref) + # delete the old optimizer / free its memory + del old_optim + # these are the variables that the optimizer is supposed to optimize over + variable_list = [ + self._params[param_ref] + for param_ref in self._optimizer_parameters[new_optim] + ] + # initialize the optimizer with the variables that it is supposed to + # optimize over + new_optim.build(variable_list) + + # This loads in the actual state of the optimizer. + self._load_optimizer_state(path, new_optim, name) + @override(Learner) def set_weights(self, weights: Mapping[str, Any]) -> None: self._module.set_state(weights)