Skip to content

Commit

Permalink
fix some dataloader bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ohyeat committed Feb 19, 2020
1 parent 9cda8d0 commit c9a6ba2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion cls/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def main():
if not args.test_only:
assert os.path.exists(args.train_dir)
args.train_dataloader = get_train_dataloader(args.train_dir, \
args.batch_size//args.gpu_num,args.local_rank)
args.batch_size//args.gpu_num, args.total_epoch,args.local_rank)

assert os.path.exists(args.val_dir)
if args.local_rank == 0:
Expand Down
6 changes: 3 additions & 3 deletions cls/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def __call__(self, pred:torch.Tensor, target:torch.Tensor):

return loss

def get_train_dataloader(train_dir, batch_size, total_iters,local_rank):
def get_train_dataloader(train_dir, batch_size, total_epochs, local_rank):
eigvec = np.array([
[-0.5836, -0.6948, 0.4203],
[-0.5808, -0.0045, -0.8140],
Expand All @@ -260,15 +260,15 @@ def get_train_dataloader(train_dir, batch_size, total_iters,local_rank):

datasampler = Random_Batch_Sampler(
train_dataset, batch_size=batch_size,
total_iters=total_iters*50000, rank=local_rank)
total_iters=total_epochs*5000, rank=local_rank)
train_loader = torch.utils.data.DataLoader(
train_dataset, num_workers=8,
pin_memory=True, batch_sampler=datasampler)

return train_loader

def get_val_dataloader(val_dir):
val_dataset = datasets.ImageFolder(train_dir,
val_dataset = datasets.ImageFolder(val_dir,
transforms.Compose([
OpencvResize(256),
transforms.CenterCrop(224),
Expand Down

0 comments on commit c9a6ba2

Please sign in to comment.