Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Fit API] improve event handlers #14685

Merged
merged 10 commits into from
Apr 19, 2019
Merged

[Fit API] improve event handlers #14685

merged 10 commits into from
Apr 19, 2019

Conversation

roywei
Copy link
Member

@roywei roywei commented Apr 12, 2019

Description

Making the follwing on evetn handlers based on the design here:
https://cwiki.apache.org/confluence/display/MXNET/Callback+Design+for+Fit+Loop

  1. Making metric update and validation logic in event handlers
  2. Each event handler maintain it's own states
  3. Passing a weak reference of estimator at each callback call, so some attributes are passed(net, trainer, etc)

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http:https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@roywei roywei requested a review from szha as a code owner April 12, 2019 18:07
@nswamy nswamy added Gluon pr-awaiting-review PR is waiting for code review labels Apr 15, 2019
@roywei roywei mentioned this pull request Apr 17, 2019
5 tasks
@roywei roywei requested a review from marcoabreu as a code owner April 18, 2019 23:40
@szha szha merged commit 7e10355 into apache:fit-api Apr 19, 2019
Copy link
Member

@nswamy nswamy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for patiently accommodating the last minute design change requests. I have a few comments would like you to know what you think and create a follow up PR if necessary.

losses = []
for loss in self.loss:
losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)])
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the model had multiple loss functions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multi loss will be supported in #14628, let's get the first version into master and iterate on that.

val_metrics=val_metrics))
event_handlers.append(LoggingHandler(train_metrics=train_metrics,
val_metrics=val_metrics))
warnings.warn("No Event Handler specified, default %s are used. "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you write this warning using the LoggingHandler's logger? so the user has one place to control the log levels and look for.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! for now we can only do this for estimator and handlers, any other warning from mxnet and gluon still can't be controlled. tracked here: https://issues.apache.org/jira/browse/MXNET-1395

losses = []
for loss in self.loss:
losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)])
loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing, using only a single loss?

Copy link
Member Author

@roywei roywei Apr 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above

multi loss will be supported in #14628, let's get the first version into master and iterate on that.

for metric in self.train_metrics:
metric.reset()

def batch_end(self, estimator, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to capture this for every batch by default. I think we should update once per epoch by default and let the user control.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once batch end we lost that batch's label and pred

self.train_metrics = train_metrics or []
# order to be called among all callbacks
# metrics need to be calculated before other callbacks can access them
self.priority = -np.Inf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the priority should be exposed in the base class, otherwise the user who writes custom handlers has no clue about this and the order is based on this.
I am not sure how python resolves the ambiguity.

Copy link
Member Author

@roywei roywei Apr 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mechanism is mainly for internal use. I m making sure metric and validation are called first and logging are called last. I'm trying to reduce the information user need to know, they can order their own event handlers in the list before passing to fit()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets call this out explicitly in the documentation.

self.epoch_period = epoch_period
self.batch_period = batch_period
self.val_metrics = val_metrics
self.num_batches = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how will the user control batch_period and num_batches when you are using this when no handlers are specified. does he have to specify all the handlers to make this change?
do you think we should make this static, so user can independently update this, one drawback is if there are multiple ValidationHandlers used in the same process all of them get the same value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! if we provide default handler one by one (so user don't need to re-create all just to custom one of them). We need a mechanism to make sure all handlers has the reference of the same set of metric objects. or make handlers an attributes so they can be configured after default handlers been created. tracked https://issues.apache.org/jira/browse/MXNET-1396

file_location=None,
verbose=LOG_VERBOSITY_PER_EPOCH,
train_metrics=None,
val_metrics=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to customize, given that prepare_loss_and_metrics happen after the estimator is created. If at all the order should be

e = Estimator()
e.prepare_loss_and_metrics()
lh = LoggingHandler(..., train_metrics=[e.train_metrics[0], e.train_metrics[1]], ..)

is this what you were thinking?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that's correct

self.logger.info(msg)
self.batch_index += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be in the estimator itself, why should all handlers maintain this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not all handlers need the same set of these infomation, so they maintain whatever they want to use. This also prevents if one handler changed self.estimator.total_steps wrongly, it will cause all other handlers to fail

'for example val_accuracy', self.monitor))
self.estimator.net.save_parameters(self.filepath)
if np.isnan(monitor_value):
warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use logger, so the user can control.

