# https://github.com/NVIDIA/DALI/blob/master/docs/examples/use_cases/pytorch/resnet50/main.py import argparse import os import shutil import time import math import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.optim import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models import numpy as np try: from nvidia.dali.plugin.pytorch import DALIClassificationIterator from nvidia.dali.pipeline import Pipeline import nvidia.dali.ops as ops import nvidia.dali.types as types except ImportError: raise ImportError("Please install DALI from https://www.github.com/NVIDIA/DALI to run this example.") class HybridTrainPipe(Pipeline): def __init__(self, batch_size, num_threads, device_id, data_dir, crop, shard_id, num_shards, dali_cpu=False): super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id) self.input = ops.FileReader(file_root=data_dir, shard_id=shard_id, num_shards=num_shards, shuffle_after_epoch=True, pad_last_batch=True) #let user decide which pipeline works him bets for RN version he runs dali_device = 'cpu' if dali_cpu else 'gpu' decoder_device = 'cpu' if dali_cpu else 'mixed' # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet # without additional reallocations device_memory_padding = 211025920 if decoder_device == 'mixed' else 0 host_memory_padding = 140544512 if decoder_device == 'mixed' else 0 self.decode = ops.ImageDecoderRandomCrop(device=decoder_device, output_type=types.RGB, device_memory_padding=device_memory_padding, host_memory_padding=host_memory_padding, random_aspect_ratio=[0.8, 1.25], random_area=[0.1, 1.0], num_attempts=100) self.res = ops.Resize(device=dali_device, resize_x=crop, resize_y=crop, interp_type=types.INTERP_TRIANGULAR) self.cmnp = ops.CropMirrorNormalize(device="gpu", output_dtype=types.FLOAT, output_layout=types.NCHW, crop=(crop, crop), mean=[0.485 * 255,0.456 * 255,0.406 * 255], std=[0.229 * 255,0.224 * 255,0.225 * 255]) self.coin = ops.CoinFlip(probability=0.5) print('DALI "{0}" variant'.format(dali_device)) def define_graph(self): rng = self.coin() self.jpegs, self.labels = self.input(name="Reader") images = self.decode(self.jpegs) images = self.res(images) output = self.cmnp(images.gpu(), mirror=rng) return [output, self.labels] class HybridValPipe(Pipeline): def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, shard_id, num_shards): super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id) self.input = ops.FileReader(file_root=data_dir, shard_id=shard_id, num_shards=num_shards, random_shuffle=False, pad_last_batch=True) self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) self.res = ops.Resize(device="gpu", resize_shorter=size, interp_type=types.INTERP_TRIANGULAR) self.cmnp = ops.CropMirrorNormalize(device="gpu", output_dtype=types.FLOAT, output_layout=types.NCHW, crop=(crop, crop), mean=[0.485 * 255,0.456 * 255,0.406 * 255], std=[0.229 * 255,0.224 * 255,0.225 * 255]) def define_graph(self): self.jpegs, self.labels = self.input(name="Reader") images = self.decode(self.jpegs) images = self.res(images) output = self.cmnp(images) return [output, self.labels] from dataset.imagenet import get_data_folder def get_dali_data_loader(args): crop_size = 224 val_size = 256 data_folder = get_data_folder(args.dataset) train_folder = os.path.join(data_folder, 'train') val_folder = os.path.join(data_folder, 'val') pipe = HybridTrainPipe(batch_size=args.batch_size, num_threads=args.num_workers, device_id=args.rank, data_dir=train_folder, crop=crop_size, dali_cpu=args.dali == 'cpu', shard_id=args.rank, num_shards=args.world_size) pipe.build() train_loader = DALIClassificationIterator(pipe, reader_name="Reader", fill_last_batch=False) pipe = HybridValPipe(batch_size=args.batch_size, num_threads=args.num_workers, device_id=args.rank, data_dir=val_folder, crop=crop_size, size=val_size, shard_id=args.rank, num_shards=args.world_size) pipe.build() val_loader = DALIClassificationIterator(pipe, reader_name="Reader", fill_last_batch=False) return train_loader, val_loader