Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNet-1334][Fit API]base class for estimator and eventhandler #14346

Merged
merged 7 commits into from
Mar 16, 2019

Conversation

roywei
Copy link
Member

@roywei roywei commented Mar 6, 2019

Description

This is the first PR for the Gluon Fit API proposed here: Gluon Fit API tech design

It add base classes for estimator and eventhandlers, and an example. We would like to merge into fit-api branch so other controbutors can work on top of it before making it available on master branch.

JIRA epic can be found here: https://issues.apache.org/jira/browse/MXNET-1333
We are separating this into a few PRs, all tracked by JIRA issues under the epic.

We will open follow up PRs for the following tasks:

  1. CNN and RNN examples using the fit API [MXNET-1335][fit api]Text Sentiment Classification examples using Gluon fit() API #14350
  2. User experience improvement on initialization, context management, and hybridization
  3. More eventhandlers like Checkpoint, early stopping.
  4. Multi output support for Multi-task and SSD use cases and examples
  5. Multi trainer support for encoder decoder and neural machine translation example.
  6. Unit tests and integration tests
  7. Documentation

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with MXNET-1334 created
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  1. Base class for estimator with fit method
  2. Base class for event handlers
  3. simple cnn example on mnist dataset

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@ankkhedia

@roywei roywei requested a review from szha as a code owner March 6, 2019 17:52
@vandanavk
Copy link
Contributor

@mxnet-label-bot add [Gluon, pr-work-in-progress]

metric.reset()

for i, batch in enumerate(train_data):
data, label = self._batch_fn(batch, self.context)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @roywei, in NVIDIA DALI we return already sliced data, could you make _batch_fn optional?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ptrendx sounds good!any examples on multi-gpu? would love to check it out.

example/gluon/estimator_example/mnist_cnn.py Outdated Show resolved Hide resolved
example/gluon/estimator_example/mnist_cnn.py Outdated Show resolved Hide resolved
example/gluon/estimator_example/mnist_cnn.py Outdated Show resolved Hide resolved
self.loss = [loss]
else:
self.loss = loss or []
if isinstance(metrics, EvalMetric):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not work for a list of metrics?

>>> class A:
...     def __init(self):
...             pass
...
>>> A
<class __main__.A at 0x721d0fe88>
>>> x = A
>>> x = [A(), A()]
>>> isinstance(x, A)
False

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it accepts single metric object or list of metrics(in the else part)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regardless of whether a single metric or a list of them are passed, we should have this validation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nswamy updated

python/mxnet/gluon/estimator/estimator.py Show resolved Hide resolved
self.context = [cpu()]

# initialize the network
if self.initializer:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to move the initializing and checking of initializing into one method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
return data, label

def fit(self, train_data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: train_data-> train_dataloader, ..
if valid_data is not passed are you going to split the train data with some percentage. if not i think we should

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently dataloader does not support split into train and val after it's created. User can create that split outside estimator using a random sampler . Same behavior with pytorch . Keras allows this but only because the fit method is accepting numpy arrays directly and it can be slow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val_data is missing in the signature? Also it should not be optional if you expect user to pass?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rename to train_dataloader

batch_size=None,
event_handlers=None):

if not batch_size:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log.info( with the batch size you are using).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be added in train_begin in loggin handler with the change on train_stats in a follow up PR

self.logger.addHandler(filehandler)

def train_begin(self):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be a good idea to log all the hyper params and defaults used here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be added with the change on train_stats in a follow up PR

pass
# logger.info(opt)

