This repository has been archived by the owner on Mar 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 141
/
test_checkpoint_handling.py
403 lines (318 loc) · 21.3 KB
/
test_checkpoint_handling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import os
from pathlib import Path
from urllib.parse import urlparse
import pytest
import torch
from InnerEye.Common.common_util import OTHER_RUNS_SUBDIR_NAME
from InnerEye.Common.fixed_paths import MODEL_INFERENCE_JSON_FILE_NAME
from InnerEye.ML.utils.checkpoint_handling import MODEL_WEIGHTS_DIR_NAME
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, get_recovery_checkpoint_path
from InnerEye.ML.deep_learning_config import FINAL_MODEL_FOLDER, FINAL_ENSEMBLE_MODEL_FOLDER
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.model_inference_config import read_model_inference_config
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
from Tests.AfterTraining.test_after_training import FALLBACK_ENSEMBLE_RUN, FALLBACK_SINGLE_RUN, get_most_recent_run, \
get_most_recent_run_id, get_most_recent_model_id
from Tests.ML.util import get_default_checkpoint_handler, get_default_workspace
EXTERNAL_WEIGHTS_URL_EXAMPLE = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
def create_checkpoint_file(file: Path) -> None:
"""
Creates a very simple torch checkpoint file. The only requirement is that it can safely pass torch.load.
:param file: The path of the checkpoint file that should be written.
"""
weights = {'state_dict': {'foo': torch.ones((2, 2))}}
torch.save(weights, str(file))
loaded = torch.load(str(file))
assert loaded, "Unable to read the checkpoint file that was just created"
def test_use_checkpoint_paths_or_urls(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
# No checkpoint handling options set.
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert not checkpoint_handler.run_recovery
assert not checkpoint_handler.trained_weights_paths
# weights from local_weights_path and weights_url will be modified if needed and stored at this location
# Set a weights_path
checkpoint_handler.azure_config.run_recovery_id = ""
checkpoint_handler.container.checkpoint_urls = [EXTERNAL_WEIGHTS_URL_EXAMPLE]
checkpoint_handler.download_recovery_checkpoints_or_weights()
expected_download_path = checkpoint_handler.output_params.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME /\
os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
assert checkpoint_handler.trained_weights_paths[0] == expected_download_path
assert checkpoint_handler.trained_weights_paths[0].is_file()
# set a local_weights_path
checkpoint_handler.container.checkpoint_urls = []
local_weights_path = test_output_dirs.root_dir / "exist.pth"
create_checkpoint_file(local_weights_path)
checkpoint_handler.container.local_checkpoint_paths = [local_weights_path]
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.trained_weights_paths[0] == local_weights_path
assert checkpoint_handler.trained_weights_paths[0].is_file()
@pytest.mark.after_training_single_run
def test_download_recovery_checkpoints_from_single_run(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
# No checkpoint handling options set.
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
run_recovery_id = get_most_recent_run_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
# Set a run recovery object - non ensemble
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.run_recovery
expected_checkpoint_root = config.checkpoint_folder
expected_paths = [get_recovery_checkpoint_path(path=expected_checkpoint_root),
expected_checkpoint_root / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX]
assert checkpoint_handler.run_recovery.checkpoints_roots == [expected_checkpoint_root]
for path in expected_paths:
assert path.is_file()
@pytest.mark.after_training_ensemble_run
def test_download_recovery_checkpoints_from_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
run_recovery_id = get_most_recent_run_id(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
with pytest.raises(ValueError) as ex:
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert "has child runs" in str(ex)
@pytest.mark.after_training_single_run
def test_download_model_from_single_run(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
# No checkpoint handling options set.
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
# Set a run recovery object - non ensemble
checkpoint_handler.azure_config.model_id = model_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.trained_weights_paths
expected_model_root = config.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME / FINAL_MODEL_FOLDER
model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]
assert len(expected_paths) == 1 # A registered model for a non-ensemble run should contain only one checkpoint
assert len(checkpoint_handler.trained_weights_paths) == 1
assert expected_paths[0] == checkpoint_handler.trained_weights_paths[0]
assert expected_paths[0].is_file()
@pytest.mark.after_training_ensemble_run
def test_download_model_from_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
# No checkpoint handling options set.
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
# Set a run recovery object - non ensemble
checkpoint_handler.azure_config.model_id = model_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.trained_weights_paths
expected_model_root = config.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME / FINAL_ENSEMBLE_MODEL_FOLDER
model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]
assert len(checkpoint_handler.trained_weights_paths) == len(expected_paths)
assert set(checkpoint_handler.trained_weights_paths) == set(expected_paths)
for path in expected_paths:
assert path.is_file()
def test_get_recovery_path_train(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
assert checkpoint_handler.get_recovery_or_checkpoint_path_train() is None
@pytest.mark.after_training_single_run
def test_get_recovery_path_train_single_run(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
run_recovery_id = get_most_recent_run_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
# Run recovery with start epoch provided should succeed
expected_path = get_recovery_checkpoint_path(path=config.checkpoint_folder)
assert checkpoint_handler.get_recovery_or_checkpoint_path_train() == expected_path
@pytest.mark.after_training_single_run
def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
# We have not set a run_recovery, nor have we trained, so this should fail to get a checkpoint
with pytest.raises(ValueError) as ex:
checkpoint_handler.get_best_checkpoints()
assert "no run recovery object provided and no training has been done in this run" in ex.value.args[0]
run_recovery_id = get_most_recent_run_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
# We have set a run_recovery_id now, so this should work: Should download all checkpoints that are available
# in the run, into a subfolder of the checkpoint folder
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
expected_checkpoint = config.checkpoint_folder / f"{BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX}"
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
assert checkpoint_paths
assert len(checkpoint_paths) == 1
assert expected_checkpoint == checkpoint_paths[0]
# From now on, the checkpoint handler will think that the run was started from epoch 1. We should pick up
# the best checkpoint from the current run, or from the run recovery if the best checkpoint is there
# and so no checkpoints have been written in the resumed run.
checkpoint_handler.additional_training_done()
# go back to non ensemble run recovery
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
# There is no checkpoint in the current run - use the one from run_recovery
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
assert checkpoint_paths
assert len(checkpoint_paths) == 1
assert checkpoint_paths[0] == expected_checkpoint
# Copy over checkpoints to make it look like training has happened and a better checkpoint written
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
expected_checkpoint.touch()
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
assert checkpoint_paths
assert len(checkpoint_paths) == 1
assert expected_checkpoint == checkpoint_paths[0]
@pytest.mark.after_training_ensemble_run
def test_download_checkpoints_from_hyperdrive_child_runs(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
hyperdrive_run = get_most_recent_run(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
checkpoint_handler.download_checkpoints_from_hyperdrive_child_runs(hyperdrive_run)
expected_checkpoints = [config.checkpoint_folder / OTHER_RUNS_SUBDIR_NAME / str(i)
/ BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX for i in range(2)]
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
assert checkpoint_paths
assert len(checkpoint_paths) == 2
assert set(expected_checkpoints) == set(checkpoint_paths)
def test_get_checkpoints_to_test(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
# Set a local_checkpoint_paths to get checkpoint from. Model has not trained and no run recovery provided,
# so the local weights should be used ignoring any epochs to test
local_weights_path = test_output_dirs.root_dir / "exist.pth"
create_checkpoint_file(local_weights_path)
checkpoint_handler.container.local_checkpoint_paths = [local_weights_path]
checkpoint_handler.download_recovery_checkpoints_or_weights()
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
assert checkpoint_and_paths
assert len(checkpoint_and_paths) == 1
assert checkpoint_and_paths[0] == local_weights_path
checkpoint_handler.additional_training_done()
checkpoint_handler.container.checkpoint_folder.mkdir(parents=True)
# Copy checkpoint to make it seem like training has happened
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
expected_checkpoint.touch()
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
assert checkpoint_and_paths
assert len(checkpoint_and_paths) == 1
assert checkpoint_and_paths[0] == expected_checkpoint
@pytest.mark.after_training_single_run
def test_get_checkpoints_to_test_single_run(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
run_recovery_id = get_most_recent_run_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
# Now set a run recovery object and set the start epoch to 1, so we get one epoch from
# run recovery and one from the training checkpoints
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
checkpoint_handler.additional_training_done()
checkpoint_handler.download_recovery_checkpoints_or_weights()
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
assert checkpoint_and_paths
assert len(checkpoint_and_paths) == 1
assert checkpoint_and_paths[0] == config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
# Copy checkpoint to make it seem like training has happened
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
expected_checkpoint.touch()
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
assert checkpoint_and_paths
assert len(checkpoint_and_paths) == 1
assert checkpoint_and_paths[0] == expected_checkpoint
def test_download_model_weights(test_output_dirs: OutputFolderForTests) -> None:
# Download a sample ResNet model from a URL given in the Pytorch docs
result_path = CheckpointHandler.download_weights(urls=[EXTERNAL_WEIGHTS_URL_EXAMPLE],
download_folder=test_output_dirs.root_dir)
assert len(result_path) == 1
assert result_path[0] == test_output_dirs.root_dir / os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
assert result_path[0].is_file()
modified_time = result_path[0].stat().st_mtime
result_path = CheckpointHandler.download_weights(urls=[EXTERNAL_WEIGHTS_URL_EXAMPLE, EXTERNAL_WEIGHTS_URL_EXAMPLE],
download_folder=test_output_dirs.root_dir)
assert len(result_path) == 2
assert len(list(test_output_dirs.root_dir.glob("*"))) == 1
assert result_path[0].samefile(result_path[1])
assert result_path[0] == test_output_dirs.root_dir / os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
assert result_path[0].is_file()
# This call should not re-download the files, just return the existing ones
assert result_path[0].stat().st_mtime == modified_time
@pytest.mark.after_training_single_run
def test_get_checkpoints_from_model_single_run(test_output_dirs: OutputFolderForTests) -> None:
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
downloaded_checkpoints = CheckpointHandler.get_checkpoints_from_model(model_id=model_id,
workspace=get_default_workspace(),
download_path=test_output_dirs.root_dir)
# Check a single checkpoint has been downloaded
expected_model_root = test_output_dirs.root_dir / FINAL_MODEL_FOLDER
assert expected_model_root.is_dir()
model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]
assert len(expected_paths) == 1 # A registered model for a non-ensemble run should contain only one checkpoint
assert len(downloaded_checkpoints) == 1
assert expected_paths[0] == downloaded_checkpoints[0]
assert expected_paths[0].is_file()
@pytest.mark.after_training_ensemble_run
def test_get_checkpoints_from_model_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
downloaded_checkpoints = CheckpointHandler.get_checkpoints_from_model(model_id=model_id,
workspace=get_default_workspace(),
download_path=test_output_dirs.root_dir)
# Check a single checkpoint has been downloaded
expected_model_root = test_output_dirs.root_dir / FINAL_ENSEMBLE_MODEL_FOLDER
assert expected_model_root.is_dir()
model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]
assert len(expected_paths) == len(downloaded_checkpoints)
assert set(expected_paths) == set(downloaded_checkpoints)
for expected_path in expected_paths:
assert expected_path.is_file()
def test_get_local_weights_path_or_download(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
# If the model has neither local_checkpoint_paths or checkpoint_urls set, should fail.
with pytest.raises(ValueError) as ex:
checkpoint_handler.get_local_checkpoints_path_or_download()
assert "none of model_id, local_weights_path or weights_url is set in the model config." in ex.value.args[0]
# If local_checkpoint_paths folder exists, get_local_checkpoints_path_or_download should not do anything.
local_weights_path = test_output_dirs.root_dir / "exist.pth"
create_checkpoint_file(local_weights_path)
checkpoint_handler.container.local_checkpoint_paths = [local_weights_path]
returned_weights_path = checkpoint_handler.get_local_checkpoints_path_or_download()
assert local_weights_path == returned_weights_path[0]
# Pointing the model to a URL should trigger a download
checkpoint_handler.container.local_checkpoint_paths = []
checkpoint_handler.container.checkpoint_urls = [EXTERNAL_WEIGHTS_URL_EXAMPLE]
downloaded_weights = checkpoint_handler.get_local_checkpoints_path_or_download()
expected_path = checkpoint_handler.output_params.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME / \
os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
assert len(downloaded_weights) == 1
assert downloaded_weights[0].is_file()
assert expected_path == downloaded_weights[0]
# try again, should not re-download
modified_time = downloaded_weights[0].stat().st_mtime
downloaded_weights_new = checkpoint_handler.get_local_checkpoints_path_or_download()
assert len(downloaded_weights_new) == 1
assert downloaded_weights_new[0].stat().st_mtime == modified_time