-
Notifications
You must be signed in to change notification settings - Fork 6
/
edit_model.py
122 lines (104 loc) · 3.97 KB
/
edit_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import torch
import copy
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-model_file", type=str, default='')
parser.add_argument("-num_classes", type=int, default=-1)
parser.add_argument("-epoch", type=int, default=-1)
parser.add_argument("-base_model", choices=['bvlc', 'p365', '5h', 'ignore'], default='ignore')
parser.add_argument("-data_mean", type=str, default='ignore')
parser.add_argument("-data_sd", type=str, default='ignore')
parser.add_argument("-normval_format", choices=['bgr', 'rgb', 'ignore'], default='ignore')
parser.add_argument("-has_branches", choices=['true', 'false', 'ignore'], default='ignore')
parser.add_argument("-reverse_normvals", action='store_true')
parser.add_argument("-print_vals", action='store_true')
parser.add_argument("-output_name", type=str, default='')
params = parser.parse_args()
main_func(params)
def main_func(params):
checkpoint = torch.load(params.model_file, map_location='cpu')
save_model = copy.deepcopy(checkpoint)
if params.print_vals:
print_model_vals(save_model)
if params.num_classes > -1:
save_model['num_classes'] = params.num_classes
if params.base_model != 'ignore':
save_model['base_model'] = params.base_model
if params.has_branches != 'ignore':
has_branches = True if params.has_branches == 'true' else False
save_model['has_branches'] = has_branches
if params.epoch != -1:
save_model['epoch'] = params.epoch
if params.data_mean != 'ignore' or params.data_sd != 'ignore' or params.normval_format != 'ignore':
try:
norm_vals = save_model['normalize_params']
if params.data_mean != 'ignore':
norm_vals[0] = [float(m) for m in params.data_mean.split(',')]
if params.data_sd != 'ignore':
norm_vals[1] = [float(s) for s in params.data_sd.split(',')]
if params.normval_format != 'ignore':
try:
norm_vals[2] = params.normval_format
except:
norm_vals += [params.normval_format] # Add to legacy models
save_model['normalize_params'] = norm_vals
except:
assert params.data_mean != 'ignore', "'-data_mean' is required"
assert params.data_sd != 'ignore', "'-data_sd' is required"
assert params.normval_format != 'ignore', "'-normval_format' is required"
save_model['normalize_params'] = [params.data_mean, params.data_sd, params.normval_format]
if params.reverse_normvals:
norm_vals = save_model['normalize_params']
norm_vals[0].reverse()
norm_vals[1].reverse()
save_model['normalize_params'] = norm_vals
if params.output_name != '':
torch.save(save_model, save_name)
def print_model_vals(model):
print('Model Values')
try:
print(' Num classes:', model['num_classes'])
except:
pass
try:
print(' Base model:', model['base_model'])
except:
pass
try:
print(' Model epoch:', model['epoch'])
except:
pass
try:
print(' Has branches:', model['has_branches'])
except:
pass
try:
print(' Norm value format', model['normalize_params'][2])
except:
pass
try:
print(' Mean values:', model['normalize_params'][0])
except:
pass
try:
print(' Standard deviation values:', model['normalize_params'][1])
except:
pass
try:
test = model['optimizer_state_dict']
print(' Contains saved optimizer state')
except:
pass
try:
test = model['lrscheduler_state_dict']
print(' Contains saved learning rate scheduler state')
except:
pass
try:
test = model['color_correlation_svd_sqrt']
print(' Contains saved color correlation matrix')
except:
pass
if __name__ == "__main__":
main()