-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_classes.py
65 lines (54 loc) · 2.87 KB
/
test_classes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch.nn as nn
import os
import numpy as np
import torch
import warnings
from dataset_loader import load_a_dataset
from General_model import Basic_CNN, CNN_DeepJDOT, CNNMAD
from MAD import MAD
# Load data
X_source, y_source, X_valid_source, y_valid_source, X_test_source, y_test_source = load_a_dataset(dataset_name="Dataset/TinyTimeMatch/TinyTimeMatch", domain_id=3) # 3: Denmark
X_target, y_target, X_valid_target, y_valid_target, X_test_target, y_test_target = load_a_dataset(dataset_name="Dataset/TinyTimeMatch/TinyTimeMatch", domain_id=4) # 4: Austria
n_classes = len(set(y_source.tolist()))
n_features = X_source.shape[-1]
feature_extractor = nn.Sequential(
nn.Conv1d(in_channels=n_features, out_channels=128, kernel_size=8, stride=1, padding="same", bias=False),
nn.BatchNorm1d(num_features=128, affine=False),
nn.ReLU(),
nn.Conv1d(in_channels=128, out_channels=256, kernel_size=5, stride=1, padding="same", bias=False),
nn.BatchNorm1d(num_features=256, affine=False),
nn.ReLU(),
nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding="same", bias=False),
nn.BatchNorm1d(num_features=128, affine=False),
nn.ReLU())
classifier = nn.Sequential(nn.Linear(128, n_classes))
model_MAD_Deep = CNNMAD(batchsize=128,
feature_extractor=feature_extractor,
classifier=classifier,
alpha=0.1,
beta=0.01,
MAD_class=True, #To be set to True if we want C-MAD or to False if we want MAD
lr=1e-4,
X_target=X_target,
max_iterations=2)
model_MAD_Deep_avec_sauvegardes = CNNMAD(batchsize=128,
feature_extractor=feature_extractor,
classifier=classifier,
alpha=0.1,
beta=0.01,
MAD_class=True, #To be set to True if we want C-MAD or to False if we want MAD
lr=1e-4,
X_target=X_target,
save_DTW_matrices=True,
save_OT_plan=True,
max_iterations=2)
model_MAD = MAD()
for model in [model_MAD_Deep, model_MAD_Deep_avec_sauvegardes, model_MAD]:
model.fit(X_source, y_source)
model.fit(X_source, y_source, X_valid_source, y_valid_source)
model.predict(X_test_target)
print(model.evaluate(X_test_target, y_test_target))
if model.save_DTW_matrices:
print(len(model.history_DTW_matrices), model.history_DTW_matrices[0][0].shape)
if model.save_OT_plan:
print(len(model.history_OT_plan), model.history_OT_plan[0].shape)