Skip to content

Commit

Permalink
Add Ditto
Browse files Browse the repository at this point in the history
  • Loading branch information
ElvinKim committed Oct 1, 2021
1 parent d87c297 commit ea83186
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 0 deletions.
128 changes: 128 additions & 0 deletions main_ditto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import copy
import pickle
import numpy as np
import pandas as pd
import torch

from utils.options import args_parser
from utils.train_utils import get_data, get_model
from models.Update import LocalUpdate, LocalUpdateDitto
from models.test import test_img, test_img_local, test_img_local_all
from models.Fed import FedAvg
import os

import pdb

if __name__ == '__main__':
# parse args
args = args_parser()

# set seed
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
np.random.seed(args.seed)

args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

if args.unbalanced:
base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}_unbalanced_bu{}_md{}/{}/'.format(
args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.num_batch_users, args.moved_data_size, args.results_save)
else:
base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/{}/'.format(
args.dataset, args.model, args.iid, args.num_users, args.frac, args.local_ep, args.shard_per_user, args.results_save)
algo_dir = "ditto"

if not os.path.exists(os.path.join(base_dir, algo_dir)):
os.makedirs(os.path.join(base_dir, algo_dir), exist_ok=True)

dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
dict_save_path = os.path.join(base_dir, algo_dir, 'dict_users.pkl')
with open(dict_save_path, 'wb') as handle:
pickle.dump((dict_users_train, dict_users_test), handle)

# build a global model
net_glob = get_model(args)
net_glob.train()

# build local models
net_local_list = []
for user_idx in range(args.num_users):
net_local_list.append(copy.deepcopy(net_glob))

# training
results_save_path = os.path.join(base_dir, algo_dir, 'results.csv')

loss_train = []
net_best = None
best_loss = None
best_acc = None
best_epoch = None

lr = args.lr
results = []
w_glob = copy.deepcopy(net_glob.state_dict())
lam = 0.75 # follows the setting of FedRep

for iter in range(args.epochs):
loss_locals = []
m = max(int(args.frac * args.num_users), 1)
idxs_users = list(np.random.choice(range(args.num_users), m, replace=False))

w_locals = []

# send all parameter for users
for idx in idxs_users:
print(idx)
local = LocalUpdateDitto(args=args, dataset=dataset_train, idxs=dict_users_train[idx])

net_global = copy.deepcopy(net_glob)
w_glob_k = copy.deepcopy(net_global.state_dict())

net_local = net_local_list[idx]

w, loss = local.train(net=net_local.to(args.device), idx=idx, lr=args.lr, w_ditto=w_glob_k, lam=lam)

w_locals.append(copy.deepcopy(w))
loss_locals.append(copy.deepcopy(loss))

# update global weights
w_glob = FedAvg(w_locals)
net_glob.load_state_dict(w_glob)

if (iter + 1) in [args.epochs//2, (args.epochs*3)//4]:
lr *= 0.1

# print loss
loss_avg = sum(loss_locals) / len(loss_locals)
loss_train.append(loss_avg)

if (iter + 1) % args.test_freq == 0:
acc_test, loss_test = test_img_local_all(net_local_list, args, dataset_test, dict_users_test, return_all=False)

print('Round {:3d}, Average loss {:.3f}, Test loss {:.3f}, Test accuracy: {:.2f}'.format(
iter, loss_avg, loss_test, acc_test))

if best_acc is None or acc_test > best_acc:
net_best = copy.deepcopy(net_glob)
best_acc = acc_test
best_epoch = iter

for user_idx in range(args.num_users):
best_save_path = os.path.join(base_dir, algo_dir, 'best_local_{}.pt'.format(user_idx))
torch.save(net_local_list[user_idx].state_dict(), best_save_path)

results.append(np.array([iter, loss_avg, loss_test, acc_test, best_acc]))
final_results = np.array(results)
final_results = pd.DataFrame(final_results, columns=['epoch', 'loss_avg', 'loss_test', 'acc_test', 'best_acc'])
final_results.to_csv(results_save_path, index=False)

# rollback global model
for user_idx in range(args.num_users):
net_local_list[user_idx].load_state_dict(w_glob, strict=False)

print('Best model, iter: {}, acc: {}'.format(best_epoch, best_acc))
98 changes: 98 additions & 0 deletions models/Update.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,5 +264,103 @@ def train(self, net, lr):
optimizer.step()

return net.state_dict()


class LocalUpdateFedProx(object):
def __init__(self, args, dataset=None, idxs=None, pretrain=False):
self.args = args
self.loss_func = nn.CrossEntropyLoss()
self.selected_clients = []
self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
self.pretrain = pretrain

def train(self, net, body_lr, head_lr):
net.train()
g_net = copy.deepcopy(net)

body_params = [p for name, p in net.named_parameters() if 'linear' not in name]
head_params = [p for name, p in net.named_parameters() if 'linear' in name]

optimizer = torch.optim.SGD([{'params': body_params, 'lr': body_lr},
{'params': head_params, 'lr': head_lr}],
momentum=self.args.momentum,
weight_decay=self.args.wd)

epoch_loss = []

for iter in range(self.args.local_ep):
batch_loss = []
for batch_idx, (images, labels) in enumerate(self.ldr_train):
images, labels = images.to(self.args.device), labels.to(self.args.device)
net.zero_grad()
logits = net(images)

loss = self.loss_func(logits, labels)

# for fedprox
fed_prox_reg = 0.0
for l_param, g_param in zip(net.parameters(), g_net.parameters()):
fed_prox_reg += (self.args.mu / 2 * torch.norm((l_param - g_param)) ** 2)
loss += fed_prox_reg

loss.backward()
optimizer.step()

batch_loss.append(loss.item())

epoch_loss.append(sum(batch_loss)/len(batch_loss))

return net.state_dict(), sum(epoch_loss) / len(epoch_loss)


class LocalUpdateDitto(object):
def __init__(self, args, dataset=None, idxs=None):
self.args = args
self.loss_func = nn.CrossEntropyLoss()
self.selected_clients = []

self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)

