Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Extend estimator.evaluate() to support event handlers (#16971)
Browse files Browse the repository at this point in the history

* fix unittest failures for the new api interface

* Add comments in the code for readability

* Remove unused argument val_metrics

* merge changes with the master branch

* fix some regression errors

* fix bugs introduced in the merging phase
  • Loading branch information
liuzh47 authored and leezu committed Dec 10, 2019
1 parent 60f77f5 commit 0c17ddd
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 100 deletions.
105 changes: 74 additions & 31 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ class Estimator(object):
The model used for training.
loss : gluon.loss.Loss
Loss (objective) function to calculate during training.
metrics : EvalMetric or list of EvalMetric
Metrics for evaluating models.
train_metrics : EvalMetric or list of EvalMetric
Training metrics for evaluating models on training dataset.
val_metrics : EvalMetric or list of EvalMetric
Validation metrics for evaluating models on validation dataset.
initializer : Initializer
Initializer to initialize the network.
trainer : Trainer
Expand Down Expand Up @@ -105,15 +107,17 @@ class Estimator(object):

def __init__(self, net,
loss,
metrics=None,
train_metrics=None,
val_metrics=None,
initializer=None,
trainer=None,
context=None,
evaluation_loss=None,
eval_net=None):
self.net = net
self.loss = self._check_loss(loss)
self._train_metrics = _check_metrics(metrics)
self._train_metrics = _check_metrics(train_metrics)
self._val_metrics = _check_metrics(val_metrics)
self._add_default_training_metrics()
self._add_validation_metrics()
self.evaluation_loss = self.loss
Expand Down Expand Up @@ -226,13 +230,21 @@ def _add_default_training_metrics(self):
self._train_metrics.append(metric_loss(loss_name))

for metric in self._train_metrics:
metric.name = "training " + metric.name
# add training prefix to the metric name
# it is useful for event handlers to distinguish them from validation metrics
metric.name = 'training ' + metric.name

def _add_validation_metrics(self):
self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics]
if not self._val_metrics:
self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics]

for metric in self._val_metrics:
metric.name = "validation " + metric.name
# add validation prefix to the metric name
# it is useful for event handlers to distinguish them from training metrics
if 'training' in metric.name:
metric.name = metric.name.replace('training', 'validation')
else:
metric.name = 'validation ' + metric.name

@property
def train_metrics(self):
Expand All @@ -244,33 +256,26 @@ def val_metrics(self):

def evaluate_batch(self,
val_batch,
val_metrics,
batch_axis=0):
"""Evaluate model on a batch of validation data.
Parameters
----------
val_batch : tuple
Data and label of a batch from the validation data loader.
val_metrics : EvalMetric or list of EvalMetrics
Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
pred = [self.eval_net(x) for x in data]
loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
for metric in val_metrics:
if isinstance(metric, metric_loss):
metric.update(0, loss)
else:
metric.update(label, pred)

return data, label, pred, loss

def evaluate(self,
val_data,
val_metrics,
batch_axis=0):
batch_axis=0,
event_handlers=None):
"""Evaluate model on validation data.
This function calls :py:func:`evaluate_batch` on each of the batches from the
Expand All @@ -281,21 +286,42 @@ def evaluate(self,
----------
val_data : DataLoader
Validation data loader with data and labels.
val_metrics : EvalMetric or list of EvalMetrics
Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
event_handlers : EventHandler or list of EventHandler
List of :py:class:`EventHandlers` to apply during validation. Besides
event handlers specified here, a default MetricHandler and a LoggingHandler
will be added if not specified explicitly.
"""
if not isinstance(val_data, DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.DataLoader")

for metric in val_metrics:
for metric in self.val_metrics:
metric.reset()

event_handlers = self._prepare_default_validation_handlers(event_handlers)

_, epoch_begin, batch_begin, batch_end, \
epoch_end, _ = self._categorize_handlers(event_handlers)

estimator_ref = self

for handler in epoch_begin:
handler.epoch_begin(estimator_ref)

for _, batch in enumerate(val_data):
self.evaluate_batch(batch, val_metrics, batch_axis)
for handler in batch_begin:
handler.batch_begin(estimator_ref, batch=batch)

_, label, pred, loss = self.evaluate_batch(batch, batch_axis)

for handler in batch_end:
handler.batch_end(estimator_ref, batch=batch, pred=pred, label=label, loss=loss)

for handler in epoch_end:
handler.epoch_end(estimator_ref)

def fit_batch(self, train_batch, batch_axis=0):
"""Trains the model on a batch of training data.
Expand Down Expand Up @@ -441,23 +467,17 @@ def _prepare_default_handlers(self, val_data, event_handlers):
added_default_handlers.append(GradientUpdateHandler())

if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics))
added_default_handlers.append(MetricHandler(metrics=self.train_metrics))

