-
Notifications
You must be signed in to change notification settings - Fork 240
/
eval_assert.py
106 lines (85 loc) · 3.56 KB
/
eval_assert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# pylint: skip-file
"""Reads and asserts over target values"""
from absl import app
from typing import Sequence
from math import isclose
from google.cloud import storage
import json
def read(metrics_file, target):
"""Reads and computes average of target value"""
avg = 0
i = 0
with open(metrics_file, 'r', encoding='utf8') as file:
lines = file.readlines()
for line in lines:
# skip the first 10 lines for burn in
if i >= 10:
vals = json.loads(line)
avg += vals[target]
i+=1
avg /= (i-10)
return avg
def assert_metric_average(metrics_file, threshold, target):
avg_value = read(metrics_file, target)
# Checks for acceptable performance by asserting that the average metric (e.g. TFLOPs)
# is greater than the threshold.
print(f'avg value of target {target} is {avg_value}')
assert avg_value >= float(threshold)
print('assert metric average passed.')
def test_checkpointing(metrics_file, target):
"""Asserts over loss values from loaded checkpoint"""
metrics_file_saved = 'saved_' + metrics_file
metrics_file_restored = 'restored_' + metrics_file
with open(metrics_file_saved, 'r', encoding='utf8') as saved,\
open(metrics_file_restored, 'r', encoding='utf8') as restored:
saved_loss = json.loads(saved.readlines()[-1])[target]
restored_loss = json.loads(restored.readlines()[0])[target]
# Checks that checkpoint restore was successful by comparing loss of last
# step in saved checkpoint to loss of first step in restored checkpoint
print("saved loss: ", saved_loss)
print("restored loss: ", restored_loss)
assert isclose(saved_loss, restored_loss, rel_tol=0.1)
print('checkpointing test passed.')
def test_determinism(metrics_file, target):
"""Asserts over loss values from two runs"""
run_1 = 'run_1_' + metrics_file
run_2 = 'run_2_' + metrics_file
with open(run_1, 'r', encoding='utf8') as run_1_file,\
open(run_2, 'r', encoding='utf8') as run_2_file:
run_1_loss = json.loads(run_1_file.readlines()[-1])[target]
run_2_loss = json.loads(run_2_file.readlines()[-1])[target]
# Check that the two runs have the same loss
print(f"Run 1 loss:{run_1_loss}", flush=True)
print(f"Run 2 loss:{run_2_loss}", flush=True)
assert run_1_loss==run_2_loss
print('determinism test passed.')
def test_vocab_creation(target):
bucket_name = target.split("/")[2]
vocab_path = "/".join(target.split("/")[3:])
storage_client = storage.Client()
assert storage.Blob(bucket=storage_client.bucket(bucket_name), name=vocab_path).exists(storage_client)
print('vocab creation test passed.')
def main(argv: Sequence[str]) -> None:
_, test_scenario, *test_vars = argv
if test_scenario == 'metrics_average':
assert_metric_average(*test_vars)
elif test_scenario == 'checkpoint_save_restore':
test_checkpointing(*test_vars)
elif test_scenario == 'determinism':
test_determinism(*test_vars)
elif test_scenario == 'vocab_creation':
test_vocab_creation(*test_vars)
if __name__ == "__main__":
app.run(main)