Skip to content

Commit

Permalink
lmbd, retrieve time in each batch
Browse files Browse the repository at this point in the history
  • Loading branch information
triomino committed Jul 7, 2020
1 parent 2088f54 commit cd4ee0e
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,6 @@ venv.bak/

# trained models
save/*

# log
log
180 changes: 180 additions & 0 deletions dataset/folder2lmdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import os
import os.path as osp
import os, sys
import os.path as osp
from PIL import Image
import six
import string

import lmdb
import pickle
import msgpack
import tqdm
import pyarrow as pa

import torch
import torch.utils.data as data
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torchvision.datasets import ImageFolder
from torchvision import transforms, datasets


class ImageFolderLMDB(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None):
self.db_path = db_path
self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
# self.length = txn.stat()['entries'] - 1
self.length =pa.deserialize(txn.get(b'__len__'))
self.keys= pa.deserialize(txn.get(b'__keys__'))

self.transform = transform
self.target_transform = target_transform

def __getitem__(self, index):
img, target = None, None
env = self.env
with env.begin(write=False) as txn:
byteflow = txn.get(self.keys[index])
unpacked = pa.deserialize(byteflow)

# load image
imgbuf = unpacked[0]
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
img = Image.open(buf).convert('RGB')

# load label
target = unpacked[1]

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return self.length

def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'


class ImageFolderLMDB_old(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None):
import lmdb
self.db_path = db_path
self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
self.length = txn.stat()['entries'] - 1
self.keys = msgpack.loads(txn.get(b'__keys__'))
# cache_file = '_cache_' + db_path.replace('/', '_')
# if os.path.isfile(cache_file):
# self.keys = pickle.load(open(cache_file, "rb"))
# else:
# with self.env.begin(write=False) as txn:
# self.keys = [key for key, _ in txn.cursor()]
# pickle.dump(self.keys, open(cache_file, "wb"))
self.transform = transform
self.target_transform = target_transform

def __getitem__(self, index):
img, target = None, None
env = self.env
with env.begin(write=False) as txn:
byteflow = txn.get(self.keys[index])
unpacked = msgpack.loads(byteflow)
imgbuf = unpacked[0][b'data']
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
img = Image.open(buf).convert('RGB')
target = unpacked[1]

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return self.length

def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'


def raw_reader(path):
with open(path, 'rb') as f:
bin_data = f.read()
return bin_data


def dumps_pyarrow(obj):
"""
Serialize an object.
Returns:
Implementation-dependent bytes-like object
"""
return pa.serialize(obj).to_buffer()


def folder2lmdb(dpath, name="train", write_frequency=5000, num_workers=16):
directory = osp.expanduser(osp.join(dpath, name))
print("Loading dataset from %s" % directory)
dataset = ImageFolder(directory, loader=raw_reader)
data_loader = DataLoader(dataset, num_workers=num_workers, collate_fn=lambda x: x)

lmdb_path = osp.join(dpath, "%s.lmdb" % name)
isdir = os.path.isdir(lmdb_path)

print("Generate LMDB to %s" % lmdb_path)
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True)

print(len(dataset), len(data_loader))
txn = db.begin(write=True)
for idx, data in enumerate(data_loader):
# print(type(data), data)
image, label = data[0]
txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow((image, label)))
if idx % write_frequency == 0:
print("[%d/%d]" % (idx, len(data_loader)))
txn.commit()
txn = db.begin(write=True)

# finish iterating through dataset
txn.commit()
keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps_pyarrow(keys))
txn.put(b'__len__', dumps_pyarrow(len(keys)))

print("Flushing database ...")
db.sync()
db.close()


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--folder", type=str)
parser.add_argument('-s', '--split', type=str, default="val")
parser.add_argument('--out', type=str, default=".")
parser.add_argument('-p', '--procs', type=int, default=20)

args = parser.parse_args()

folder2lmdb(args.folder, num_workers=args.procs, name=args.split)
16 changes: 13 additions & 3 deletions dataset/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchvision import datasets
from torchvision import transforms

from dataset.folder2lmdb import ImageFolderLMDB

def get_data_folder():
"""
Expand Down Expand Up @@ -176,7 +177,7 @@ def get_dataloader_sample(dataset='imagenet', batch_size=128, num_workers=8, is_
return train_loader, test_loader, len(train_set), len(train_set.classes)


def get_imagenet_dataloader(dataset='imagenet', batch_size=128, num_workers=16, is_instance=False):
def get_imagenet_dataloader(dataset='imagenet', batch_size=128, num_workers=16, is_instance=False, use_lmdb=False):
"""
Data Loader for imagenet
"""
Expand All @@ -202,14 +203,23 @@ def get_imagenet_dataloader(dataset='imagenet', batch_size=128, num_workers=16,

train_folder = os.path.join(data_folder, 'train')
test_folder = os.path.join(data_folder, 'val')
if use_lmbd:
train_lmdb_path = os.path.join(data_folder, 'train.lmdb')
test_lmdb_path = os.path.join(data_folder, 'val.lmdb')

if is_instance:
train_set = ImageFolderInstance(train_folder, transform=train_transform)
n_data = len(train_set)
else:
train_set = datasets.ImageFolder(train_folder, transform=train_transform)
if use_lmdb:
train_set = ImageFolderLMDB(train_lmdb_path, transform=train_transform)
else:
train_set = datasets.ImageFolder(train_folder, transform=train_transform)

test_set = datasets.ImageFolder(test_folder, transform=test_transform)
if use_lmdb:
test_set = ImageFolderLMDB(test_lmdb_path, transform=test_transform)
else:
test_set = datasets.ImageFolder(test_folder, transform=test_transform)

train_loader = DataLoader(train_set,
batch_size=batch_size,
Expand Down
13 changes: 11 additions & 2 deletions helper/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@ def train_vanilla(epoch, train_loader, model, criterion, optimizer, opt):
"""vanilla training"""
model.train()

batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()

end = time.time()
for idx, (input, target) in enumerate(train_loader):
data_time.update(time.time() - end)

# input = input.float()
if torch.cuda.is_available():
input = input.cuda()
Expand All @@ -29,6 +34,8 @@ def train_vanilla(epoch, train_loader, model, criterion, optimizer, opt):
metrics = accuracy(output, target, topk=(1, 5))
top1.update(metrics[0].item(), input.size(0))
top5.update(metrics[1].item(), input.size(0))
batch_time.update(time.time() - end)
end = time.time()

# ===================backward=====================
optimizer.zero_grad()
Expand All @@ -38,11 +45,13 @@ def train_vanilla(epoch, train_loader, model, criterion, optimizer, opt):
# print info
if idx % opt.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.avg:.4f}\t'
'Acc@1 {top1.avg:.3f}\t'
'Acc@5 {top5.avg:.3f}'.format(
epoch, idx, len(train_loader),
loss=losses, top1=top1, top5=top5))
epoch, idx, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
sys.stdout.flush()

return top1.avg, top5.avg, losses.avg
Expand Down
2 changes: 1 addition & 1 deletion train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def parse_option():
parser.add_argument('--transform_layer_s', nargs='+', type=int, default = [])

# switch for edge transformation
parser.add_argument('--no_edge_transform', action='store_true')
parser.add_argument('--no_edge_transform', action='store_true') # default=false

opt = parser.parse_args()

Expand Down
4 changes: 3 additions & 1 deletion train_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def parse_option():

parser.add_argument('-t', '--trial', type=int, default=0, help='the experiment id')

parser.add_argument('--use_lmdb', action='store_true') # default=false

opt = parser.parse_args()

# set different learning rate from these 4 models
Expand Down Expand Up @@ -89,7 +91,7 @@ def main():
train_loader, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers)
n_cls = 100
elif opt.dataset == 'imagenet':
train_loader, val_loader = get_imagenet_dataloader(batch_size=opt.batch_size, num_workers=opt.num_workers)
train_loader, val_loader = get_imagenet_dataloader(batch_size=opt.batch_size, num_workers=opt.num_workers, use_lmdb=opt.use_lmdb)
n_cls = 1000
else:
raise NotImplementedError(opt.dataset)
Expand Down

0 comments on commit cd4ee0e

Please sign in to comment.