-
Notifications
You must be signed in to change notification settings - Fork 14
/
imagenet_dali.py
138 lines (125 loc) · 6.34 KB
/
imagenet_dali.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# https://github.com/NVIDIA/DALI/blob/master/docs/examples/use_cases/pytorch/resnet50/main.py
import argparse
import os
import shutil
import time
import math
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import numpy as np
try:
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
except ImportError:
raise ImportError("Please install DALI from https://www.github.com/NVIDIA/DALI to run this example.")
class HybridTrainPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop,
shard_id, num_shards, dali_cpu=False):
super(HybridTrainPipe, self).__init__(batch_size,
num_threads,
device_id,
seed=12 + device_id)
self.input = ops.FileReader(file_root=data_dir,
shard_id=shard_id,
num_shards=num_shards,
shuffle_after_epoch=True,
pad_last_batch=True)
#let user decide which pipeline works him bets for RN version he runs
dali_device = 'cpu' if dali_cpu else 'gpu'
decoder_device = 'cpu' if dali_cpu else 'mixed'
# This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
# without additional reallocations
device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
self.decode = ops.ImageDecoderRandomCrop(device=decoder_device, output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
random_aspect_ratio=[0.8, 1.25],
random_area=[0.1, 1.0],
num_attempts=100)
self.res = ops.Resize(device=dali_device,
resize_x=crop,
resize_y=crop,
interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device="gpu",
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
mean=[0.485 * 255,0.456 * 255,0.406 * 255],
std=[0.229 * 255,0.224 * 255,0.225 * 255])
self.coin = ops.CoinFlip(probability=0.5)
print('DALI "{0}" variant'.format(dali_device))
def define_graph(self):
rng = self.coin()
self.jpegs, self.labels = self.input(name="Reader")
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images.gpu(), mirror=rng)
return [output, self.labels]
class HybridValPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop,
size, shard_id, num_shards):
super(HybridValPipe, self).__init__(batch_size,
num_threads,
device_id,
seed=12 + device_id)
self.input = ops.FileReader(file_root=data_dir,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=False,
pad_last_batch=True)
self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
self.res = ops.Resize(device="gpu",
resize_shorter=size,
interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device="gpu",
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
mean=[0.485 * 255,0.456 * 255,0.406 * 255],
std=[0.229 * 255,0.224 * 255,0.225 * 255])
def define_graph(self):
self.jpegs, self.labels = self.input(name="Reader")
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images)
return [output, self.labels]
from dataset.imagenet import get_data_folder
def get_dali_data_loader(args):
crop_size = 224
val_size = 256
data_folder = get_data_folder(args.dataset)
train_folder = os.path.join(data_folder, 'train')
val_folder = os.path.join(data_folder, 'val')
pipe = HybridTrainPipe(batch_size=args.batch_size,
num_threads=args.num_workers,
device_id=args.rank,
data_dir=train_folder,
crop=crop_size,
dali_cpu=args.dali == 'cpu',
shard_id=args.rank,
num_shards=args.world_size)
pipe.build()
train_loader = DALIClassificationIterator(pipe, reader_name="Reader", fill_last_batch=False)
pipe = HybridValPipe(batch_size=args.batch_size,
num_threads=args.num_workers,
device_id=args.rank,
data_dir=val_folder,
crop=crop_size,
size=val_size,
shard_id=args.rank,
num_shards=args.world_size)
pipe.build()
val_loader = DALIClassificationIterator(pipe, reader_name="Reader", fill_last_batch=False)
return train_loader, val_loader