Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Fixing SSL recovery, attempt 2 #584

Merged
merged 48 commits into from
Nov 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
cf51128
instructions
ant0nsc Sep 22, 2021
b6201b5
Moving loading time out into a callback
ant0nsc Oct 13, 2021
7b5997a
fixing timing callback
ant0nsc Oct 13, 2021
9d55e20
docu
ant0nsc Oct 13, 2021
4254ff2
docu
ant0nsc Oct 13, 2021
0a1fc26
progress bar
ant0nsc Oct 14, 2021
f190669
docu and cleanup
ant0nsc Oct 14, 2021
f90912d
tests
ant0nsc Oct 15, 2021
9bdbd7a
test cleanup
ant0nsc Oct 18, 2021
d02ba0b
test for timers
ant0nsc Oct 19, 2021
8cbc5f1
cleanup
ant0nsc Oct 19, 2021
144698a
tests for callback
ant0nsc Oct 19, 2021
6674b70
hyperparams logging
ant0nsc Oct 19, 2021
ebf8b25
flags
ant0nsc Oct 19, 2021
fd96667
Merge branch 'antonsc/submodule_doc' into antonsc/diagnostics
ant0nsc Oct 19, 2021
b11f6dd
submodule
ant0nsc Oct 19, 2021
8ac90f2
update all usage
ant0nsc Oct 19, 2021
547626e
fix
ant0nsc Oct 20, 2021
986fba8
cleanup
ant0nsc Oct 20, 2021
a823af0
callback save and load
ant0nsc Nov 2, 2021
17564f1
find_unused
ant0nsc Nov 2, 2021
367662f
Merge remote-tracking branch 'origin/main' into antonsc/diagnostics
ant0nsc Nov 2, 2021
6064bc5
remove submodule
ant0nsc Nov 2, 2021
88ed46c
storinglogger update
ant0nsc Nov 2, 2021
857026c
head_batchsize
ant0nsc Nov 2, 2021
a71a476
using submodule
ant0nsc Nov 2, 2021
69fe247
import fix
ant0nsc Nov 2, 2021
3ef5d5d
log_on_epoch
ant0nsc Nov 3, 2021
40b6d07
cleanup of metrics
ant0nsc Nov 3, 2021
8080544
changelog
ant0nsc Nov 3, 2021
8cbafe7
removing submodule
ant0nsc Nov 3, 2021
70771d3
fix import
ant0nsc Nov 3, 2021
c964d84
changelog
ant0nsc Nov 3, 2021
a432fe2
flake fix
ant0nsc Nov 3, 2021
4ee9e8e
Merge remote-tracking branch 'origin/antonsc/diagnostics' into antons…
ant0nsc Nov 3, 2021
d605335
fixed logging
ant0nsc Nov 3, 2021
e086757
Merge remote-tracking branch 'origin/main' into antonsc/recovery2
ant0nsc Nov 12, 2021
7f75ec3
changelog
ant0nsc Nov 12, 2021
c22362e
mypy
ant0nsc Nov 12, 2021
2da000b
test fix
ant0nsc Nov 13, 2021
16f9793
Merge branch 'main' into antonsc/recovery2
ant0nsc Nov 17, 2021
588dd01
PR comments
ant0nsc Nov 17, 2021
44b8d0e
fix
ant0nsc Nov 17, 2021
ea77b47
fix
ant0nsc Nov 17, 2021
a14b301
Merge remote-tracking branch 'origin/antonsc/pathfix' into antonsc/re…
ant0nsc Nov 18, 2021
3123ed4
Merge remote-tracking branch 'origin/main' into antonsc/recovery2
ant0nsc Nov 18, 2021
5703c1c
fix
ant0nsc Nov 18, 2021
4c4f4eb
PR comments
ant0nsc Nov 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
tests for callback
  • Loading branch information
ant0nsc committed Oct 19, 2021
commit 144698a6f3701c8a00843678dfcce7e330a6bcc0
13 changes: 3 additions & 10 deletions InnerEye/ML/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,10 @@ def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.module = pl_module

def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.train_timers.reset()
self.train_timers.epoch_start()