@@ -191,18 +287,23 @@ class CheckpointHandler(EventHandler):

def __init__(self,
filepath,
monitor='val_accuracy',
monitor=None,
verbose=0,
save_best_only=False,
mode='auto',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you expand what different modes mean in the doc.

Copy link
Member Author

@roywei roywei Apr 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's explained in the doc string

roywei added a commit to roywei/incubator-mxnet that referenced this pull request May 15, 2019
* improve event handlers

* update tests

* passing weakref of estimator

* fix unit test

* fix test

* fix pylint

* fix test

* fix pylint

* move default metric logic

* combine nightly tests
roywei added a commit to roywei/incubator-mxnet that referenced this pull request May 15, 2019
* improve event handlers

* update tests

* passing weakref of estimator

* fix unit test

* fix test

* fix pylint

* fix test

* fix pylint

* move default metric logic

* combine nightly tests
szha pushed a commit that referenced this pull request May 18, 2019
* [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346)

* base class for estimator and eventhandler

* add license

* add event handlers

* fix pylint

* improve arg check

* fix pylint

* add unit tests

* Fixed issue where the estimator was printing beyond the dataset size … (#14464)

* Fixed issue where the estimator was printing beyond the dataset size for the last batch

* Added comments

* Nudge to CI

* [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442)

* added estimator unittests

* add more tests for estimator

* added validation logic

* added error handlers, unittests

* improve val stats

* fix pylint

* fix pylint

* update unit test

* fix tests

* fix tests

* updated metrics, val logic

* trigger ci

* trigger ci

* update metric, batch_fn error handler

* update context logic, add default metric

* [MXNet-1340][Fit API]Update train stats (#14494)

* add train history

* update history

* update test

* avoid calling empty methods

* remove train history object

* fix pylint

* add unit test

* fix test

* update categorize handlers

* [MXNet-1375][Fit API]Added RNN integration test for fit() API (#14547)

* Added RNN integration test for fit() API

* Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports

* CPU test doesn't require nvidiadocker container

* Modified the structure by removing the redundant code

* [MXNet-1343][Fit API]Add CNN integration test for fit() API (#14405)

* added cnn intg tests for fit api

* updated cnn intg tests

* added functions for nightly test

* updated runtime_function

* updated intg tests

* updated init, datapath, refs

* added validation data

* update cpu test

* refactor code

* updated context

* [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (#14587)

* Retrieve Batch size and Logging verbose support for Gluon fit() API

* NIT changes

* Addressed review comments: shifted the batch size code to a separate method, sentence correction

* Modified unittest

* removed redundant parameter

* Resolve CI test failure

* only support DataLoader for now, future PRs will include DataIter to DataLoader converter

* Get the number of samples from shape attribute instead of length due to low space complexity

* Simplified batch size retrieval code

* removed batch_size parameter from fit() method and fixed the tests

* Verbose exception handling

* Assigning constant to a verbose

* Modified exception message

* Resolved undefined class reference

* Addressed review comments: Modified verbose level names, docs, variable names

* Update estimator.py

* move estimator to contrib (#14633)

* move to gluon contrib (#14635)

* [Fit API] improve event handlers (#14685)

* improve event handlers

* update tests

* passing weakref of estimator

* fix unit test

* fix test

* fix pylint

* fix test

* fix pylint

* move default metric logic

* combine nightly tests

* [MXNET-1396][Fit-API] Update default handler logic (#14765)

* move to nightly for binaries

* update default handler

* fix pylint

* trigger ci

* trigger ci

* [Fit API] update estimator (#14849)

* address comments

* add comment

* check available context

* fix bug

* change cpu check

* [Fit-API] Adress PR comments (#14885)

* address comments

* update checkpoint

* test symbol save

* address comments

* add resume

* update doc and resume checkpoint

* update docs

* trigger ci

* trigger ci
szha pushed a commit that referenced this pull request May 20, 2019
* improve event handlers

* update tests

* passing weakref of estimator

* fix unit test

* fix test

* fix pylint

* fix test

* fix pylint

* move default metric logic

* combine nightly tests
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* [MXNet-1334][Fit API]base class for estimator and eventhandler (apache#14346)

* base class for estimator and eventhandler

* add license

* add event handlers

* fix pylint

* improve arg check

* fix pylint

* add unit tests

* Fixed issue where the estimator was printing beyond the dataset size … (apache#14464)

* Fixed issue where the estimator was printing beyond the dataset size for the last batch

* Added comments

* Nudge to CI

* [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (apache#14442)

* added estimator unittests

* add more tests for estimator

* added validation logic

* added error handlers, unittests

* improve val stats

* fix pylint

* fix pylint

* update unit test

* fix tests

* fix tests

* updated metrics, val logic

* trigger ci

* trigger ci

* update metric, batch_fn error handler

* update context logic, add default metric

* [MXNet-1340][Fit API]Update train stats (apache#14494)

* add train history

* update history

* update test

* avoid calling empty methods

* remove train history object

* fix pylint

* add unit test

* fix test

* update categorize handlers

* [MXNet-1375][Fit API]Added RNN integration test for fit() API (apache#14547)

* Added RNN integration test for fit() API

* Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports

* CPU test doesn't require nvidiadocker container

* Modified the structure by removing the redundant code

* [MXNet-1343][Fit API]Add CNN integration test for fit() API (apache#14405)

* added cnn intg tests for fit api

* updated cnn intg tests

* added functions for nightly test

* updated runtime_function

* updated intg tests

* updated init, datapath, refs

* added validation data

* update cpu test

* refactor code

* updated context

* [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (apache#14587)

* Retrieve Batch size and Logging verbose support for Gluon fit() API

* NIT changes

* Addressed review comments: shifted the batch size code to a separate method, sentence correction

* Modified unittest

* removed redundant parameter

* Resolve CI test failure

* only support DataLoader for now, future PRs will include DataIter to DataLoader converter

* Get the number of samples from shape attribute instead of length due to low space complexity

* Simplified batch size retrieval code

* removed batch_size parameter from fit() method and fixed the tests

* Verbose exception handling

* Assigning constant to a verbose

* Modified exception message

* Resolved undefined class reference

* Addressed review comments: Modified verbose level names, docs, variable names

* Update estimator.py

* move estimator to contrib (apache#14633)

* move to gluon contrib (apache#14635)

* [Fit API] improve event handlers (apache#14685)

* improve event handlers

* update tests

* passing weakref of estimator

* fix unit test

* fix test

* fix pylint

* fix test

* fix pylint

* move default metric logic

* combine nightly tests

* [MXNET-1396][Fit-API] Update default handler logic (apache#14765)

* move to nightly for binaries

* update default handler

* fix pylint

* trigger ci

* trigger ci

* [Fit API] update estimator (apache#14849)

* address comments

* add comment

* check available context

* fix bug

* change cpu check

* [Fit-API] Adress PR comments (apache#14885)

* address comments

* update checkpoint

* test symbol save

* address comments

* add resume

* update doc and resume checkpoint

* update docs

* trigger ci

* trigger ci
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* improve event handlers

* update tests

* passing weakref of estimator

* fix unit test

* fix test

* fix pylint

* fix test

* fix pylint

* move default metric logic

* combine nightly tests
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Gluon pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants