Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Extend estimator.evaluate() to support event handlers #16971

Merged
merged 7 commits into from
Dec 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 74 additions & 31 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ class Estimator(object):
The model used for training.
loss : gluon.loss.Loss
Loss (objective) function to calculate during training.
metrics : EvalMetric or list of EvalMetric
Metrics for evaluating models.
train_metrics : EvalMetric or list of EvalMetric
Training metrics for evaluating models on training dataset.
val_metrics : EvalMetric or list of EvalMetric
Validation metrics for evaluating models on validation dataset.
initializer : Initializer
Initializer to initialize the network.
trainer : Trainer
Expand Down Expand Up @@ -105,15 +107,17 @@ class Estimator(object):

def __init__(self, net,
loss,
metrics=None,
train_metrics=None,
val_metrics=None,
initializer=None,
trainer=None,
context=None,
evaluation_loss=None,
eval_net=None):
self.net = net
self.loss = self._check_loss(loss)
self._train_metrics = _check_metrics(metrics)
self._train_metrics = _check_metrics(train_metrics)
self._val_metrics = _check_metrics(val_metrics)
self._add_default_training_metrics()
self._add_validation_metrics()
self.evaluation_loss = self.loss
Expand Down Expand Up @@ -226,13 +230,21 @@ def _add_default_training_metrics(self):
self._train_metrics.append(metric_loss(loss_name))

for metric in self._train_metrics:
metric.name = "training " + metric.name
# add training prefix to the metric name
# it is useful for event handlers to distinguish them from validation metrics
metric.name = 'training ' + metric.name
liuzh47 marked this conversation as resolved.
Show resolved Hide resolved

def _add_validation_metrics(self):
self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics]
if not self._val_metrics:
self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics]

for metric in self._val_metrics:
metric.name = "validation " + metric.name
# add validation prefix to the metric name
# it is useful for event handlers to distinguish them from training metrics
if 'training' in metric.name:
metric.name = metric.name.replace('training', 'validation')
else:
metric.name = 'validation ' + metric.name

@property
def train_metrics(self):
Expand All @@ -244,33 +256,26 @@ def val_metrics(self):

def evaluate_batch(self,
val_batch,
val_metrics,
batch_axis=0):
"""Evaluate model on a batch of validation data.

Parameters
----------
val_batch : tuple
Data and label of a batch from the validation data loader.
val_metrics : EvalMetric or list of EvalMetrics
Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
pred = [self.eval_net(x) for x in data]
loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
for metric in val_metrics:
if isinstance(metric, metric_loss):
metric.update(0, loss)
else:
metric.update(label, pred)

return data, label, pred, loss

def evaluate(self,
val_data,
val_metrics,
batch_axis=0):
batch_axis=0,
event_handlers=None):
"""Evaluate model on validation data.

This function calls :py:func:`evaluate_batch` on each of the batches from the
Expand All @@ -281,21 +286,42 @@ def evaluate(self,
----------
val_data : DataLoader
Validation data loader with data and labels.
val_metrics : EvalMetric or list of EvalMetrics
Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
event_handlers : EventHandler or list of EventHandler
List of :py:class:`EventHandlers` to apply during validation. Besides
event handlers specified here, a default MetricHandler and a LoggingHandler
will be added if not specified explicitly.
"""
if not isinstance(val_data, DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.DataLoader")

for metric in val_metrics:
for metric in self.val_metrics:
metric.reset()

event_handlers = self._prepare_default_validation_handlers(event_handlers)

_, epoch_begin, batch_begin, batch_end, \
epoch_end, _ = self._categorize_handlers(event_handlers)

estimator_ref = self

for handler in epoch_begin:
handler.epoch_begin(estimator_ref)

for _, batch in enumerate(val_data):
self.evaluate_batch(batch, val_metrics, batch_axis)
for handler in batch_begin:
handler.batch_begin(estimator_ref, batch=batch)

_, label, pred, loss = self.evaluate_batch(batch, batch_axis)

for handler in batch_end:
handler.batch_end(estimator_ref, batch=batch, pred=pred, label=label, loss=loss)

for handler in epoch_end:
handler.epoch_end(estimator_ref)

def fit_batch(self, train_batch, batch_axis=0):
"""Trains the model on a batch of training data.
Expand Down Expand Up @@ -441,23 +467,17 @@ def _prepare_default_handlers(self, val_data, event_handlers):
added_default_handlers.append(GradientUpdateHandler())

if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics))
added_default_handlers.append(MetricHandler(metrics=self.train_metrics))