if not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
# no validation handler
if val_data:
val_metrics = self.val_metrics
# add default validation handler if validation data found
added_default_handlers.append(ValidationHandler(val_data=val_data,
eval_fn=self.evaluate,
val_metrics=val_metrics))
else:
# set validation metrics to None if no validation data and no validation handler
val_metrics = []
eval_fn=self.evaluate))

if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
added_default_handlers.append(LoggingHandler(train_metrics=self.train_metrics,
val_metrics=val_metrics))
added_default_handlers.append(LoggingHandler(metrics=self.train_metrics))

# if there is a mix of user defined event handlers and default event handlers
# they should have the same set of metrics
Expand All @@ -474,6 +494,29 @@ def _prepare_default_handlers(self, val_data, event_handlers):
event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
return event_handlers

def _prepare_default_validation_handlers(self, event_handlers):
event_handlers = _check_event_handlers(event_handlers)
added_default_handlers = []

# add default logging handler and metric handler for validation
if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
added_default_handlers.append(MetricHandler(metrics=self.val_metrics))

if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
added_default_handlers.append(LoggingHandler(metrics=self.val_metrics))

mixing_handlers = event_handlers and added_default_handlers
event_handlers.extend(added_default_handlers)

# check if all handlers refer to well-defined validation metrics
if mixing_handlers:
known_metrics = set(self.val_metrics)
for handler in event_handlers:
_check_handler_metric_ref(handler, known_metrics)

event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
return event_handlers

def _categorize_handlers(self, event_handlers):
"""
categorize handlers into 6 event lists to avoid calling empty methods
Expand Down
59 changes: 26 additions & 33 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,28 +128,28 @@ class MetricHandler(EpochBegin, BatchEnd):
Parameters
----------
train_metrics : List of EvalMetrics
Training metrics to be updated at batch end.
metrics : List of EvalMetrics
Metrics to be updated at batch end.
priority : scalar
Priority level of the MetricHandler. Priority level is sorted in ascending
order. The lower the number is, the higher priority level the handler is.
"""

def __init__(self, train_metrics, priority=-1000):
self.train_metrics = _check_metrics(train_metrics)
def __init__(self, metrics, priority=-1000):
self.metrics = _check_metrics(metrics)
# order to be called among all callbacks
# metrics need to be calculated before other callbacks can access them
self.priority = priority

def epoch_begin(self, estimator, *args, **kwargs):
for metric in self.train_metrics:
for metric in self.metrics:
metric.reset()

def batch_end(self, estimator, *args, **kwargs):
pred = kwargs['pred']
label = kwargs['label']
loss = kwargs['loss']
for metric in self.train_metrics:
for metric in self.metrics:
if isinstance(metric, metric_loss):
# metric wrapper for loss values
metric.update(0, loss)
Expand All @@ -171,8 +171,6 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
eval_fn : function
A function defines how to run evaluation and
calculate loss and metrics.
val_metrics : List of EvalMetrics
Validation metrics to be updated.
epoch_period : int, default 1
How often to run validation at epoch end, by default
:py:class:`ValidationHandler` validate every epoch.
Expand All @@ -188,15 +186,13 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
def __init__(self,
val_data,
eval_fn,
val_metrics=None,
epoch_period=1,
batch_period=None,
priority=-1000):
self.val_data = val_data
self.eval_fn = eval_fn
self.epoch_period = epoch_period
self.batch_period = batch_period
self.val_metrics = _check_metrics(val_metrics)
self.current_batch = 0
self.current_epoch = 0
# order to be called among all callbacks
Expand All @@ -211,20 +207,12 @@ def train_begin(self, estimator, *args, **kwargs):
def batch_end(self, estimator, *args, **kwargs):
self.current_batch += 1
if self.batch_period and self.current_batch % self.batch_period == 0:
self.eval_fn(val_data=self.val_data,
val_metrics=self.val_metrics)
msg = '[Epoch %d] ValidationHandler: %d batches reached, ' \
% (self.current_epoch, self.current_batch)
for monitor in self.val_metrics:
name, value = monitor.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(','))
self.eval_fn(val_data=self.val_data)

