-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
58 lines (43 loc) · 1.73 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
train = datasets.MNIST(root='data', train=True, transform=ToTensor(), download=True)
# train = datasets.MNIST(root='data', train=True,
# transform=ToTensor(), download=False)
dataset = DataLoader(train, batch_size=32)
class ImageClassifier(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.model = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3)),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3)),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3)),
nn.ReLU(),
nn.Flatten(),
nn.Linear(in_features=64 * (28 - 6) * (28 - 6), out_features=10),
# nn.Softmax(dim=1)
)
def forward(self, x):
return self.model(x)
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clf = ImageClassifier().to(device)
opt = Adam(params=clf.parameters(), lr=1e-3)
loss_func = nn.CrossEntropyLoss()
EPOCHS = 7
for epoch in range(1, EPOCHS + 1):
for batch in dataset:
x, y = map(lambda a: a.to(device), batch)
y_pred = clf(x)
loss = loss_func(y_pred, y)
# Apply Backpropagation
opt.zero_grad()
loss.backward()
opt.step()
print(f'Epoch: {epoch}, Loss: {loss.item()}')
torch.save(clf.state_dict(), 'mnist_classifier_torch.pth')