Implementation code for the GATTA algorithm
The code implements the GATTA algorithm in Distributed Learning over Networks with Graph-Attention-Based Personalization.
@ARTICLE{GATTA,
author={Tian, Zhuojun and Zhang, Zhaoyang and Yang, Zhaohui and Jin, Richeng and Dai, Huaiyu},
journal={IEEE Transactions on Signal Processing},
title={Distributed Learning Over Networks With Graph-Attention-Based Personalization},
year={2023},
volume={71},
pages={2071-2086},
doi={10.1109/TSP.2023.3282071}}
[Python 3.6 + TensorFlow 2.4.1 + CUDA 11.0] or [Python 2.7 + TensorFlow 1.15.0 + CUDA 10.0]
(modify the ''import'' tensorflow.compat.v1 to tensorflow)numpy
tqdm
- Run
Main_GATTA.py
to get the training result of GATTA. adj_matrix.txt
is the adjacent matrix of the communication network topology generated and stored bygenerate_network.py
CategoryToClients3.txt
stores the classes in each client and is generated by Sample_parti_noiid3.py`.LocalDist_niid3.txt
andLocalDist_niid_test3.txt
store the number of data samples in each class in all clients, corresponding to training data and testing data respectively. They are generated bySample_parti_noiid3.py
.Model_GATTA.py
is the core of the algorithm, storing the network architecture where the last fully-connected layer is the node-specific layer with attention mechanism.Client_GATTA.py
generate the clients.
-The uploaded code only implements the basic setting of Fig. 3(a), with N=100, p=0.6,
-To generate the other communication network topology, please modify the generate_network.py
and store the adjacent matrix.
-To consider different Sample_parti_noiid3.py
. To generate new CategoryToClients3.txt LocalDist_niid3.txt and LocalDist_niid_test3.txt.
- Run
Main_GATTA.py
to get the training result of GATTA. -
adj_matrix.txt
is the adjacent matrix of the communication network topology generated and stored bygenerate_network.py
-
Categories2.txt
the number of agents having the corresponding user and is generated by Sample_parti_noiid2.py`. -
LocalDist_niid2.txt
andLocalDist_niid_test2.txt
store the number of data samples correpsonding to the two local users in all clients, training data and testing data respectively. They are generated bySample_parti_noiid2.py
. -
UsersToClients2.txt
stores the ID of the local users ($e_i=2$ here) in each agent and is generated bySample_parti_noiid2.py
. -
Model_GATTA.py
is the core of the algorithm, storing the network architecture where the last fully-connected layer is the node-specific layer with attention mechanism. -
Client_GATTA.py
generate the clients.
-The uploaded code only implements the basic setting of Fig. 4(a), with N=100, p=0.6,
-To generate the other communication network topology, please modify the generate_network.py
and store the adjacent matrix.
-To consider different Sample_parti_noiid2.py
. To generate new Categories2.txt, UsersToClients2.txt, LocalDist_niid2.txt and LocalDist_niid_test2.txt.
-
FEMNIST need to be preprocessed according to the official introduction and stored in
femnist/data/train
andfemnist/data/train
. Specifically, for the preprocessing, we shuffle the data and delete the users whose number of training samples are smaller than 10. Note there should be 3596 users left. Otherwise, modify the value in 82 line in the 'Dataset.py' fromFEMNIST_code/Model
. Then we separate the data for each user into 75% for training and 25% for testing and the results are stored inFEMNIST/femnist/data/train
andfemnist/data/train
. -
The data assignments for clients are already finished and stored. If one wants to reassign the non-i.i.d. data, run the 'Sample_parti_noiid3.py' in
CIFAR10_code/Main
by uncommenting the code; or the 'Sample_parti_noiid2.py' inFEMNIST_code/Main
.