if not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
# no validation handler
if val_data:
val_metrics = self.val_metrics
# add default validation handler if validation data found
added_default_handlers.append(ValidationHandler(val_data=val_data,
eval_fn=self.evaluate,
val_metrics=val_metrics))
else:
# set validation metrics to None if no validation data and no validation handler
val_metrics = []
eval_fn=self.evaluate))

if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
added_default_handlers.append(LoggingHandler(train_metrics=self.train_metrics,
val_metrics=val_metrics))
added_default_handlers.append(LoggingHandler(metrics=self.train_metrics))

# if there is a mix of user defined event handlers and default event handlers
# they should have the same set of metrics
Expand All @@ -474,6 +494,29 @@ def _prepare_default_handlers(self, val_data, event_handlers):
event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
return event_handlers

def _prepare_default_validation_handlers(self, event_handlers):
event_handlers = _check_event_handlers(event_handlers)
added_default_handlers = []

# add default logging handler and metric handler for validation
if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
added_default_handlers.append(MetricHandler(metrics=self.val_metrics))

if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
added_default_handlers.append(LoggingHandler(metrics=self.val_metrics))

mixing_handlers = event_handlers and added_default_handlers
event_handlers.extend(added_default_handlers)

# check if all handlers refer to well-defined validation metrics
if mixing_handlers:
known_metrics = set(self.val_metrics)
for handler in event_handlers:
_check_handler_metric_ref(handler, known_metrics)

event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
return event_handlers

def _categorize_handlers(self, event_handlers):
"""
categorize handlers into 6 event lists to avoid calling empty methods
Expand Down
59 changes: 26 additions & 33 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,28 +128,28 @@ class MetricHandler(EpochBegin, BatchEnd):

Parameters
----------
train_metrics : List of EvalMetrics
Training metrics to be updated at batch end.
metrics : List of EvalMetrics
Metrics to be updated at batch end.
priority : scalar
Priority level of the MetricHandler. Priority level is sorted in ascending
order. The lower the number is, the higher priority level the handler is.
"""

def __init__(self, train_metrics, priority=-1000):
self.train_metrics = _check_metrics(train_metrics)
def __init__(self, metrics, priority=-1000):
self.metrics = _check_metrics(metrics)
# order to be called among all callbacks
# metrics need to be calculated before other callbacks can access them
self.priority = priority

def epoch_begin(self, estimator, *args, **kwargs):
for metric in self.train_metrics:
for metric in self.metrics:
metric.reset()

def batch_end(self, estimator, *args, **kwargs):
pred = kwargs['pred']
label = kwargs['label']
loss = kwargs['loss']
for metric in self.train_metrics:
for metric in self.metrics:
if isinstance(metric, metric_loss):
# metric wrapper for loss values
metric.update(0, loss)
Expand All @@ -171,8 +171,6 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
eval_fn : function
A function defines how to run evaluation and
calculate loss and metrics.
val_metrics : List of EvalMetrics
Validation metrics to be updated.
epoch_period : int, default 1
How often to run validation at epoch end, by default
:py:class:`ValidationHandler` validate every epoch.
Expand All @@ -188,15 +186,13 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
def __init__(self,
val_data,
eval_fn,
val_metrics=None,
epoch_period=1,
batch_period=None,
priority=-1000):
self.val_data = val_data
self.eval_fn = eval_fn
self.epoch_period = epoch_period
self.batch_period = batch_period
self.val_metrics = _check_metrics(val_metrics)
self.current_batch = 0
self.current_epoch = 0
# order to be called among all callbacks
Expand All @@ -211,20 +207,12 @@ def train_begin(self, estimator, *args, **kwargs):
def batch_end(self, estimator, *args, **kwargs):
self.current_batch += 1
if self.batch_period and self.current_batch % self.batch_period == 0:
self.eval_fn(val_data=self.val_data,
val_metrics=self.val_metrics)
msg = '[Epoch %d] ValidationHandler: %d batches reached, ' \
% (self.current_epoch, self.current_batch)
for monitor in self.val_metrics:
name, value = monitor.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(','))
self.eval_fn(val_data=self.val_data)

def epoch_end(self, estimator, *args, **kwargs):
self.current_epoch += 1
if self.epoch_period and self.current_epoch % self.epoch_period == 0:
self.eval_fn(val_data=self.val_data,
val_metrics=self.val_metrics)
self.eval_fn(val_data=self.val_data)


class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd):
Expand All @@ -239,32 +227,29 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
Logging interval during training.
log_interval='epoch': display metrics every epoch
log_interval=integer k: display metrics every interval of k batches
train_metrics : list of EvalMetrics
Training metrics to be logged, logged at batch end, epoch end, train end.
val_metrics : list of EvalMetrics
Validation metrics to be logged, logged at epoch end, train end.
metrics : list of EvalMetrics
Metrics to be logged, logged at batch end, epoch end, train end.
priority : scalar, default np.Inf
Priority level of the LoggingHandler. Priority level is sorted in
ascending order. The lower the number is, the higher priority level the
handler is.
"""

def __init__(self, log_interval='epoch',
train_metrics=None,
val_metrics=None,
metrics=None,
priority=np.Inf):
super(LoggingHandler, self).__init__()
if not isinstance(log_interval, int) and log_interval != 'epoch':
raise ValueError("log_interval must be either an integer or string 'epoch'")
self.train_metrics = _check_metrics(train_metrics)
self.val_metrics = _check_metrics(val_metrics)
self.metrics = _check_metrics(metrics)
self.batch_index = 0
self.current_epoch = 0
self.processed_samples = 0
# logging handler need to be called at last to make sure all states are updated
# it will also shut down logging at train end
self.priority = priority
self.log_interval = log_interval
self.log_interval_time = 0

def train_begin(self, estimator, *args, **kwargs):
self.train_start = time.time()
Expand All @@ -288,7 +273,7 @@ def train_end(self, estimator, *args, **kwargs):
train_time = time.time() - self.train_start
msg = 'Train finished using total %ds with %d epochs. ' % (train_time, self.current_epoch)
# log every result in train stats including train/validation loss & metrics
for metric in self.train_metrics + self.val_metrics:
for metric in self.metrics:
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))
Expand All @@ -307,7 +292,7 @@ def batch_end(self, estimator, *args, **kwargs):
if self.batch_index % self.log_interval == 0:
msg += 'time/interval: %.3fs ' % self.log_interval_time
self.log_interval_time = 0
for metric in self.train_metrics:
for metric in self.metrics:
# only log current training loss & metric after each interval
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
Expand All @@ -316,15 +301,23 @@ def batch_end(self, estimator, *args, **kwargs):

def epoch_begin(self, estimator, *args, **kwargs):
if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
is_training = False
# use the name hack defined in __init__() of estimator class
for metric in self.metrics:
if 'training' in metric.name:
liuzh47 marked this conversation as resolved.
Show resolved Hide resolved
is_training = True
self.epoch_start = time.time()
estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
self.current_epoch, estimator.trainer.learning_rate)
if is_training:
estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
self.current_epoch, estimator.trainer.learning_rate)
else:
estimator.logger.info("Validation Begin")

def epoch_end(self, estimator, *args, **kwargs):
if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
epoch_time = time.time() - self.epoch_start
msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time)
for monitor in self.train_metrics + self.val_metrics:
for monitor in self.metrics:
name, value = monitor.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))
Expand Down
Loading