def train(self, net, w_ditto=None, lam=0, idx=-1, lr=0.1, last=False, momentum=0.9):
net.train()
# train and update
bias_p=[]
weight_p=[]
for name, p in net.named_parameters():
if 'bias' in name:
bias_p += [p]
else:
weight_p += [p]

optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)

local_eps = self.args.local_ep
args = self.args
epoch_loss=[]
num_updates = 0

for iter in range(local_eps):
done=False
batch_loss = []
for batch_idx, (images, labels) in enumerate(self.ldr_train):
w_0 = copy.deepcopy(net.state_dict())
images, labels = images.to(self.args.device), labels.to(self.args.device)
log_probs = net(images)
loss = self.loss_func(log_probs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()

if w_ditto is not None:
w_net = copy.deepcopy(net.state_dict())
for key in w_net.keys():
w_net[key] = w_net[key] - args.lr*lam*(w_0[key] - w_ditto[key])
net.load_state_dict(w_net)
optimizer.zero_grad()

num_updates += 1
batch_loss.append(loss.item())
epoch_loss.append(sum(batch_loss)/len(batch_loss))
return net.state_dict(), sum(epoch_loss) / len(epoch_loss)


25 changes: 25 additions & 0 deletions scripts/run_ditto.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 10 --epochs 320 --lr 0.1 --num_users 100 --frac 1.0 --local_ep 1 --local_bs 50 --results_save ditto
python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 10 --epochs 80 --lr 0.1 --num_users 100 --frac 1.0 --local_ep 4 --local_bs 50 --results_save ditto
python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 10 --epochs 32 --lr 0.1 --num_users 100 --frac 1.0 --local_ep 10 --local_bs 50 --results_save ditto

python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 10 --epochs 320 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 1 --local_bs 50 --results_save ditto
python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 10 --epochs 80 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 4 --local_bs 50 --results_save ditto
python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 10 --epochs 32 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 10 --local_bs 50 --results_save ditto

python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 50 --epochs 320 --lr 0.1 --num_users 100 --frac 1.0 --local_ep 1 --local_bs 50 --results_save ditto
python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 50 --epochs 80 --lr 0.1 --num_users 100 --frac 1.0 --local_ep 4 --local_bs 50 --results_save ditto
python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 50 --epochs 32 --lr 0.1 --num_users 100 --frac 1.0 --local_ep 10 --local_bs 50 --results_save ditto

python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 50 --epochs 320 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 1 --local_bs 50 --results_save ditto
python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 50 --epochs 80 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 4 --local_bs 50 --results_save ditto
python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 50 --epochs 32 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 10 --local_bs 50 --results_save ditto

python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 100 --epochs 320 --lr 0.1 --num_users 100 --frac 1.0 --local_ep 1 --local_bs 50 --results_save ditto
python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 100 --epochs 80 --lr 0.1 --num_users 100 --frac 1.0 --local_ep 4 --local_bs 50 --results_save ditto
python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 100 --epochs 32 --lr 0.1 --num_users 100 --frac 1.0 --local_ep 10 --local_bs 50 --results_save ditto

python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 100 --epochs 320 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 1 --local_bs 50 --results_save ditto
python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 100 --epochs 80 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 4 --local_bs 50 --results_save ditto
python main_ditto.py --dataset cifar100 --model mobile --num_classes 100 --shard_per_user 100 --epochs 32 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 10 --local_bs 50 --results_save ditto

0 comments on commit ea83186

Please sign in to comment.