def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.val_timers.reset()
self.val_timers.epoch_start()
# In Lightning, the validation epoch is running "inside" the training. If we get here, it means that training
# is done for this epoch, even though the on_training_epoch hook has not yet been called.
self.train_timers.epoch_end()
Expand Down Expand Up @@ -327,7 +327,7 @@ def write_and_log_epoch_time(self, is_training: bool) -> None:
f"for data took {timers.total_load_time:0.2f} sec total.")
if timers.num_load_time_exceeded > 0 and timers.should_warn_in_this_epoch:
logging.warning("The dataloaders were not fast enough to always supply the next batch in less than "
f"{timers.max_item_load_time_seconds}sec.")
f"{timers.max_item_load_time_seconds:0.2f}sec.")
logging.warning(
f"In this epoch, {timers.num_load_time_exceeded} out of {timers.num_batches} batches exceeded the load "
f"time threshold. Total loading time for the slow batches was {timers.total_extra_load_time:0.2f}sec.")
Expand Down Expand Up @@ -364,13 +364,6 @@ def get_timers(self, is_training: bool) -> EpochTimers:
"""
return self.train_timers if is_training else self.val_timers

def reset_timers(self) -> None:
"""
Resets all timers and counters, for both the validation and the training epoch.
"""
self.train_timers.reset()
self.val_timers.reset()


class InnerEyeLightning(LightningModule):
"""
Expand Down
21 changes: 10 additions & 11 deletions InnerEye/ML/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,6 @@ class EpochTimers:
"""
Contains all information necessary to compute the IO metrics: Epoch times, batch times, loading times.
"""
epoch_start_time: float = 0.0
epoch_end_time: float = 0.0
batch_start_time: float = 0.0
num_load_time_warnings: int = 0
num_load_time_exceeded: int = 0
total_extra_load_time: float = 0.0
total_load_time: float = 0.0
num_batches: int = 0

def __init__(self,
max_item_load_time_seconds: float = 0.5,
Expand All @@ -106,9 +98,16 @@ def __init__(self,
self.max_load_time_warnings = max_load_time_warnings
self.max_load_time_epochs = max_load_time_epochs
self.load_time_warning_epochs: Set[int] = set()
self.reset()

def reset(self) -> None:
self.epoch_start_time: float = 0.0
self.epoch_end_time: float = 0.0
self.batch_start_time: float = 0.0
self.num_load_time_warnings: int = 0
self.num_load_time_exceeded: int = 0
self.total_extra_load_time: float = 0.0
self.total_load_time: float = 0.0
self.num_batches: int = 0

def epoch_start(self) -> None:
"""
Resets all timers to the current time, and all counters to 0. The set of epochs for which warnings about
load time were produced will not be reset.
Expand Down
89 changes: 87 additions & 2 deletions Tests/ML/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
# ------------------------------------------------------------------------------------------
import logging
import math
from typing import List
from typing import Callable, Dict, List, Optional
from unittest import mock

import torch
from _pytest.logging import LogCaptureFixture

from InnerEye.Common.metrics_constants import MetricType, TRAIN_PREFIX, VALIDATION_PREFIX
from InnerEye.ML.lightning_base import BatchTimeCallback
from InnerEye.ML.lightning_loggers import (AzureMLProgressBar, PROGRESS_STAGE_PREDICT, PROGRESS_STAGE_TEST,
PROGRESS_STAGE_TRAIN, PROGRESS_STAGE_VAL)
from InnerEye.ML.metrics import EpochTimers
Expand Down Expand Up @@ -96,6 +99,9 @@ def write_message(message: str) -> None:


def test_epoch_timers(caplog: LogCaptureFixture) -> None:
"""
Test the class that measures batch and epoch times.
"""
caplog.set_level(logging.INFO)
batch_index = 123
epoch = 24
Expand Down Expand Up @@ -140,6 +146,10 @@ def test_epoch_timers(caplog: LogCaptureFixture) -> None:
assert f"prefix: Loading minibatch {batch_index} took" in message
assert f"This message will be printed at most {timer.max_load_time_warnings} times"
assert timer.num_load_time_warnings > 0
# Test if the warnings disappear after the max number of warnings
assert timer.should_warn_in_this_epoch
timer.num_load_time_warnings = timer.max_load_time_warnings + 1
assert not timer.should_warn_in_this_epoch

# Epoch end time should be stored
assert timer.total_epoch_time == 0.0
Expand All @@ -148,6 +158,81 @@ def test_epoch_timers(caplog: LogCaptureFixture) -> None:
assert timer.epoch_end_time > old_epoch_end_time
assert timer.total_epoch_time > 0.0

timer.reset()
# Test the resetting logic
timer.epoch_start()
assert timer.total_load_time == 0.0
assert timer.num_load_time_warnings == 0
# The object should keep track of all epochs in which warnings were printed
assert len(timer.load_time_warning_epochs) > 0


def test_batch_time_callback(caplog: LogCaptureFixture) -> None:
"""
Test the callback that measures data loading times.
"""
caplog.set_level(logging.INFO)
callback = BatchTimeCallback()
epoch = 1234
# This dictionary stores all metrics that are written via module.log
logged_metrics = {}

def mock_log(name: str, value: float, reduce_fx: Callable, **kwargs: Dict) -> None:
logged_metrics[name] = (value, reduce_fx)

mock_module = mock.MagicMock(current_epoch=epoch, log=mock_log)
callback.on_fit_start(trainer=None, pl_module=mock_module) # type: ignore
assert callback.module == mock_module

# Upon epoch start, the timers should be reset. We can check that by looking at epoch_start_time
assert callback.train_timers.epoch_start_time == 0.0
callback.on_train_epoch_start(None, None) # type: ignore
assert callback.train_timers.epoch_start_time > 0.0
assert callback.val_timers.epoch_start_time == 0.0
old_train_epoch_end_time = callback.train_timers.epoch_end_time
callback.on_validation_epoch_start(None, None) # type: ignore
assert callback.val_timers.epoch_start_time > 0.0
# When calling epoch_start for validation, training epoch should be ended
assert callback.train_timers.epoch_end_time > old_train_epoch_end_time

# Run 1 training batch
callback.on_train_batch_start(None, None, None, batch_idx=0, dataloader_idx=0) # type: ignore
callback.on_train_batch_end(None, None, None, None, batch_idx=0, dataloader_idx=0) # type: ignore
assert len(logged_metrics) == 2
# Upon batch end, we should see metrics being logged. Batch level timings should be logged both as averages and max
def check_batch_metrics(train_or_val: str) -> None:
for suffix in [" avg", " max"]:
name = f"timing/{train_or_val}/SecondsPerBatch" + suffix
assert name in logged_metrics
assert logged_metrics[name][1] == max if suffix == " max" else torch.mean
check_batch_metrics("train")
assert caplog.messages[-1].startswith(f"Epoch {epoch} training: Loaded the first")
# Run 2 validation batches
for batch_idx in range(2):
callback.on_validation_batch_start(None, None, None, batch_idx=batch_idx, dataloader_idx=0) # type: ignore
callback.on_validation_batch_end(None, None, None, None, batch_idx=batch_idx, dataloader_idx=0) # type: ignore
assert caplog.messages[-1].startswith(f"Epoch {epoch} validation: Loaded the first")
assert callback.train_timers.num_batches == 1
assert callback.val_timers.num_batches == 2
check_batch_metrics("val")

# Check that the metrics are written at the end of the validation epoch.
# Hack the timers to trigger the warning message for validation only
callback.val_timers.num_load_time_exceeded = 1
callback.val_timers.total_extra_load_time = 100.00
callback.val_timers.max_item_load_time_seconds = 2.0
assert callback.val_timers.should_warn_in_this_epoch
old_val_epoch_end_time = callback.train_timers.epoch_end_time
callback.on_validation_epoch_end(None, None) # type: ignore
assert callback.val_timers.epoch_end_time > old_val_epoch_end_time
assert len(logged_metrics) > 0

assert f"Epoch {epoch} training took " in caplog.messages[-4]
assert f"Epoch {epoch} validation took " in caplog.messages[-3]
assert "The dataloaders were not fast enough" in caplog.messages[-2]
assert "in less than 2.00sec" in caplog.messages[-2]
assert "1 out of 2 batches exceeded the load time threshold" in caplog.messages[-1]
assert "Total loading time for the slow batches was 100.00sec" in caplog.messages[-1]

for prefix in [TRAIN_PREFIX, VALIDATION_PREFIX]:
for metric in [MetricType.SECONDS_PER_EPOCH.value, MetricType.EXCESS_BATCH_LOADING_TIME.value]:
assert f"timing/{prefix}{metric}" in logged_metrics