Skip to content

Commit

Permalink
distributed save&load, imagenette, new kd.sh, fix validation error
Browse files Browse the repository at this point in the history
  • Loading branch information
triomino committed Jul 15, 2020
1 parent ed14d11 commit 4db1cd4
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 41 deletions.
28 changes: 28 additions & 0 deletions dataset/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# https://github.com/tanglang96/DataLoaders_DALI/blob/master/base.py

from nvidia.dali.plugin.pytorch import DALIGenericIterator

class DALIDataloader(DALIGenericIterator):
def __init__(self, pipeline, size, batch_size, output_map=["data", "label"], auto_reset=True, onehot_label=False):
self.size = size
self.batch_size = batch_size
self.onehot_label = onehot_label
self.output_map = output_map
super().__init__(pipelines=pipeline, size=size, auto_reset=auto_reset, output_map=output_map)

def __next__(self):
if self._first_batch is not None:
batch = self._first_batch
self._first_batch = None
return batch
data = super().__next__()[0]
if self.onehot_label:
return [data[self.output_map[0]], data[self.output_map[1]].squeeze().long()]
else:
return [data[self.output_map[0]], data[self.output_map[1]]]

def __len__(self):
if self.size%self.batch_size==0:
return self.size//self.batch_size
else:
return self.size//self.batch_size+1
17 changes: 9 additions & 8 deletions dataset/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@

from dataset.folder2lmdb import ImageFolderLMDB

def get_data_folder():
imagenet_list = ['imagenet', 'imagenette']
def get_data_folder(dataset = 'imagenet'):
"""
return the path to store the data
"""
data_folder = './data/imagenet'
data_folder = os.path.join('./data', dataset)

if not os.path.isdir(data_folder):
os.makedirs(data_folder)
Expand Down Expand Up @@ -107,8 +108,8 @@ def __getitem__(self, index):
def get_test_loader(dataset='imagenet', batch_size=128, num_workers=8):
"""get the test data loader"""

if dataset == 'imagenet':
data_folder = get_data_folder()
if dataset in imagenet_list:
data_folder = get_data_folder(dataset)
else:
raise NotImplementedError('dataset not supported: {}'.format(dataset))

Expand All @@ -135,8 +136,8 @@ def get_test_loader(dataset='imagenet', batch_size=128, num_workers=8):
def get_dataloader_sample(dataset='imagenet', batch_size=128, num_workers=8, is_sample=False, k=4096):
"""Data Loader for ImageNet"""

if dataset == 'imagenet':
data_folder = get_data_folder()
if dataset in imagenet_list:
data_folder = get_data_folder(dataset)
else:
raise NotImplementedError('dataset not supported: {}'.format(dataset))

