Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for saving to and restoring from checkpoints #399

Merged
merged 10 commits into from
Apr 19, 2023

Conversation

ramon-garcia
Copy link
Contributor

This pull request contributes support for saving and restoring of variables using tensorflow SaveV2/RestoreV2 ops, and therefore it is compatible with checkpoints created with tf.keras.Model.save_weights.

It includes a test case.

@ramon-garcia
Copy link
Contributor Author

A point for discussion: for my own code, I use anyhow crate to have context in exceptions. Otherwise, one gets an error code without any context.

I didn't add anyhow here to avoid add a new dependency to the project, and to be consistent with the rest of this library. But error messages without context are very difficult to diagnose.

@adamcrume
Copy link
Contributor

I don't think anyhow is appropriate, because it's intended for application code and forces you to use its error type. That would be awkward for a library, and switching to it would be a breaking change (since it would change the return type of functions).

I'm definitely in favor of including context in errors. We can manually add context in error messages; it's just a matter of doing it. To make it easier, we could add

impl Status {
  pub fn context(&self, msg: &str) -> Status {
    ...
  }
}

and use it with

my_tensorflow_function().map_err(|s| s.context("foo"))?;

To make this even more concise, we could do the same thing anyhow does and provide an extension trait for Result, but if we were to expose it outside this crate, we'd probably want to call the methods something other than context so they don't conflict for people using both tensorflow and anyhow.

Changes like this to Status should probably be its own PR, though.

@ramon-garcia
Copy link
Contributor Author

Thank you for your reply. Sounds good to me.

Copy link
Contributor

@adamcrume adamcrume left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great overall, just a few readability concerns.


#[derive(Debug)]
struct SaveRestoreOps {
pub prefix_save: Operation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to make these fields pub.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

/// Add save and restore ops to the graph
fn build_save_ops(&mut self) -> Result<SaveRestoreOps, Status> {
let mut all_variable_ops_opt: Option<Vec<Operation>> = None;
fn make_all_variable_ops(myself: &mut CheckpointMaker) -> Result<Vec<Operation>, Status> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just be a method on CheckpointMaker, since it doesn't capture anything. build_save_ops probably needs to be broken up a bit into smaller methods for readability, anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function make_all_variable_ops can be removed if we change from variable names to Variable objects. This is what I am going to implement. Then we can go back if desired.

@@ -0,0 +1,18 @@
// Similar to Option<T>.get_or_insert_with, for a function that returns a result.
pub trait OptionInsertWithResult<T> {
fn get_or_insert_with_result<F, E>(&mut self, f: F) -> Result<&T, E>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this return Result<&mut T, E> for consistency with get_or_insert_with?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, done

nd.set_attr_type_list("dtypes", &dtypes[..])?;
nd.finish()?
};
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scope doesn't do anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two scopes (one ending in line 138 and the one starting in 139), that look redundant, were to avoid borrowing errors. The conflict is between the mutable borrow
let mut g = self.scope.graph_mut();
and the immutable borrow

let var_op = self
                         .scope
                         .graph()

Because the mutable borrow lasts until the end of the scope.

But it can be done better, by dropping the mutable reference to g with drop(g).

}

struct MyScopeData {
pub scope: Scope,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to make these pub.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

pub restore_op: Operation,
}

/// Checkpointing and restoring struct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This documentation needs to be a little more in-depth.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I added documentation and a small example at the start of the module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the module is private, the module docs are not visible by default.

src/checkpoint.rs Outdated Show resolved Hide resolved
@ramon-garcia
Copy link
Contributor Author

Thank you for your review.
In addition, something I am not comfortable with, wouldn't it be better to use an array of variables rather than variable names? When I wrote it, I was concerned about not being able to create a Variable object (for instance, from an imported graph), but it seems to be an independent problem on its own, handled better in its own PR.

@ramon-garcia
Copy link
Contributor Author

Added to commits, the first one 013a829. The next one, 2247ed8 changes the specification of variables from strings to Variable objects.

Best regards.

@ramon-garcia
Copy link
Contributor Author

Hello?

src/checkpoint.rs Outdated Show resolved Hide resolved
src/checkpoint.rs Outdated Show resolved Hide resolved
pub restore_op: Operation,
}

/// Checkpointing and restoring struct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the module is private, the module docs are not visible by default.

@ramon-garcia
Copy link
Contributor Author

ping.

@adamcrume
Copy link
Contributor

Sorry for the delay, and thank you!

@adamcrume adamcrume merged commit 33977d3 into tensorflow:master Apr 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants