-
Notifications
You must be signed in to change notification settings - Fork 19
/
imagenet_dali.py
118 lines (106 loc) · 5.05 KB
/
imagenet_dali.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# https://github.com/NVIDIA/DALI/blob/master/docs/examples/use_cases/pytorch/resnet50/main.py
import argparse
import os
import shutil
import time
import math
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import numpy as np
try:
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
from nvidia.dali.pipeline import pipeline_def
import nvidia.dali.types as types
import nvidia.dali.fn as fn
except ImportError:
raise ImportError("Please install DALI from https://www.github.com/NVIDIA/DALI to run this example.")
@pipeline_def
def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu=False, is_training=True):
images, labels = fn.readers.file(file_root=data_dir,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=is_training,
pad_last_batch=True,
name="Reader")
dali_device = 'cpu' if dali_cpu else 'gpu'
decoder_device = 'cpu' if dali_cpu else 'mixed'
device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
if is_training:
images = fn.decoders.image_random_crop(images,
device=decoder_device, output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
random_aspect_ratio=[0.8, 1.25],
random_area=[0.1, 1.0],
num_attempts=100)
images = fn.resize(images,
device=dali_device,
resize_x=crop,
resize_y=crop,
interp_type=types.INTERP_TRIANGULAR)
mirror = fn.random.coin_flip(probability=0.5)
else:
images = fn.decoders.image(images,
device=decoder_device,
output_type=types.RGB)
images = fn.resize(images,
device=dali_device,
size=size,
mode="not_smaller",
interp_type=types.INTERP_TRIANGULAR)
mirror = False
images = fn.crop_mirror_normalize(images.gpu(),
dtype=types.FLOAT,
output_layout="CHW",
crop=(crop, crop),
mean=[0.485 * 255,0.456 * 255,0.406 * 255],
std=[0.229 * 255,0.224 * 255,0.225 * 255],
mirror=mirror)
labels = labels.gpu()
return images, labels
def get_dali_data_loader(args):
crop_size = 224
val_size = 256
path = '../data'
data_folder = os.path.join(path, args.dataset)
if not os.path.isdir(data_folder):
print('Please place the ImageNet dataset at: ', path)
traindir = os.path.join(data_folder, 'train')
valdir = os.path.join(data_folder, 'val')
pipe = create_dali_pipeline(batch_size=args.batch_size,
num_threads=args.num_workers,
device_id=args.rank,
seed=12 + args.rank,
data_dir=traindir,
crop=crop_size,
size=val_size,
dali_cpu=args.dali == 'cpu',
shard_id=args.rank,
num_shards=args.world_size,
is_training=True)
pipe.build()
train_loader = DALIClassificationIterator(pipe, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL)
pipe = create_dali_pipeline(batch_size=args.batch_size,
num_threads=args.num_workers,
device_id=args.rank,
seed=12 + args.rank,
data_dir=valdir,
crop=crop_size,
size=val_size,
dali_cpu=args.dali == 'cpu',
shard_id=args.rank,
num_shards=args.world_size,
is_training=False)
pipe.build()
val_loader = DALIClassificationIterator(pipe, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL)
return train_loader, val_loader