Expand Down Expand Up @@ -182,8 +183,8 @@ def get_imagenet_dataloader(dataset='imagenet', batch_size=128, num_workers=16,
"""
Data Loader for imagenet
"""
if dataset == 'imagenet':
data_folder = get_data_folder()
if dataset in imagenet_list:
data_folder = get_data_folder(dataset)
else:
raise NotImplementedError('dataset not supported: {}'.format(dataset))

Expand Down
2 changes: 1 addition & 1 deletion dataset/imagenet_dali.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_dali_data_loader(args):
crop_size = 224
val_size = 256

data_folder = get_data_folder()
data_folder = get_data_folder(args.dataset)
train_folder = os.path.join(data_folder, 'train')
val_folder = os.path.join(data_folder, 'val')

Expand Down
4 changes: 2 additions & 2 deletions kd.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
python train_student.py --path-t ./save/models/ResNet34_vanilla/resnet34_transformed.pth \
--batch_size 256 --epochs 90 --dataset imagenet --gpu_id 4,5,6,7 --dist-url tcp:https://127.0.0.1:23333 \
--print-freq 100 --num_workers 16 --distill kd --model_s ResNet18 -r 1 -a 1 -b 0 --trial 0 \
--batch_size 256 --epochs 90 --dataset imagenet --gpu_id 0,1,2,3,4,5,6,7 --dist-url tcp:https://127.0.0.1:23333 \
--print-freq 100 --num_workers 32 --distill kd --model_s ResNet18 -r 1 -a 1 -b 0 --trial test \
--multiprocessing-distributed --learning_rate 0.1 --lr_decay_epochs 30,60 --weight_decay 1e-4 --dali gpu
40 changes: 22 additions & 18 deletions train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from models.util import Embed, ConvReg, LinearEmbed, SelfA

from dataset.cifar100 import get_cifar100_dataloaders, get_cifar100_dataloaders_sample
from dataset.imagenet import get_imagenet_dataloader, get_dataloader_sample
from dataset.imagenet import get_imagenet_dataloader, get_dataloader_sample, imagenet_list
from dataset.imagenet_dali import get_dali_data_loader

from helper.util import adjust_learning_rate, save_dict_to_json, reduce_tensor
Expand Down Expand Up @@ -51,7 +51,7 @@ def parse_option():
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')

# dataset
parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100', 'imagenet'], help='dataset')
parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100', 'imagenet', 'imagenette'], help='dataset')

# model
parser.add_argument('--model_s', type=str, default='resnet8',
Expand Down Expand Up @@ -144,7 +144,7 @@ def get_teacher_name(model_path):
return segments[0] + '_' + segments[1] + '_' + segments[2]


def load_teacher(model_path, n_cls):
def load_teacher(model_path, n_cls, gpu=None):
print('==> loading teacher model')
model_t = get_teacher_name(model_path)
model = model_dict[model_t](num_classes=n_cls)
Expand All @@ -155,7 +155,9 @@ def load_teacher(model_path, n_cls):
#model = model.module
# pre-trained model saved from train_teacher.py
#else:
model.load_state_dict(torch.load(model_path)['model'])
# TODO: reduce size of the teacher saved in train_teacher.py
map_location = None if gpu is None else {'cuda:0': 'cuda:%d' % gpu}
model.load_state_dict(torch.load(model_path, map_location=map_location)['model'])
print('==> done')
return model

Expand Down Expand Up @@ -199,18 +201,17 @@ def main_worker(gpu, ngpus_per_node, opt):
opt.batch_size = int(opt.batch_size / ngpus_per_node)
opt.num_workers = int((opt.num_workers + ngpus_per_node - 1) / ngpus_per_node)

# tensorboard logger
logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

if opt.dataset == 'cifar100':
n_cls = 100
elif opt.dataset == 'imagenet':
n_cls = 1000
else:
class_num_map = {
'cifar100': 100,
'imagenet': 1000,
'imagenette': 10,
}
if opt.dataset not in class_num_map:
raise NotImplementedError(opt.dataset)
n_cls = class_num_map[opt.dataset]

# model
model_t = load_teacher(opt.path_t, n_cls)
model_t = load_teacher(opt.path_t, n_cls, opt.gpu)
model_s = model_dict[opt.model_s](num_classes=n_cls)

data = torch.randn(2, 3, 32, 32)
Expand Down Expand Up @@ -303,7 +304,7 @@ def main_worker(gpu, ngpus_per_node, opt):
module_list.cuda(opt.gpu)
distributed_modules = []
for module in module_list:
# TODO: test whether apex is faster
# TODO: test apex.amp
# DDP = torch.nn.parallel.DistributedDataParallel if opt.dali is None else apex.parallel.DistributedDataParallel
# distributed_modules.append(DDP(module, delay_allreduce=True))
DDP = torch.nn.parallel.DistributedDataParallel
Expand All @@ -327,13 +328,13 @@ def main_worker(gpu, ngpus_per_node, opt):
else:
train_loader, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size,
num_workers=opt.num_workers)
elif opt.dataset == 'imagenet':
elif opt.dataset in imagenet_list:
if opt.dali is None:
if opt.distill in ['crd']:
train_loader, val_loader, n_data = get_dataloader_sample(batch_size=opt.batch_size, num_workers=opt.num_workers,
k=opt.nce_k, is_sample=False)
else:
train_loader, val_loader, train_sampler = get_imagenet_dataloader(batch_size=opt.batch_size,
train_loader, val_loader, train_sampler = get_imagenet_dataloader(dataset=opt.dataset, batch_size=opt.batch_size,
num_workers=opt.num_workers,
multiprocessing_distributed=opt.multiprocessing_distributed)
else:
Expand All @@ -350,6 +351,9 @@ def main_worker(gpu, ngpus_per_node, opt):
reduced = reduce_tensor(teacher_acc, opt.world_size)
teacher_acc = reduced.item()

if opt.dali is not None:
val_loader.reset()

if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
print('teacher accuracy: ', teacher_acc)

Expand All @@ -369,14 +373,14 @@ def main_worker(gpu, ngpus_per_node, opt):
metrics = torch.tensor([train_acc, train_acc_top5, train_loss, data_time]).cuda(opt.gpu, non_blocking=True)
reduced = reduce_tensor(metrics, opt.world_size)
train_acc, train_acc_top5, train_loss, data_time = reduced.tolist()

if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
print(' * Epoch {}, GPU {}, Acc@1 {:.3f}, Acc@5 {:.3f}, Time {:.2f}, Data {:.2f}'.format(epoch, opt.gpu, train_acc, train_acc_top5, time2 - time1, data_time))

logger.log_value('train_acc', train_acc, epoch)
logger.log_value('train_loss', train_loss, epoch)

test_acc, test_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt)
print('GPU %d validating' % (opt.gpu))
test_acc, test_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt)

if opt.dali is not None:
train_loader.reset()
Expand Down
24 changes: 12 additions & 12 deletions train_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from models import model_dict

from dataset.cifar100 import get_cifar100_dataloaders
from dataset.imagenet import get_imagenet_dataloader
from dataset.imagenet import get_imagenet_dataloader, imagenet_list
from helper.util import adjust_learning_rate, accuracy, AverageMeter, save_dict_to_json, reduce_tensor
from helper.loops import train_vanilla as train, validate
from dataset.imagenet_dali import get_dali_data_loader
Expand Down Expand Up @@ -47,9 +47,9 @@ def parse_option():
'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2',
'vgg8', 'vgg11', 'vgg13', 'vgg16', 'vgg19',
'MobileNetV2', 'ShuffleV1', 'ShuffleV2','ResNet50' ])
parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100', 'imagenet'], help='dataset')
parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100', 'imagenet', 'imagenette'], help='dataset')

parser.add_argument('-t', '--trial', type=int, default=0, help='the experiment id')
parser.add_argument('-t', '--trial', type=str, default='0', help='the experiment id')

parser.add_argument('--use-lmdb', action='store_true') # default=false

Expand Down Expand Up @@ -132,13 +132,12 @@ def main_worker(gpu, ngpus_per_node, opt):
world_size=opt.world_size, rank=opt.rank)

# model
if opt.dataset == 'cifar100':
n_cls = 100
elif opt.dataset == 'imagenet':
n_cls = 1000
else:
n_cls = None

n_cls = {
'cifar100': 100,
'imagenet': 1000,
'imagenette': 10,
}.get(opt.dataset, None)

model = model_dict[opt.model](num_classes=n_cls)

# optimizer
Expand Down Expand Up @@ -181,9 +180,10 @@ def main_worker(gpu, ngpus_per_node, opt):
# dataloader
if opt.dataset == 'cifar100':
train_loader, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers)
elif opt.dataset == 'imagenet':
elif opt.dataset in imagenet_list:
if opt.dali is None:
train_loader, val_loader, train_sampler = get_imagenet_dataloader(
dataset = opt.dataset,
batch_size=opt.batch_size, num_workers=opt.num_workers, use_lmdb=opt.use_lmdb,
multiprocessing_distributed=opt.multiprocessing_distributed)
else:
Expand Down Expand Up @@ -241,7 +241,7 @@ def main_worker(gpu, ngpus_per_node, opt):
best_acc = test_acc
state = {
'epoch': epoch,
'model': model.state_dict(),
'model': model.module.state_dict() if opt.multiprocessing_distributed else model.state_dict(),
'best_acc': best_acc,
'optimizer': optimizer.state_dict(),
}
Expand Down

0 comments on commit 4db1cd4

Please sign in to comment.