-
-
Notifications
You must be signed in to change notification settings - Fork 84
/
base.py
279 lines (236 loc) · 9.79 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
"""
The base classes for PyPOTS classification models.
"""
# Created by Wenjie Du <[email protected]>
# License: GPL-v3
from abc import abstractmethod
from typing import Optional, Union
import numpy as np
import torch
from torch.utils.data import DataLoader
from pypots.base import BaseModel, BaseNNModel
from pypots.utils.logging import logger
class BaseClassifier(BaseModel):
"""The abstract class for all PyPOTS classification models.
Parameters
---
device
saving_path
"""
def __init__(
self,
device: Optional[Union[str, torch.device]] = None,
saving_path: str = None,
model_saving_strategy: Optional[str] = "best",
):
super().__init__(
device,
saving_path,
model_saving_strategy,
)
@abstractmethod
def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
) -> None:
"""Train the classifier on the given data.
Parameters
----------
train_set : dict or str,
The dataset for model training, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for training, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
val_set : dict or str,
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type : str, default = "h5py",
The type of the given file if train_set and val_set are path strings.
"""
pass
@abstractmethod
def classify(
self,
X: Union[dict, str],
file_type: str = "h5py",
) -> np.ndarray:
"""Classify the input data with the trained model.
Parameters
----------
X : array-like or str,
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
n_features], or a path string locating a data file, e.g. h5 file.
file_type : str, default = "h5py",
The type of the given file if X is a path string.
Returns
-------
array-like, shape [n_samples],
Classification results of the given samples.
"""
pass
class BaseNNClassifier(BaseNNModel, BaseClassifier):
def __init__(
self,
n_classes: int,
batch_size: int,
epochs: int,
patience: int,
learning_rate: float,
weight_decay: float,
num_workers: int = 0,
device: Optional[Union[str, torch.device]] = None,
saving_path: str = None,
model_saving_strategy: Optional[str] = "best",
):
super().__init__(
batch_size,
epochs,
patience,
learning_rate,
weight_decay,
num_workers,
device,
saving_path,
model_saving_strategy,
)
self.n_classes = n_classes
@abstractmethod
def _assemble_input_for_training(self, data) -> dict:
"""Assemble the given data into a dictionary for training input.
Parameters
----------
data : list,
Input data from dataloader, should be list.
Returns
-------
dict,
A python dictionary contains the input data for model training.
"""
pass
@abstractmethod
def _assemble_input_for_validating(self, data) -> dict:
"""Assemble the given data into a dictionary for validating input.
Parameters
----------
data : list,
Data output from dataloader, should be list.
Returns
-------
dict,
A python dictionary contains the input data for model validating.
"""
pass
@abstractmethod
def _assemble_input_for_testing(self, data) -> dict:
"""Assemble the given data into a dictionary for testing input.
Notes
-----
The processing functions of train/val/test stages are separated for the situation that the input of
the three stages are different, and this situation usually happens when the Dataset/Dataloader classes
used in the train/val/test stages are not the same, e.g. the training data and validating data in a
classification task contains labels, but the testing data (from the production environment) generally
doesn't have labels.
Parameters
----------
data : list,
Data output from dataloader, should be list.
Returns
-------
dict,
A python dictionary contains the input data for model testing.
"""
pass
def _train_model(
self,
training_loader: DataLoader,
val_loader: DataLoader = None,
) -> None:
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay
)
# each training starts from the very beginning, so reset the loss and model dict here
self.best_loss = float("inf")
self.best_model_dict = None
try:
training_step = 0
for epoch in range(self.epochs):
self.model.train()
epoch_train_loss_collector = []
for idx, data in enumerate(training_loader):
training_step += 1
inputs = self._assemble_input_for_training(data)
self.optimizer.zero_grad()
results = self.model.forward(inputs)
results["loss"].backward()
self.optimizer.step()
epoch_train_loss_collector.append(results["loss"].item())
# save training loss logs into the tensorboard file for every step if in need
if self.summary_writer is not None:
self._save_log_into_tb_file(training_step, "training", results)
# mean training loss of the current epoch
mean_train_loss = np.mean(epoch_train_loss_collector)
if val_loader is not None:
self.model.eval()
epoch_val_loss_collector = []
with torch.no_grad():
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(inputs)
epoch_val_loss_collector.append(results["loss"].item())
mean_val_loss = np.mean(epoch_val_loss_collector)
# save validating loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"classification_loss": mean_val_loss,
}
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)
logger.info(
f"epoch {epoch}: "
f"training loss {mean_train_loss:.4f}, "
f"validating loss {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
mean_loss = mean_train_loss
if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
# save the model if necessary
self._auto_save_model_if_necessary(
training_finished=False,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)
else:
self.patience -= 1
if self.patience == 0:
logger.info(
"Exceeded the training patience. Terminating the training procedure..."
)
break
except Exception as e:
logger.info(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not get trained. Please try fit() again."
)
else:
RuntimeWarning(
"Training got interrupted. "
"Model will load the best parameters so far for testing. "
"If you don't want it, please try fit() again."
)
if np.equal(self.best_loss, float("inf")):
raise ValueError("Something is wrong. best_loss is Nan after training.")
logger.info("Finished training.")