import torch_xla.test.test_utils as test_utils import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.core.xla_model as xm import torch_xla.utils.utils as xu import torch_xla.distributed.parallel_loader as pl import torch_xla.debug.metrics as met import torch_xla import torchvision.transforms as transforms import torchvision import torch.optim as optim import torch.nn.functional as F import torch.nn as nn import torch import numpy as np import sys import os import webdataset as wds # import warnings from itertools import islice for extra in ('/usr/share/torch-xla-1.7/pytorch/xla/test', '/pytorch/xla/test'): if os.path.exists(extra): sys.path.insert(0, extra) import schedulers # import gcsdataset import args_parse SUPPORTED_MODELS = [ 'alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn' ] MODEL_OPTS = { '--model': { 'choices': SUPPORTED_MODELS, 'default': 'resnet50', }, '--test_set_batch_size': { 'type': int, }, '--lr_scheduler_type': { 'type': str, }, '--lr_scheduler_divide_every_n_epochs': { 'type': int, }, '--lr_scheduler_divisor': { 'type': int, }, '--dataset': { 'choices': ['gcsdataset', 'torchdataset'], 'default': 'gcsdataset', 'type': str, }, } FLAGS = args_parse.parse_common_options( datadir='/tmp/imagenet', batch_size=None, num_epochs=None, momentum=None, lr=None, target_accuracy=None, opts=MODEL_OPTS.items(), ) DEFAULT_KWARGS = dict( batch_size=128, test_set_batch_size=64, num_epochs=18, momentum=0.9, lr=0.1, target_accuracy=0.0, ) MODEL_SPECIFIC_DEFAULTS = { # Override some of the args in DEFAULT_KWARGS, or add them to the dict # if they don't exist. 'resnet50': dict( DEFAULT_KWARGS, **{ 'lr': 0.5, 'lr_scheduler_divide_every_n_epochs': 20, 'lr_scheduler_divisor': 5, 'lr_scheduler_type': 'WarmupAndExponentialDecayScheduler', }) } # Set any args that were not explicitly given by the user. default_value_dict = MODEL_SPECIFIC_DEFAULTS.get(FLAGS.model, DEFAULT_KWARGS) for arg, value in default_value_dict.items(): if getattr(FLAGS, arg) is None: setattr(FLAGS, arg, value) def get_model_property(key): default_model_property = { 'img_dim': 224, 'model_fn': getattr(torchvision.models, FLAGS.model) } model_properties = { 'inception_v3': { 'img_dim': 299, 'model_fn': lambda: torchvision.models.inception_v3(aux_logits=False) }, } model_fn = model_properties.get(FLAGS.model, default_model_property)[key] return model_fn def _train_update(device, step, loss, tracker, epoch, writer): test_utils.print_training_update( device, step, loss.item(), tracker.rate(), tracker.global_rate(), epoch, summary_writer=writer) ##### WDS ######## # tpu-demo-eu-west/imagenet-wds/wds-data/shards # traindir = os.path.join(args.data, "train") # train_urls = "gs://tpu-demo-eu-west/imagenet-wds/wds-data/shards/imagenet-train-{000000..001281}.tar" # train_urls = f"pipe:curl -L -s {train_urls} || true" # train_urls = "pipe:gsutil cat gs://tpu-demo-eu-west/imagenet-wds/wds-data/shards/imagenet-train-{000000..001281}.tar" # val_urls = "gs://tpu-demo-eu-west/imagenet-wds/wds-data/shards/imagenet-val-{000000..000049}.tar" # val_urls = f"pipe:curl -L -s {val_urls} || true" # val_urls = "pipe:gsutil cat gs://tpu-demo-eu-west/imagenet-wds/wds-data/shards/imagenet-val-{000000..000049}.tar" trainsize = 1281167 testsize = 50000 # shuffle=1000 # shuffle buffer size for webdataset # num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers # epoch_size = datasets_size // num_dataset_instances # num_batches = (epoch_size + batch_size - 1) // batch_size def identity(x): return x def repeatedly(loader, nepochs=999999999, nbatches=999999999999): """Repeatedly returns batches from a DataLoader.""" for epoch in range(nepochs): for sample in islice(loader, nbatches): yield sample # def split_by_worker(urls): # """Split urls per worker # Selects a subset of urls based on Torch get_worker_info. # Used as a shard selection function in Dataset.""" # # import torch # urls = [url for url in urls] # assert isinstance(urls, list) # worker_info = torch.utils.data.get_worker_info() # if worker_info is not None: # wid = worker_info.id # num_workers = worker_info.num_workers # if wid == 0 and len(urls) < num_workers: # warnings.warn(f"num_workers {num_workers} > num_shards {len(urls)}") # return urls[wid::num_workers] # else: # return urls def my_node_splitter(urls): """Split urls_ correctly per accelerator node :param urls: :return: slice of urls_ """ rank=xm.get_ordinal() num_replicas=xm.xrt_world_size() urls_this = urls[rank::num_replicas] return urls_this normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) def make_train_loader(img_dim, shuffle=10000, batch_size=FLAGS.batch_size): num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers epoch_size = trainsize // num_dataset_instances # num_batches = (epoch_size + batch_size - 1) // batch_size num_batches = epoch_size // batch_size image_transform = transforms.Compose( [ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ] ) # something like "nodesplitter=lambda l: l[node_index::num_nodes]" # length=num_batches dataset = ( wds.WebDataset("pipe:gsutil cat gs://tpu-demo-eu-west/imagenet-wds/wds-data/shards/imagenet-train-{000000..001281}.tar", splitter=wds.split_by_worker, nodesplitter=my_node_splitter, shardshuffle=True, length=num_batches) # length=epoch_size # .unbatched() .shuffle(shuffle) .decode("pil") # handler=wds.warn_and_continue .to_tuple("ppm;jpg;jpeg;png", "cls") # handler=wds.warn_and_continue .map_tuple(image_transform, identity) # handler=wds.warn_and_continue .batched(batch_size) ) # dataset = wds.ResizedDataset(dataset, epoch_size, nominal=num_batches) loader = torch.utils.data.DataLoader(dataset, batch_size=None, shuffle=False, num_workers=FLAGS.num_workers) return loader def make_val_loader(img_dim, resize_dim, batch_size=FLAGS.test_set_batch_size): num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers epoch_test_size = testsize // num_dataset_instances # num_batches = (epoch_size + batch_size - 1) // batch_size num_test_batches = epoch_test_size // batch_size val_transform = transforms.Compose( [ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ] ) val_dataset = ( wds.WebDataset("pipe:gsutil cat gs://tpu-demo-eu-west/imagenet-wds/wds-data/shards/imagenet-val-{000000..000049}.tar", splitter=None, nodesplitter=wds.split_by_worker, shardshuffle=my_node_splitter, length=num_test_batches) # splitter=wds.split_by_worker, nodesplitter=my_node_splitter .decode("pil") # handler=wds.warn_and_continue .to_tuple("ppm;jpg;jpeg;png", "cls") # handler=wds.warn_and_continue .map_tuple(val_transform, identity) # handler=wds.warn_and_continue .batched(batch_size) ) # val_dataset = wds.ResizedDataset(val_dataset, epoch_test_size, nominal=num_test_batches) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=None, shuffle=False, num_workers=FLAGS.num_workers, pin_memory=False) # pin_memory=False return val_loader def train_imagenet(): print('==> Preparing data..') # print('world_size:', xm.xrt_world_size()) img_dim = get_model_property('img_dim') resize_dim = max(img_dim, 256) train_loader = make_train_loader(img_dim, batch_size=FLAGS.batch_size, shuffle=10000) test_loader = make_val_loader(img_dim, resize_dim, batch_size=FLAGS.test_set_batch_size) torch.manual_seed(42) device = xm.xla_device() model = get_model_property('model_fn')().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD( model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = trainsize // ( FLAGS.batch_size * xm.xrt_world_size()) # num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers # epoch_size = trainsize // num_dataset_instances # num_batches = (epoch_size + FLAGS.batch_size - 1) // FLAGS.batch_size lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer) loss_fn = nn.CrossEntropyLoss() # def train_loop_fn(loader, epoch): # tracker = xm.RateTracker() # model.train() # iters = 0 # for step, (data, target) in enumerate(loader): # iters += 1 # if step % 1000 == 0: # print("Step {}".format(step)) # optimizer.zero_grad() # output = model(data) # loss = loss_fn(output, target) # loss.backward() # xm.optimizer_step(optimizer) # tracker.add(FLAGS.batch_size) # if lr_scheduler: # lr_scheduler.step() # if step % FLAGS.log_steps == 0: # xm.add_step_closure( # _train_update, args=(device, step, loss, tracker, epoch, writer)) def train_loop_fn(loader, epoch): num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers epoch_size = trainsize // num_dataset_instances num_batches = epoch_size // FLAGS.batch_size tracker = xm.RateTracker() model.train() # iters = 0 for step, (data, target) in enumerate(islice(repeatedly(loader), 0, num_batches)): #wds.repeatedly # if iters == num_batches: # break # iters += 1 optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: xm.add_step_closure( _train_update, args=(device, step, loss, tracker, epoch, writer)) # if step % 300 == 0: # break def test_loop_fn(loader, epoch): num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers epoch_test_size = testsize // num_dataset_instances num_test_batches = epoch_test_size // FLAGS.test_set_batch_size total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(islice(loader, 0, num_test_batches)): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: xm.add_step_closure( test_utils.print_test_update, args=(device, None, epoch, step)) correct_val = correct.item() # pytype: disable=attribute-error accuracy_replica = 100.0 * correct_val / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy_replica, np.mean) return accuracy, accuracy_replica train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, FLAGS.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format( epoch, test_utils.now())) accuracy, accuracy_replica = test_loop_fn(test_device_loader, epoch) xm.master_print('Epoch {} test end {}, Reduced Accuracy={:.2f}, Replica Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy, accuracy_replica)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy def _mp_fn(index, flags): global FLAGS FLAGS = flags torch.set_default_tensor_type('torch.FloatTensor') accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, FLAGS.target_accuracy)) sys.exit(21) if __name__ == '__main__': xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)