def train_end(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output informational message like number of epochs run, train_loss and valid_loss , other metrics at the end, etc.,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be added with the change on train_stats in a follow up PR

self.trainers = [trainers]
else:
self.trainers = trainers or []
if not self.trainers:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we dealing with multiple trainer case over here (e.g. Multi task classification)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be addressed in another PR (TODO item 4)

handler.batch_end()

for metric in self.metrics + self.loss_metrics:
self.train_stats['train_' + metric.name].append(metric.get()[1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation metrics would also be needed in train_stats. How are we dealing with it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's tracked here, will be address in follow up PR

@roywei
Copy link
Member Author

roywei commented Mar 12, 2019

@nswamy @ankkhedia I have addressed comments, added checkpoint and early stopping handler, could you take another look? Thanks

Remaining todo list from your comments, will be addressed in follow up PRs

  1. update train stats (change to object instead of dictionary), initialization logic MXNet-1340
  1. Validation logic MXNet-1349
  2. multi trainer MXNet-1339

@roywei roywei changed the title [MXNet-1334][WIP][Fit API]base class for estimator and eventhandler [MXNet-1334][Fit API]base class for estimator and eventhandler Mar 13, 2019
@roywei
Copy link
Member Author

roywei commented Mar 14, 2019

@mxnet-label-bot update[Gluon, pr-awaiting-review]

@marcoabreu marcoabreu added pr-awaiting-review PR is waiting for code review and removed pr-work-in-progress PR is still work in progress labels Mar 14, 2019
elif isinstance(train_data, DataIter):
data, label = self._batch_fn(batch, self.context, is_iterator=True)
else:
raise ValueError("You are using a custom iteration, please also provide "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit : custom iterator

early_stopping = [event_handler.EarlyStoppingHandler(est, monitor,
patience=patience,
mode=mode)]
est.fit(test_data, event_handlers=early_stopping, epochs=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No assertions here ?

epoch = self._estimator.train_stats['epochs'][-1]
msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
for key in self._estimator.train_stats.keys():
if key.startswith('train_') or key.startswith('test_'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key should be val_ ?

if isinstance(metrics, EvalMetric):
self.metrics = [metrics]
else:
self.metrics = metrics or []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we allow anything but EvalMetric, if not please validate each metric in the list is a EvalMetric.

label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
return data, label

def fit(self, train_data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val_data is missing in the signature? Also it should not be optional if you expect user to pass?

self.context = [cpu()]

# initialize the network
if self.initializer:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
return data, label

def fit(self, train_data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rename to train_dataloader

est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
logging_handler = [event_handler.LoggingHandler(est, file_name=file_name, file_location=tmpdir)]
est.fit(test_data, event_handlers=logging_handler, epochs=1)
assert os.path.isfile(output_dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to do a sanity check that validation accuracy for last epoch is logged(also that training process completed?

batch_time = time.time() - self.batch_start
epoch = self._estimator.train_stats['epochs'][-1]
step = self._estimator.train_stats['step']
msg = '[Epoch %d] [Step %s] time/step: %.3fs ' % (epoch, step, batch_time)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we use the term step, lets not introduce new terms lets continue to use batch, look at existing logging

self._estimator.net.save_parameters(self.filepath)


class EarlyStoppingHandler(EventHandler):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please see if we can have a integration test for this?

@nswamy
Copy link
Member

nswamy commented Mar 16, 2019

I will merge this code since its on a branch, please make sure to address all comments in subsequent PRs before a PR to the master

@nswamy nswamy merged commit 41392fa into apache:fit-api Mar 16, 2019
piyushghai pushed a commit to piyushghai/incubator-mxnet that referenced this pull request Apr 5, 2019
…e#14346)

* base class for estimator and eventhandler

* add license

* add event handlers

* fix pylint

* improve arg check

* fix pylint

* add unit tests
nswamy pushed a commit to nswamy/incubator-mxnet that referenced this pull request Apr 5, 2019
…e#14346)

* base class for estimator and eventhandler

* add license

* add event handlers

* fix pylint

* improve arg check

* fix pylint

* add unit tests
roywei added a commit to roywei/incubator-mxnet that referenced this pull request May 15, 2019
…e#14346)

* base class for estimator and eventhandler

* add license

* add event handlers

* fix pylint

* improve arg check

* fix pylint

* add unit tests
roywei added a commit to roywei/incubator-mxnet that referenced this pull request May 15, 2019
…e#14346)

* base class for estimator and eventhandler

* add license

* add event handlers

* fix pylint

* improve arg check

* fix pylint

* add unit tests
szha pushed a commit that referenced this pull request May 18, 2019
* [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346)

* base class for estimator and eventhandler

* add license

* add event handlers

* fix pylint

* improve arg check

* fix pylint

* add unit tests

* Fixed issue where the estimator was printing beyond the dataset size … (#14464)

* Fixed issue where the estimator was printing beyond the dataset size for the last batch

* Added comments

* Nudge to CI

* [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442)

* added estimator unittests

* add more tests for estimator

* added validation logic

* added error handlers, unittests

* improve val stats

* fix pylint

* fix pylint

* update unit test

* fix tests

* fix tests

* updated metrics, val logic

* trigger ci

* trigger ci

* update metric, batch_fn error handler

* update context logic, add default metric

* [MXNet-1340][Fit API]Update train stats (#14494)

* add train history

* update history

* update test

* avoid calling empty methods

* remove train history object

* fix pylint

* add unit test

* fix test

* update categorize handlers

* [MXNet-1375][Fit API]Added RNN integration test for fit() API (#14547)

* Added RNN integration test for fit() API

* Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports

* CPU test doesn't require nvidiadocker container

* Modified the structure by removing the redundant code

* [MXNet-1343][Fit API]Add CNN integration test for fit() API (#14405)

* added cnn intg tests for fit api

* updated cnn intg tests

* added functions for nightly test

* updated runtime_function

* updated intg tests

* updated init, datapath, refs

* added validation data

* update cpu test

* refactor code

* updated context

* [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (#14587)

* Retrieve Batch size and Logging verbose support for Gluon fit() API

* NIT changes

* Addressed review comments: shifted the batch size code to a separate method, sentence correction

* Modified unittest

* removed redundant parameter

* Resolve CI test failure

* only support DataLoader for now, future PRs will include DataIter to DataLoader converter

* Get the number of samples from shape attribute instead of length due to low space complexity

* Simplified batch size retrieval code

* removed batch_size parameter from fit() method and fixed the tests

* Verbose exception handling

* Assigning constant to a verbose

* Modified exception message

* Resolved undefined class reference

* Addressed review comments: Modified verbose level names, docs, variable names

* Update estimator.py

* move estimator to contrib (#14633)

* move to gluon contrib (#14635)

* [Fit API] improve event handlers (#14685)

* improve event handlers

* update tests

* passing weakref of estimator

* fix unit test

* fix test

* fix pylint

* fix test

* fix pylint

* move default metric logic

* combine nightly tests

* [MXNET-1396][Fit-API] Update default handler logic (#14765)

* move to nightly for binaries

* update default handler

* fix pylint

* trigger ci

* trigger ci

* [Fit API] update estimator (#14849)

* address comments

* add comment

* check available context

* fix bug

* change cpu check

* [Fit-API] Adress PR comments (#14885)

* address comments

* update checkpoint

* test symbol save

* address comments

* add resume

* update doc and resume checkpoint

* update docs

* trigger ci

* trigger ci
szha pushed a commit that referenced this pull request May 20, 2019
* base class for estimator and eventhandler

* add license

* add event handlers

* fix pylint

* improve arg check

* fix pylint

* add unit tests
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* [MXNet-1334][Fit API]base class for estimator and eventhandler (apache#14346)

* base class for estimator and eventhandler

* add license

* add event handlers

* fix pylint

* improve arg check

* fix pylint

* add unit tests

* Fixed issue where the estimator was printing beyond the dataset size … (apache#14464)

* Fixed issue where the estimator was printing beyond the dataset size for the last batch

* Added comments

* Nudge to CI

* [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (apache#14442)

* added estimator unittests

* add more tests for estimator

* added validation logic

* added error handlers, unittests

* improve val stats

* fix pylint

* fix pylint

* update unit test

* fix tests

* fix tests

* updated metrics, val logic

* trigger ci

* trigger ci

* update metric, batch_fn error handler

* update context logic, add default metric

* [MXNet-1340][Fit API]Update train stats (apache#14494)

* add train history

* update history

* update test

* avoid calling empty methods

* remove train history object

* fix pylint

* add unit test

* fix test

* update categorize handlers

* [MXNet-1375][Fit API]Added RNN integration test for fit() API (apache#14547)

* Added RNN integration test for fit() API

* Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports

* CPU test doesn't require nvidiadocker container

* Modified the structure by removing the redundant code

* [MXNet-1343][Fit API]Add CNN integration test for fit() API (apache#14405)

* added cnn intg tests for fit api

* updated cnn intg tests

* added functions for nightly test

* updated runtime_function

* updated intg tests

* updated init, datapath, refs

* added validation data

* update cpu test

* refactor code

* updated context

* [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (apache#14587)

* Retrieve Batch size and Logging verbose support for Gluon fit() API

* NIT changes

* Addressed review comments: shifted the batch size code to a separate method, sentence correction

* Modified unittest

* removed redundant parameter

* Resolve CI test failure

* only support DataLoader for now, future PRs will include DataIter to DataLoader converter

* Get the number of samples from shape attribute instead of length due to low space complexity

* Simplified batch size retrieval code

* removed batch_size parameter from fit() method and fixed the tests

* Verbose exception handling

* Assigning constant to a verbose

* Modified exception message

* Resolved undefined class reference

* Addressed review comments: Modified verbose level names, docs, variable names

* Update estimator.py

* move estimator to contrib (apache#14633)

* move to gluon contrib (apache#14635)

* [Fit API] improve event handlers (apache#14685)

* improve event handlers

* update tests

* passing weakref of estimator

* fix unit test

* fix test

* fix pylint

* fix test

* fix pylint

* move default metric logic

* combine nightly tests

* [MXNET-1396][Fit-API] Update default handler logic (apache#14765)

* move to nightly for binaries

* update default handler

* fix pylint

* trigger ci

* trigger ci

* [Fit API] update estimator (apache#14849)

* address comments

* add comment

* check available context

* fix bug

* change cpu check

* [Fit-API] Adress PR comments (apache#14885)

* address comments

* update checkpoint

* test symbol save

* address comments

* add resume

* update doc and resume checkpoint

* update docs

* trigger ci

* trigger ci
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
…e#14346)

* base class for estimator and eventhandler

* add license

* add event handlers

* fix pylint

* improve arg check

* fix pylint

* add unit tests
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Gluon pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants