Skip to content

Commit

Permalink
Re-upload data load
Browse files Browse the repository at this point in the history
  • Loading branch information
jhoon-oh committed May 21, 2021
1 parent 76c37dd commit b3bf421
Showing 1 changed file with 58 additions and 48 deletions.
106 changes: 58 additions & 48 deletions utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,63 +22,73 @@
transforms.Normalize(mean=[0.507, 0.487, 0.441],
std=[0.267, 0.256, 0.276])])

def get_data(args):

if args.unbalanced:
def get_data(args, env='fed'):
if env == 'single':
if args.dataset == 'cifar10':
dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar10_train)
dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar10_val)
if args.iid:
dict_users_train = iid_unbalanced(dataset_train, args.num_users, args.num_batch_users, args.moved_data_size)
dict_users_test = iid_unbalanced(dataset_test, args.num_users, args.num_batch_users, args.moved_data_size)
else:
dict_users_train, rand_set_all = noniid_unbalanced(dataset_train, args.num_users, args.num_batch_users, args.moved_data_size, args.shard_per_user)
dict_users_test, rand_set_all = noniid_unbalanced(dataset_test, args.num_users, args.num_batch_users, args.moved_data_size, args.shard_per_user, rand_set_all=rand_set_all)
elif args.dataset == 'cifar100':
dataset_train = datasets.CIFAR100('data/cifar100', train=True, download=True, transform=trans_cifar100_train)
dataset_test = datasets.CIFAR100('data/cifar100', train=False, download=True, transform=trans_cifar100_val)
if args.iid:
dict_users_train = iid_unbalanced(dataset_train, args.num_users, args.num_batch_users, args.moved_data_size)
dict_users_test = iid_unbalanced(dataset_test, args.num_users, args.num_batch_users, args.moved_data_size)
else:
dict_users_train, rand_set_all = noniid_unbalanced(dataset_train, args.num_users, args.num_batch_users, args.moved_data_size, args.shard_per_user)
dict_users_test, rand_set_all = noniid_unbalanced(dataset_test, args.num_users, args.num_batch_users, args.moved_data_size, args.shard_per_user, rand_set_all=rand_set_all)
else:
exit('Error: unrecognized dataset')

else:
if args.dataset == 'mnist':
dataset_train = datasets.MNIST('data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = datasets.MNIST('data/mnist/', train=False, download=True, transform=trans_mnist)
# sample users
if args.iid:
dict_users_train = iid(dataset_train, args.num_users)
dict_users_test = iid(dataset_test, args.num_users)
else:
dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user)
dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, rand_set_all=rand_set_all)
elif args.dataset == 'cifar10':
dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar10_train)
dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar10_val)
if args.iid:
dict_users_train = iid(dataset_train, args.num_users)
dict_users_test = iid(dataset_test, args.num_users)
else:
dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user)
dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, rand_set_all=rand_set_all)
elif args.dataset == 'cifar100':
dataset_train = datasets.CIFAR100('data/cifar100', train=True, download=True, transform=trans_cifar100_train)
dataset_test = datasets.CIFAR100('data/cifar100', train=False, download=True, transform=trans_cifar100_val)
if args.iid:
dict_users_train = iid(dataset_train, args.num_users)
dict_users_test = iid(dataset_test, args.num_users)
return dataset_train, dataset_test

elif env == 'fed':
if args.unbalanced:
if args.dataset == 'cifar10':
dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar10_train)
dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar10_val)
if args.iid:
dict_users_train = iid_unbalanced(dataset_train, args.num_users, args.num_batch_users, args.moved_data_size)
dict_users_test = iid_unbalanced(dataset_test, args.num_users, args.num_batch_users, args.moved_data_size)
else:
dict_users_train, rand_set_all = noniid_unbalanced(dataset_train, args.num_users, args.num_batch_users, args.moved_data_size, args.shard_per_user)
dict_users_test, rand_set_all = noniid_unbalanced(dataset_test, args.num_users, args.num_batch_users, args.moved_data_size, args.shard_per_user, rand_set_all=rand_set_all)
elif args.dataset == 'cifar100':
dataset_train = datasets.CIFAR100('data/cifar100', train=True, download=True, transform=trans_cifar100_train)
dataset_test = datasets.CIFAR100('data/cifar100', train=False, download=True, transform=trans_cifar100_val)
if args.iid:
dict_users_train = iid_unbalanced(dataset_train, args.num_users, args.num_batch_users, args.moved_data_size)
dict_users_test = iid_unbalanced(dataset_test, args.num_users, args.num_batch_users, args.moved_data_size)
else:
dict_users_train, rand_set_all = noniid_unbalanced(dataset_train, args.num_users, args.num_batch_users, args.moved_data_size, args.shard_per_user)
dict_users_test, rand_set_all = noniid_unbalanced(dataset_test, args.num_users, args.num_batch_users, args.moved_data_size, args.shard_per_user, rand_set_all=rand_set_all)
else:
dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user)
dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, rand_set_all=rand_set_all)
exit('Error: unrecognized dataset')

else:
exit('Error: unrecognized dataset')

return dataset_train, dataset_test, dict_users_train, dict_users_test
if args.dataset == 'mnist':
dataset_train = datasets.MNIST('data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = datasets.MNIST('data/mnist/', train=False, download=True, transform=trans_mnist)
# sample users
if args.iid:
dict_users_train = iid(dataset_train, args.num_users, args.server_data_ratio)
dict_users_test = iid(dataset_test, args.num_users, args.server_data_ratio)
else:
dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user, args.server_data_ratio)
dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, args.server_data_ratio, rand_set_all=rand_set_all)
elif args.dataset == 'cifar10':
dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar10_train)
dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar10_val)
if args.iid:
dict_users_train = iid(dataset_train, args.num_users, args.server_data_ratio)
dict_users_test = iid(dataset_test, args.num_users, args.server_data_ratio)
else:
dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user, args.server_data_ratio)
dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, args.server_data_ratio, rand_set_all=rand_set_all)
elif args.dataset == 'cifar100':
dataset_train = datasets.CIFAR100('data/cifar100', train=True, download=True, transform=trans_cifar100_train)
dataset_test = datasets.CIFAR100('data/cifar100', train=False, download=True, transform=trans_cifar100_val)
if args.iid:
dict_users_train = iid(dataset_train, args.num_users, args.server_data_ratio)
dict_users_test = iid(dataset_test, args.num_users, args.server_data_ratio)
else:
dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user, args.server_data_ratio)
dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, args.server_data_ratio, rand_set_all=rand_set_all)
else:
exit('Error: unrecognized dataset')

return dataset_train, dataset_test, dict_users_train, dict_users_test

def get_model(args):
if args.model == 'cnn' and args.dataset in ['cifar10', 'cifar100']:
Expand Down

0 comments on commit b3bf421

Please sign in to comment.