-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
107 lines (88 loc) · 3.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from pathlib import Path
import torch
import torch.cuda
import torch.optim
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import MNIST
from ignite.engine.events import Events
import numpy as np
from vae_trainer import VAETrainer
from model import FullyConnectedVAE, CNNVAE
from evaluation import VAEEvaluator
from handlers import EvaluationRunner, ModelSaver, StateLogger, Plotter, LogPrinter
def parse_args():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-e", type=int, default=20, help="epoch")
parser.add_argument("-b", type=int, default=64, help="batch size")
parser.add_argument(
"--zdim",
type=int,
default=20,
help="number of dimensions of latent space")
parser.add_argument(
"-m",
required=True,
choices=[
"fc",
"cnn"],
help="model architecture")
parser.add_argument(
"-o",
default="outputs",
help="ouput directory")
args = parser.parse_args()
main(args)
def main(args):
if torch.cuda.is_available():
device = torch.device("cuda")
print("GPU mode")
else:
device = torch.device("cpu")
print("CPU mode")
Path(args.o).mkdir(parents=True, exist_ok=True)
mnist_transform = lambda x: np.asarray(
x, dtype=np.float32).reshape(1, 28, 28) / 255
train_dataset = MNIST(
root=".",
download=True,
train=True,
transform=mnist_transform)
test_dataset = MNIST(
root=".",
download=True,
train=False,
transform=mnist_transform)
train_loader = DataLoader(train_dataset, args.b, shuffle=True)
test_loader = DataLoader(test_dataset, args.b)
match args.m:
case "fc":
net = FullyConnectedVAE(28 * 28, args.zdim).to(device)
case "cnn":
net = CNNVAE(1, args.zdim).to(device)
case _:
raise ValueError(f"Invalid model: {args.m!r}")
opt = torch.optim.Adam(net.parameters())
trainer = VAETrainer(net, opt, device)
evaluator = VAEEvaluator(net, device)
train_logger = StateLogger(trainer)
test_logger = StateLogger(evaluator)
trainer.add_event_handler(Events.EPOCH_COMPLETED, train_logger)
trainer.add_event_handler(
Events.EPOCH_COMPLETED,
EvaluationRunner(evaluator, test_loader))
metric_keys = ("kl_div", "recon_loss")
trainer.add_event_handler(Events.EPOCH_COMPLETED, test_logger)
trainer.add_event_handler(Events.EPOCH_COMPLETED, Plotter(
train_logger, metric_keys, Path(args.o, "train_loss.pdf")))
trainer.add_event_handler(Events.EPOCH_COMPLETED, Plotter(
test_logger, metric_keys, Path(args.o, "test_loss.pdf")))
trainer.add_event_handler(
Events.EPOCH_COMPLETED, LogPrinter(train_logger, metric_keys))
trainer.add_event_handler(
Events.EPOCH_COMPLETED, LogPrinter(test_logger, metric_keys))
trainer.add_event_handler(
Events.COMPLETED, ModelSaver(net, Path(args.o, "model.pt")))
trainer.run(train_loader, max_epochs=args.e)
if __name__ == "__main__":
parse_args()