diff --git a/src/checkpoint.rs b/src/checkpoint.rs index ee1076772d..5854969765 100644 --- a/src/checkpoint.rs +++ b/src/checkpoint.rs @@ -1,17 +1,37 @@ -//! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format - +//! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format. +//! First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored. +//! The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring. +//! When one wants to save/restore from or into a session, one calls the save/restore methods +//! # Example +//! let mut scope = Scope::new_root_scope(); +//! // add operations to define the graph +//! // ... +//! // let "w" and "b" the name of the variables that we wish to save +//! let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"), +//! vec![String::from("w"), String::from("b")].into_boxed_slice(), +//! ); +//! let session = Session::new(&SessionOptions::new(), &scope.graph())?; +//! // run some training +//! // ... +//! // to save the training +//! checkpoint_maker.save(&session, "data/checkpoint")?; +//! // then we restore in a different session to continue there +//! let new_session = Session::new(&SessionOptions::new(), &scope.graph())?; +//! checkpoint_maker.save(&new_session, "data/checkpoint")?; use crate::option_insert_result::OptionInsertWithResult; use crate::{ops, Operation, Scope, Session, SessionRunArgs, Status, Tensor}; #[derive(Debug)] struct SaveRestoreOps { - pub prefix_save: Operation, - pub prefix_restore: Operation, - pub save_op: Operation, - pub restore_op: Operation, + prefix_save: Operation, + prefix_restore: Operation, + save_op: Operation, + restore_op: Operation, } -/// Checkpointing and restoring struct +/// Checkpointing and restoring support for Tensorflow. +/// This struct is manages a scope, adds lazily the Tensorflow ops +/// to perform the save/restore operations #[derive(Debug)] pub struct CheckpointMaker { scope: Scope, @@ -33,19 +53,20 @@ impl CheckpointMaker { } } + fn make_all_variable_ops(&mut self) -> Result, Status> { + let graph = self.scope.graph(); + Ok(self + .variables + .iter() + .map(|v: &String| -> Result { + Ok(graph.operation_by_name_required(v.as_str())?.clone()) + }) + .collect::, Status>>()?) + } + /// Add save and restore ops to the graph fn build_save_ops(&mut self) -> Result { let mut all_variable_ops_opt: Option> = None; - fn make_all_variable_ops(myself: &mut CheckpointMaker) -> Result, Status> { - let graph = myself.scope.graph(); - Ok(myself - .variables - .iter() - .map(|v: &String| -> Result { - Ok(graph.operation_by_name_required(v.as_str())?.clone()) - }) - .collect::, Status>>()?) - } let existing_save_op = self.scope.graph().operation_by_name("save")?; let (prefix_save, save_op) = if let Some(op) = existing_save_op { @@ -56,7 +77,7 @@ impl CheckpointMaker { (prefix_save_op, op) } else { let all_variable_ops = - all_variable_ops_opt.get_or_insert_with_result(|| make_all_variable_ops(self))?; + all_variable_ops_opt.get_or_insert_with_result(|| self.make_all_variable_ops())?; let prefix_save = ops::Placeholder::new() .dtype(crate::DataType::String) .build(&mut self.scope.with_op_name("prefix_save"))?; @@ -105,39 +126,37 @@ impl CheckpointMaker { (the_prefix_restore, op) } else { let all_variable_ops = - all_variable_ops_opt.get_or_insert_with_result(|| make_all_variable_ops(self))?; + all_variable_ops_opt.get_or_insert_with_result(|| self.make_all_variable_ops())?; let prefix_restore = ops::Placeholder::new() .dtype(crate::DataType::String) .build(&mut self.scope.with_op_name("prefix_restore"))?; - let restore_op = { - let all_var_names = self + let all_var_names = self + .variables + .iter() + .map(|v| v.to_string()) + .collect::>(); + let tensor_names = ops::constant(&all_var_names[..], &mut self.scope)?; + let shape_and_slices = ops::constant( + &self .variables .iter() - .map(|v| v.to_string()) - .collect::>(); - let tensor_names = ops::constant(&all_var_names[..], &mut self.scope)?; - let shape_and_slices = ops::constant( - &self - .variables - .iter() - .map(|_| "".to_string()) - .collect::>()[..], - &mut self.scope, - )?; - let mut g = self.scope.graph_mut(); - let mut nd = g.new_operation("RestoreV2", "restore")?; - nd.add_input(prefix_restore.clone()); - nd.add_input(tensor_names); - nd.add_input(shape_and_slices); - let dtypes = all_variable_ops - .iter() - .map(|v| v.get_attr_type("dtype")) - .collect::, Status>>()?; - nd.set_attr_type_list("dtypes", &dtypes[..])?; - nd.finish()? - }; - { - let mut restore_var_ops = Vec::::new(); + .map(|_| "".to_string()) + .collect::>()[..], + &mut self.scope, + )?; + let mut g = self.scope.graph_mut(); + let mut nd = g.new_operation("RestoreV2", "restore")?; + nd.add_input(prefix_restore.clone()); + nd.add_input(tensor_names); + nd.add_input(shape_and_slices); + let dtypes = all_variable_ops + .iter() + .map(|v| v.get_attr_type("dtype")) + .collect::, Status>>()?; + nd.set_attr_type_list("dtypes", &dtypes[..])?; + let restore_op = nd.finish()?; + drop(g); + let mut restore_var_ops = Vec::::new(); for (i, var) in self.variables.iter().enumerate() { let var_op = self .scope @@ -157,7 +176,6 @@ impl CheckpointMaker { no_op = no_op.add_control_input(op); } (prefix_restore, no_op.build(&mut self.scope)?) - } }; Ok(SaveRestoreOps { prefix_save, @@ -236,8 +254,8 @@ mod tests { } struct MyScopeData { - pub scope: Scope, - pub variables: [Variable; 3], + scope: Scope, + variables: [Variable; 3], } // Initialize a scope and place same variables in it diff --git a/src/option_insert_result.rs b/src/option_insert_result.rs index d22bec295a..1d41ba50fe 100644 --- a/src/option_insert_result.rs +++ b/src/option_insert_result.rs @@ -1,18 +1,18 @@ // Similar to Option.get_or_insert_with, for a function that returns a result. pub trait OptionInsertWithResult { - fn get_or_insert_with_result(&mut self, f: F) -> Result<&T, E> + fn get_or_insert_with_result(&mut self, f: F) -> Result<&mut T, E> where F: FnOnce() -> Result; } impl OptionInsertWithResult for Option { - fn get_or_insert_with_result(&mut self, f: F) -> Result<&T, E> + fn get_or_insert_with_result(&mut self, f: F) -> Result<&mut T, E> where F: FnOnce() -> Result, { if self.is_none() { *self = Some(f()?); } - Ok(self.as_ref().unwrap()) + Ok(self.as_mut().unwrap()) } }