Skip to content

Commit

Permalink
Address reviewer's requests
Browse files Browse the repository at this point in the history
  • Loading branch information
ramon-garcia committed Mar 8, 2023
1 parent 568ff56 commit 013a829
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 52 deletions.
116 changes: 67 additions & 49 deletions src/checkpoint.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
//! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format

//! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format.
//! First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored.
//! The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring.
//! When one wants to save/restore from or into a session, one calls the save/restore methods
//! # Example
//! let mut scope = Scope::new_root_scope();
//! // add operations to define the graph
//! // ...
//! // let "w" and "b" the name of the variables that we wish to save
//! let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
//! vec![String::from("w"), String::from("b")].into_boxed_slice(),
//! );
//! let session = Session::new(&SessionOptions::new(), &scope.graph())?;
//! // run some training
//! // ...
//! // to save the training
//! checkpoint_maker.save(&session, "data/checkpoint")?;
//! // then we restore in a different session to continue there
//! let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
//! checkpoint_maker.save(&new_session, "data/checkpoint")?;
use crate::option_insert_result::OptionInsertWithResult;
use crate::{ops, Operation, Scope, Session, SessionRunArgs, Status, Tensor};

#[derive(Debug)]
struct SaveRestoreOps {
pub prefix_save: Operation,
pub prefix_restore: Operation,
pub save_op: Operation,
pub restore_op: Operation,
prefix_save: Operation,
prefix_restore: Operation,
save_op: Operation,
restore_op: Operation,
}

/// Checkpointing and restoring struct
/// Checkpointing and restoring support for Tensorflow.
/// This struct is manages a scope, adds lazily the Tensorflow ops
/// to perform the save/restore operations
#[derive(Debug)]
pub struct CheckpointMaker {
scope: Scope,
Expand All @@ -33,19 +53,20 @@ impl CheckpointMaker {
}
}

fn make_all_variable_ops(&mut self) -> Result<Vec<Operation>, Status> {
let graph = self.scope.graph();
Ok(self
.variables
.iter()
.map(|v: &String| -> Result<Operation, Status> {
Ok(graph.operation_by_name_required(v.as_str())?.clone())
})
.collect::<Result<Vec<_>, Status>>()?)
}

/// Add save and restore ops to the graph
fn build_save_ops(&mut self) -> Result<SaveRestoreOps, Status> {
let mut all_variable_ops_opt: Option<Vec<Operation>> = None;
fn make_all_variable_ops(myself: &mut CheckpointMaker) -> Result<Vec<Operation>, Status> {
let graph = myself.scope.graph();
Ok(myself
.variables
.iter()
.map(|v: &String| -> Result<Operation, Status> {
Ok(graph.operation_by_name_required(v.as_str())?.clone())
})
.collect::<Result<Vec<_>, Status>>()?)
}

let existing_save_op = self.scope.graph().operation_by_name("save")?;
let (prefix_save, save_op) = if let Some(op) = existing_save_op {
Expand All @@ -56,7 +77,7 @@ impl CheckpointMaker {
(prefix_save_op, op)
} else {
let all_variable_ops =
all_variable_ops_opt.get_or_insert_with_result(|| make_all_variable_ops(self))?;
all_variable_ops_opt.get_or_insert_with_result(|| self.make_all_variable_ops())?;
let prefix_save = ops::Placeholder::new()
.dtype(crate::DataType::String)
.build(&mut self.scope.with_op_name("prefix_save"))?;
Expand Down Expand Up @@ -105,39 +126,37 @@ impl CheckpointMaker {
(the_prefix_restore, op)
} else {
let all_variable_ops =
all_variable_ops_opt.get_or_insert_with_result(|| make_all_variable_ops(self))?;
all_variable_ops_opt.get_or_insert_with_result(|| self.make_all_variable_ops())?;
let prefix_restore = ops::Placeholder::new()
.dtype(crate::DataType::String)
.build(&mut self.scope.with_op_name("prefix_restore"))?;
let restore_op = {
let all_var_names = self
let all_var_names = self
.variables
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>();
let tensor_names = ops::constant(&all_var_names[..], &mut self.scope)?;
let shape_and_slices = ops::constant(
&self
.variables
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>();
let tensor_names = ops::constant(&all_var_names[..], &mut self.scope)?;
let shape_and_slices = ops::constant(
&self
.variables
.iter()
.map(|_| "".to_string())
.collect::<Vec<_>>()[..],
&mut self.scope,
)?;
let mut g = self.scope.graph_mut();
let mut nd = g.new_operation("RestoreV2", "restore")?;
nd.add_input(prefix_restore.clone());
nd.add_input(tensor_names);
nd.add_input(shape_and_slices);
let dtypes = all_variable_ops
.iter()
.map(|v| v.get_attr_type("dtype"))
.collect::<Result<Vec<_>, Status>>()?;
nd.set_attr_type_list("dtypes", &dtypes[..])?;
nd.finish()?
};
{
let mut restore_var_ops = Vec::<Operation>::new();
.map(|_| "".to_string())
.collect::<Vec<_>>()[..],
&mut self.scope,
)?;
let mut g = self.scope.graph_mut();
let mut nd = g.new_operation("RestoreV2", "restore")?;
nd.add_input(prefix_restore.clone());
nd.add_input(tensor_names);
nd.add_input(shape_and_slices);
let dtypes = all_variable_ops
.iter()
.map(|v| v.get_attr_type("dtype"))
.collect::<Result<Vec<_>, Status>>()?;
nd.set_attr_type_list("dtypes", &dtypes[..])?;
let restore_op = nd.finish()?;
drop(g);
let mut restore_var_ops = Vec::<Operation>::new();
for (i, var) in self.variables.iter().enumerate() {
let var_op = self
.scope
Expand All @@ -157,7 +176,6 @@ impl CheckpointMaker {
no_op = no_op.add_control_input(op);
}
(prefix_restore, no_op.build(&mut self.scope)?)
}
};
Ok(SaveRestoreOps {
prefix_save,
Expand Down Expand Up @@ -236,8 +254,8 @@ mod tests {
}

struct MyScopeData {
pub scope: Scope,
pub variables: [Variable; 3],
scope: Scope,
variables: [Variable; 3],
}

// Initialize a scope and place same variables in it
Expand Down
6 changes: 3 additions & 3 deletions src/option_insert_result.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
// Similar to Option<T>.get_or_insert_with, for a function that returns a result.
pub trait OptionInsertWithResult<T> {
fn get_or_insert_with_result<F, E>(&mut self, f: F) -> Result<&T, E>
fn get_or_insert_with_result<F, E>(&mut self, f: F) -> Result<&mut T, E>
where
F: FnOnce() -> Result<T, E>;
}

impl<T> OptionInsertWithResult<T> for Option<T> {
fn get_or_insert_with_result<F, E>(&mut self, f: F) -> Result<&T, E>
fn get_or_insert_with_result<F, E>(&mut self, f: F) -> Result<&mut T, E>
where
F: FnOnce() -> Result<T, E>,
{
if self.is_none() {
*self = Some(f()?);
}
Ok(self.as_ref().unwrap())
Ok(self.as_mut().unwrap())
}
}

0 comments on commit 013a829

Please sign in to comment.