Skip to content

Commit

Permalink
Expose Snapshot to pycaffe
Browse files Browse the repository at this point in the history
- Solver::Snapshot is made public
- It is also added as `snapshot` to pycaffe

Addressing #3077
  • Loading branch information
gustavla committed Sep 18, 2015
1 parent 3d12b5d commit b4f9add
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
10 changes: 5 additions & 5 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ class Solver {
// RestoreSolverStateFrom___ protected methods. You should implement these
// methods to restore the state from the appropriate snapshot type.
void Restore(const char* resume_file);
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
void Snapshot();
virtual ~Solver() {}
inline const SolverParameter& param() const { return param_; }
inline shared_ptr<Net<Dtype> > net() { return net_; }
Expand Down Expand Up @@ -87,11 +92,6 @@ class Solver {
protected:
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
void Snapshot();
string SnapshotFilename(const string extension);
string SnapshotToBinaryProto();
string SnapshotToHDF5();
Expand Down
3 changes: 2 additions & 1 deletion python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
&Solver<Dtype>::Solve), SolveOverloads())
.def("step", &Solver<Dtype>::Step)
.def("restore", &Solver<Dtype>::Restore);
.def("restore", &Solver<Dtype>::Restore)
.def("snapshot", &Solver<Dtype>::Snapshot);

bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(
Expand Down

0 comments on commit b4f9add

Please sign in to comment.