-
Notifications
You must be signed in to change notification settings - Fork 260
/
__init__.py
162 lines (141 loc) · 5.39 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from argparse import Namespace
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torchvision import models
from typing import Tuple
from .moco.builder import MoCo
from .main_moco import adjust_learning_rate
from ...util.model import BenchmarkModel
from torchbenchmark.tasks import OTHER
cudnn.deterministic = False
cudnn.benchmark = True
class Model(BenchmarkModel):
task = OTHER.OTHER_TASKS
# Original train batch size: 32
# Paper and code uses batch size of 256 for 8 GPUs.
# Source: https://arxiv.org/pdf/1911.05722.pdf
DEFAULT_TRAIN_BSIZE = 32
DEFAULT_EVAL_BSIZE = 32
def __init__(self, test, device, batch_size=None, extra_args=[]):
super().__init__(
test=test, device=device, batch_size=batch_size, extra_args=extra_args
)
self.opt = Namespace(
**{
"arch": "resnet50",
"epochs": 2,
"start_epoch": 0,
"lr": 0.03,
"schedule": [120, 160],
"momentum": 0.9,
"weight_decay": 1e-4,
"gpu": None,
"moco_dim": 128,
"moco_k": 32000,
"moco_m": 0.999,
"moco_t": 0.07,
"mlp": False,
"aug_plus": False,
"cos": False,
"fake_data": True,
"distributed": True,
}
)
try:
dist.init_process_group(
backend="nccl",
init_method="tcp:https://localhost:10001",
world_size=1,
rank=0,
)
except RuntimeError:
pass # already initialized?
if device == "cpu":
raise NotImplementedError("DistributedDataParallel/allgather requires cuda")
self.model = MoCo(
models.__dict__[self.opt.arch],
self.opt.moco_dim,
self.opt.moco_k,
self.opt.moco_m,
self.opt.moco_t,
self.opt.mlp,
)
self.model.to(self.device)
self.model = torch.nn.parallel.DistributedDataParallel(
self.model, device_ids=[0]
)
# Define loss function (criterion) and optimizer
self.criterion = nn.CrossEntropyLoss().to(self.device)
self.optimizer = torch.optim.SGD(
self.model.parameters(),
self.opt.lr,
momentum=self.opt.momentum,
weight_decay=self.opt.weight_decay,
)
def collate_train_fn(data):
ind = data[0]
return [batches[2 * ind], batches[2 * ind + 1]], 0
batches = []
for i in range(4):
batches.append(torch.randn(self.batch_size, 3, 224, 224).to(self.device))
self.example_inputs = torch.utils.data.DataLoader(
range(2), collate_fn=collate_train_fn
)
for i, (images, _) in enumerate(self.example_inputs):
images[0] = images[0].cuda(device=0, non_blocking=True)
images[1] = images[1].cuda(device=0, non_blocking=True)
def get_module(self):
"""Recommended
Returns model, example_inputs
Both model and example_inputs should be on self.device properly.
`model(*example_inputs)` should execute one step of model forward.
"""
images = []
for (i, _) in self.example_inputs:
images = (i[0], i[1])
return self.model, images
def get_optimizer(self):
"""Returns the current optimizer"""
return self.optimizer
def get_optimizer(self, optimizer) -> None:
"""Sets the optimizer for future training"""
self.optimizer = optimizer
def train(self):
"""Recommended
Runs training on model for one epoch.
Avoid unnecessary benchmark noise by keeping any tensor creation, memcopy operations in __init__.
Leave warmup to the caller (e.g. don't do it inside)
"""
self.model.train()
n_epochs = 1
for e in range(n_epochs):
adjust_learning_rate(self.optimizer, e, self.opt)
for i, (images, _) in enumerate(self.example_inputs):
# compute output
output, target = self.model(im_q=images[0], im_k=images[1])
loss = self.criterion(output, target)
# compute gradient and do SGD step
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def eval(self) -> Tuple[torch.Tensor]:
"""Recommended
Run evaluation on model for one iteration. One iteration should be sufficient
to warm up the model for the purpose of profiling.
In most cases this can use the `get_module` API but in some cases libraries
do not have a single Module object used for inference. In these case, you can
write a custom eval function.
Avoid unnecessary benchmark noise by keeping any tensor creation, memcopy operations in __init__.
Leave warmup to the caller (e.g. don't do it inside)
"""
for i, (images, _) in enumerate(self.example_inputs):
out = self.model(im_q=images[0], im_k=images[1])
return out