-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
94 lines (77 loc) · 2.86 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import numpy as np
from torch import optim, nn
from torch.utils.data import DataLoader, Dataset
from learnware.utils import choose_device
@torch.no_grad()
def evaluate(model, evaluate_set: Dataset, device=None, distribution=True):
device = choose_device(0) if device is None else device
if isinstance(model, nn.Module):
model.eval()
mapping = lambda m, x: m(x)
else:
mapping = lambda m, x: m.predict(x)
criterion = nn.CrossEntropyLoss(reduction="sum")
total, correct, loss = 0, 0, torch.as_tensor(0.0, dtype=torch.float32, device=device)
dataloader = DataLoader(evaluate_set, batch_size=1024, shuffle=True)
for i, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
out = mapping(model, X)
if not torch.is_tensor(out):
out = torch.from_numpy(out).to(device)
if distribution:
loss += criterion(out, y)
_, predicted = torch.max(out.data, 1)
else:
predicted = out
total += y.size(0)
correct += (predicted == y).sum().item()
acc = correct / total * 100
loss = loss / total
if isinstance(model, nn.Module):
model.train()
return loss.item(), acc
def train_model(
model: nn.Module,
train_set: Dataset,
valid_set: Dataset,
save_path: str,
epochs=35,
batch_size=128,
device=None,
verbose=True,
):
device = choose_device(0) if device is None else device
model.train()
optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
criterion = nn.CrossEntropyLoss()
dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
best_loss = 100000
for epoch in range(epochs):
running_loss = []
model.train()
for i, (X, y) in enumerate(dataloader):
X, y = X.to(device=device), y.to(device=device)
optimizer.zero_grad()
out = model(X)
loss = criterion(out, y)
loss.backward()
optimizer.step()
running_loss.append(loss.item())
valid_loss, valid_acc = evaluate(model, valid_set, device=device)
train_loss, train_acc = evaluate(model, train_set, device=device)
if valid_loss < best_loss:
best_loss = valid_loss
torch.save(model.state_dict(), save_path)
if verbose:
print("Epoch: {}, Valid Best Accuracy: {:.3f}% ({:.3f})".format(epoch + 1, valid_acc, valid_loss))
if valid_acc > 99.0:
if verbose:
print("Early Stopping at 99% !")
break
if verbose and (epoch + 1) % 5 == 0:
print(
"Epoch: {}, Train Average Loss: {:.3f}, Accuracy {:.3f}%, Valid Average Loss: {:.3f}".format(
epoch + 1, np.mean(running_loss), train_acc, valid_loss
)
)