def epoch_end(self, estimator, *args, **kwargs):
self.current_epoch += 1
if self.epoch_period and self.current_epoch % self.epoch_period == 0:
self.eval_fn(val_data=self.val_data,
val_metrics=self.val_metrics)
self.eval_fn(val_data=self.val_data)


class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd):
Expand All @@ -239,32 +227,29 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
Logging interval during training.
log_interval='epoch': display metrics every epoch
log_interval=integer k: display metrics every interval of k batches
train_metrics : list of EvalMetrics
Training metrics to be logged, logged at batch end, epoch end, train end.
val_metrics : list of EvalMetrics
Validation metrics to be logged, logged at epoch end, train end.
metrics : list of EvalMetrics
Metrics to be logged, logged at batch end, epoch end, train end.
priority : scalar, default np.Inf
Priority level of the LoggingHandler. Priority level is sorted in
ascending order. The lower the number is, the higher priority level the
handler is.
"""

def __init__(self, log_interval='epoch',
train_metrics=None,
val_metrics=None,
metrics=None,
priority=np.Inf):
super(LoggingHandler, self).__init__()
if not isinstance(log_interval, int) and log_interval != 'epoch':
raise ValueError("log_interval must be either an integer or string 'epoch'")
self.train_metrics = _check_metrics(train_metrics)
self.val_metrics = _check_metrics(val_metrics)
self.metrics = _check_metrics(metrics)
self.batch_index = 0
self.current_epoch = 0
self.processed_samples = 0
# logging handler need to be called at last to make sure all states are updated
# it will also shut down logging at train end
self.priority = priority
self.log_interval = log_interval
self.log_interval_time = 0

def train_begin(self, estimator, *args, **kwargs):
self.train_start = time.time()
Expand All @@ -288,7 +273,7 @@ def train_end(self, estimator, *args, **kwargs):
train_time = time.time() - self.train_start
msg = 'Train finished using total %ds with %d epochs. ' % (train_time, self.current_epoch)
# log every result in train stats including train/validation loss & metrics
for metric in self.train_metrics + self.val_metrics:
for metric in self.metrics:
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))
Expand All @@ -307,7 +292,7 @@ def batch_end(self, estimator, *args, **kwargs):
if self.batch_index % self.log_interval == 0:
msg += 'time/interval: %.3fs ' % self.log_interval_time
self.log_interval_time = 0
for metric in self.train_metrics:
for metric in self.metrics:
# only log current training loss & metric after each interval
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
Expand All @@ -316,15 +301,23 @@ def batch_end(self, estimator, *args, **kwargs):

def epoch_begin(self, estimator, *args, **kwargs):
if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
is_training = False
# use the name hack defined in __init__() of estimator class
for metric in self.metrics:
if 'training' in metric.name:
is_training = True
self.epoch_start = time.time()
estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
self.current_epoch, estimator.trainer.learning_rate)
if is_training:
estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
self.current_epoch, estimator.trainer.learning_rate)
else:
estimator.logger.info("Validation Begin")

def epoch_end(self, estimator, *args, **kwargs):
if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
epoch_time = time.time() - self.epoch_start
msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time)
for monitor in self.train_metrics + self.val_metrics:
for monitor in self.metrics:
name, value = monitor.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))
Expand Down
Loading

0 comments on commit 0c17ddd

Please sign in to comment.