Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1333] Estimator and Fit API #14629

Merged
merged 13 commits into from
May 18, 2019
Merged

[MXNET-1333] Estimator and Fit API #14629

merged 13 commits into from
May 18, 2019

Conversation

roywei
Copy link
Member

@roywei roywei commented Apr 5, 2019

Description

This PR introduce an Estimator class in contrib with easy fit method to help beginners with model training.
It's been developed on a branch, and we hope to merge it to contrib and get feedback for first iteration.

Design: https://cwiki.apache.org/confluence/display/MXNET/Gluon+Fit+API+-+Tech+Design
JIRA epics: https://issues.apache.org/jira/browse/MXNET-1333
Dev list discussion: https://lists.apache.org/thread.html/13e3dee0fc9dd8e45b6616f97d282096a1ee67cde78a93dada295577@%3Cdev.mxnet.apache.org%3E
Feedbacks: currently all feedbacks are captured in cwiki comment section. We have created JIRA issues for each feedback and will continue to work on it
Follow up PRs:
We currently have the following PRs to address feedback, will create more and track using JIRA issue

  1. [MXNET-1358][Fit API] Fit api tutorial
  2. [MXNET-1340, 1339][Fit API]Multi input/output

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with MXNET-1333
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http:https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Estimator class for fit and evaluate
  • Evenhandler class for callbacks in fit methods
  • Unit tests
  • Integration/nightly tests

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@marcoabreu
Copy link
Contributor

Would it be possible to make these tests part of the training test suite instead of introducing new jobs?

@roywei
Copy link
Member Author

roywei commented Apr 5, 2019

@marcoabreu are you referring to training test suite here? we want to only test this on nightly. Please point to the correct test suite if I m wrong. Thanks!

@nswamy
Copy link
Member

nswamy commented Apr 5, 2019

Suggest to move it to contrib as there is some feedback from Mu on the dev@. We could also gather feedback from the users to see what other changes are required.

Could you please break all the backlog items into Jira tasks and paste the master ticket to this PR ? Any contributor interested to further contribute to this could pick up those tasks.

@roywei
Copy link
Member Author

roywei commented Apr 5, 2019

@nswamy done, all JIRA tickets has detailed description and reference to the feedback (either from cwiki or dev list discussion)

import numpy as np


class EventHandler(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please break this down into different event classes and use the same approach as gluon's forward hook through weakref. this has the benefit of:

  • people can mix what they need into a unified handler through multi-inheritence, without the unnecessary pass calls.
  • a handler can be detached at will.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i added a method to check if a handlers has implemented train_begin ect to avoid unnecessary pass calls. The time it takes should be the same as using multi-inheritence and do a bunch of isinstance() at the beginning. The benefit is user can override any event call without inherit multiple class.

Still looking into how to use forward hook to register different input args

Thanks!

Copy link
Member

@szha szha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. see above. more comments to come.

@piyushghai
Copy link
Contributor

Thanks for your contributions @roywei.
This is a very useful API for users and it simplifies a lot of pain points.

@mxnet-label-bot Add [Gluon, pr-awaiting-review]

@marcoabreu marcoabreu added Gluon pr-awaiting-review PR is waiting for code review labels Apr 5, 2019
@roywei
Copy link
Member Author

roywei commented Apr 5, 2019

Thanks @szha for the feedback! I m noting the summary of offlien discussion here:

  1. Currently event handlers has access to entire estimator. Estimator has to maintain different stats/states (e.g. current epoch, num of steps, metrics) and do the book keeping to ensure eventhandlers has them when they are called.
  2. If user want to add an event handler that want to use some info estimator does not have. He has to add it in estimator and change the estimator code to do book keeping. He has to know he can access current epoch from estimator.currrent_epoch, not some other variable name.
  3. Separate event handlers into 6 classes (train begin class, train end class ect) instead of single eventhandler parent class. Each class maintain it's own info/state and know what args to pass when called. So estimator does not do the state management.
  4. Using the gluon forward hook approch to provide avaibilty to detach a event handler. (e.g. remove some handler after n epochs/steps)

@nswamy
Copy link
Member

nswamy commented Apr 5, 2019

I am not sure I understand the concerns,

What is the problem with 1) ?

For 2), the user can create custom event Handler taking objects whatever it needs to keep track of.

MyEventHandler
def __init__(self, whatever1, whatever2):
     self.whatever1 = whatever1, ...

def train_begin():
     # dowhatever using self.whatever1

Given that the fit API is targeted at novice users, I think 3) is going to make it unnecessarily cumbersome.

What is the benefit of using the Forward Hook approach?

@nswamy
Copy link
Member

nswamy commented Apr 5, 2019

wouldn't it be easier for the user to write the training loop if they want more control instead of having the loop split 6 or more methods or hooks.
In my opinion trying to make it more flexible would complicate the usage and add to the users cognitive load.

nswamy
nswamy previously requested changes Apr 12, 2019
Copy link
Member

@nswamy nswamy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking the PR until my questions are answered.

I would like to understand why and how Sheng's proposal is better than the current design which was discussed offline and surfaced on dev@ months ago.
Last minute requests to fundamentally change the design should have very strong reasons.

@roywei
Copy link
Member Author

roywei commented Apr 17, 2019

@nswamy @szha I have addressed the concerns on callback in this doc: https://cwiki.apache.org/confluence/display/MXNET/Callback+Design+for+Fit+Loop
and created a PR here: #14685

Please help take a look, thanks!

@roywei roywei mentioned this pull request May 6, 2019
7 tasks
@roywei
Copy link
Member Author

roywei commented May 6, 2019

@eric-haibin-lin
Copy link
Member

My comments are addressed. Great work!!

roywei and others added 12 commits May 15, 2019 13:47
* base class for estimator and eventhandler

* add license

* add event handlers

* fix pylint

* improve arg check

* fix pylint

* add unit tests
#14464)

* Fixed issue where the estimator was printing beyond the dataset size for the last batch

* Added comments

* Nudge to CI
…API (#14442)

* added estimator unittests

* add more tests for estimator

* added validation logic

* added error handlers, unittests

* improve val stats

* fix pylint

* fix pylint

* update unit test

* fix tests

* fix tests

* updated metrics, val logic

* trigger ci

* trigger ci

* update metric, batch_fn error handler

* update context logic, add default metric
* add train history

* update history

* update test

* avoid calling empty methods

* remove train history object

* fix pylint

* add unit test

* fix test

* update categorize handlers
* Added RNN integration test for fit() API

* Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports

* CPU test doesn't require nvidiadocker container

* Modified the structure by removing the redundant code
* added cnn intg tests for fit api

* updated cnn intg tests

* added functions for nightly test

* updated runtime_function

* updated intg tests

* updated init, datapath, refs

* added validation data

* update cpu test

* refactor code

* updated context
…upport for Gluon fit() API (#14587)

* Retrieve Batch size and Logging verbose support for Gluon fit() API

* NIT changes

* Addressed review comments: shifted the batch size code to a separate method, sentence correction

* Modified unittest

* removed redundant parameter

* Resolve CI test failure

* only support DataLoader for now, future PRs will include DataIter to DataLoader converter

* Get the number of samples from shape attribute instead of length due to low space complexity

* Simplified batch size retrieval code

* removed batch_size parameter from fit() method and fixed the tests

* Verbose exception handling

* Assigning constant to a verbose

* Modified exception message

* Resolved undefined class reference

* Addressed review comments: Modified verbose level names, docs, variable names

* Update estimator.py
* improve event handlers

* update tests

* passing weakref of estimator

* fix unit test

* fix test

* fix pylint

* fix test

* fix pylint

* move default metric logic

* combine nightly tests
* move to nightly for binaries

* update default handler

* fix pylint

* trigger ci

* trigger ci
* address comments

* add comment

* check available context

* fix bug

* change cpu check
* address comments

* update checkpoint

* test symbol save

* address comments

* add resume

* update doc and resume checkpoint

* update docs

* trigger ci

* trigger ci
@roywei
Copy link
Member Author

roywei commented May 18, 2019

@szha @eric-haibin-lin CI finally passed, validation/miscellaneous job status returned not correctly (passed instead of pending). Could you help merge if looks good?

@szha szha merged commit 9f451fb into master May 18, 2019
@szha
Copy link
Member

szha commented May 18, 2019

Nice work! Great job upholding the quality even at the cost of several iterations. Well done!

roywei added a commit to roywei/incubator-mxnet that referenced this pull request May 20, 2019
eric-haibin-lin pushed a commit that referenced this pull request May 20, 2019
@roywei roywei mentioned this pull request May 20, 2019
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* [MXNet-1334][Fit API]base class for estimator and eventhandler (apache#14346)

* base class for estimator and eventhandler

* add license

* add event handlers

* fix pylint

* improve arg check

* fix pylint

* add unit tests

* Fixed issue where the estimator was printing beyond the dataset size … (apache#14464)

* Fixed issue where the estimator was printing beyond the dataset size for the last batch

* Added comments

* Nudge to CI

* [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (apache#14442)

* added estimator unittests

* add more tests for estimator

* added validation logic

* added error handlers, unittests

* improve val stats

* fix pylint

* fix pylint

* update unit test

* fix tests

* fix tests

* updated metrics, val logic

* trigger ci

* trigger ci

* update metric, batch_fn error handler

* update context logic, add default metric

* [MXNet-1340][Fit API]Update train stats (apache#14494)

* add train history

* update history

* update test

* avoid calling empty methods

* remove train history object

* fix pylint

* add unit test

* fix test

* update categorize handlers

* [MXNet-1375][Fit API]Added RNN integration test for fit() API (apache#14547)

* Added RNN integration test for fit() API

* Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports

* CPU test doesn't require nvidiadocker container

* Modified the structure by removing the redundant code

* [MXNet-1343][Fit API]Add CNN integration test for fit() API (apache#14405)

* added cnn intg tests for fit api

* updated cnn intg tests

* added functions for nightly test

* updated runtime_function

* updated intg tests

* updated init, datapath, refs

* added validation data

* update cpu test

* refactor code

* updated context

* [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (apache#14587)

* Retrieve Batch size and Logging verbose support for Gluon fit() API

* NIT changes

* Addressed review comments: shifted the batch size code to a separate method, sentence correction

* Modified unittest

* removed redundant parameter

* Resolve CI test failure

* only support DataLoader for now, future PRs will include DataIter to DataLoader converter

* Get the number of samples from shape attribute instead of length due to low space complexity

* Simplified batch size retrieval code

* removed batch_size parameter from fit() method and fixed the tests

* Verbose exception handling

* Assigning constant to a verbose

* Modified exception message

* Resolved undefined class reference

* Addressed review comments: Modified verbose level names, docs, variable names

* Update estimator.py

* move estimator to contrib (apache#14633)

* move to gluon contrib (apache#14635)

* [Fit API] improve event handlers (apache#14685)

* improve event handlers

* update tests

* passing weakref of estimator

* fix unit test

* fix test

* fix pylint

* fix test

* fix pylint

* move default metric logic

* combine nightly tests

* [MXNET-1396][Fit-API] Update default handler logic (apache#14765)

* move to nightly for binaries

* update default handler

* fix pylint

* trigger ci

* trigger ci

* [Fit API] update estimator (apache#14849)

* address comments

* add comment

* check available context

* fix bug

* change cpu check

* [Fit-API] Adress PR comments (apache#14885)

* address comments

* update checkpoint

* test symbol save

* address comments

* add resume

* update doc and resume checkpoint

* update docs

* trigger ci

* trigger ci
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
@szha szha deleted the fit-api branch September 8, 2019 03:40
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Gluon pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants