forked from eliberis/uNAS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_trainer.py
121 lines (96 loc) · 4.85 KB
/
model_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import logging
from typing import Optional
import tensorflow as tf
from config import TrainingConfig
from pruning import DPFPruning
from utils import debug_mode
import tensorflow_addons as tfa
import pdb
# GPU mem issue
config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
#config.gpu_options.per_process_gpu_memory_fraction = 0.2
sess = tf.compat.v1.Session(config=config)
class ModelTrainer:
"""Trains Keras models according to the specified config."""
def __init__(self, training_config: TrainingConfig):
self.log = logging.getLogger("Model trainer")
self.config = training_config
self.distillation = training_config.distillation
self.pruning = training_config.pruning
self.dataset = training_config.dataset
def train_and_eval(self, model: tf.keras.Model,
epochs: Optional[int] = None, sparsity: Optional[float] = None):
"""
Trains a Keras model and returns its validation set error (1.0 - accuracy).
:param model: A Keras model.
:param epochs: Overrides the duration of training.
:param sparsity: Desired sparsity level (for unstructured sparsity)
:returns Smallest error on validation set seen during training, the error on the test set,
pruned weights (if pruning was used)
"""
dataset = self.config.dataset
batch_size = self.config.batch_size
sparsity = sparsity or 0.0
train = dataset.train_dataset() \
.shuffle(batch_size * 8) \
.batch(batch_size) \
.prefetch(tf.data.experimental.AUTOTUNE)
val = dataset.validation_dataset() \
.batch(batch_size) \
.prefetch(tf.data.experimental.AUTOTUNE)
# TODO: check if this works, make sure we're excluding the last layer from the student
if self.pruning and self.distillation:
raise NotImplementedError()
if self.distillation:
teacher = tf.keras.models.load_model(self.distillation.distill_from)
teacher._name = "teacher_"
teacher.trainable = False
t, a = self.distillation.temperature, self.distillation.alpha
# Assemble a parallel model with the teacher and student
i = tf.keras.Input(shape=dataset.input_shape)
cxent = tf.keras.losses.CategoricalCrossentropy()
stud_logits = model(i)
tchr_logits = teacher(i)
o_stud = tf.keras.layers.Softmax()(stud_logits / t)
o_tchr = tf.keras.layers.Softmax()(tchr_logits / t)
teaching_loss = (a * t * t) * cxent(o_tchr, o_stud)
model = tf.keras.Model(inputs=i, outputs=stud_logits)
model.add_loss(teaching_loss, inputs=True)
if self.dataset.num_classes == 2:
loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
accuracy = tf.keras.metrics.BinaryAccuracy(name="accuracy")
else:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")
model.compile(optimizer=self.config.optimizer(),
loss=loss, metrics=[accuracy])
# TODO: adjust metrics by class weight?
class_weight = {k: v for k, v in enumerate(self.dataset.class_weight())} \
if self.config.use_class_weight else None
epochs = epochs or self.config.epochs
callbacks = self.config.callbacks()
check_logs_from_epoch = 0
pruning_cb = None
#skip by ntk
if epochs != 1:
if self.pruning and sparsity > 0.0:
assert 0.0 < sparsity <= 1.0
self.log.info(f"Target sparsity: {sparsity:.4f}")
pruning_cb = DPFPruning(target_sparsity=sparsity, structured=self.pruning.structured,
start_pruning_at_epoch=self.pruning.start_pruning_at_epoch,
finish_pruning_by_epoch=self.pruning.finish_pruning_by_epoch,
update_iterations=self.pruning.update_iterations)
check_logs_from_epoch = self.pruning.finish_pruning_by_epoch
callbacks.append(pruning_cb)
log = model.fit(train, epochs=epochs, validation_data=val,
verbose=1 if debug_mode() else 2,
callbacks=callbacks, class_weight=class_weight)
test = dataset.test_dataset() \
.batch(batch_size) \
.prefetch(tf.data.experimental.AUTOTUNE)
_, test_acc = model.evaluate(test, verbose=0)
return {
"val_error": 1.0 - max(log.history["val_accuracy"][check_logs_from_epoch:]),
"test_error": 1.0 - test_acc,
"pruned_weights": pruning_cb.weights if pruning_cb else None
}