Skip to content

Implementation code for the GATTA algorithm in "Distributed Learning over Networks with Graph-Attention-Based Personalization"

Notifications You must be signed in to change notification settings

ZhuoJTian/GATTA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 

Repository files navigation

GATTA

Implementation code for the GATTA algorithm

Introduction

The code implements the GATTA algorithm in Distributed Learning over Networks with Graph-Attention-Based Personalization.

@ARTICLE{GATTA,
  author={Tian, Zhuojun and Zhang, Zhaoyang and Yang, Zhaohui and Jin, Richeng and Dai, Huaiyu},
  journal={IEEE Transactions on Signal Processing}, 
  title={Distributed Learning Over Networks With Graph-Attention-Based Personalization}, 
  year={2023},
  volume={71},
  pages={2071-2086},
  doi={10.1109/TSP.2023.3282071}}

Requirements

  1. [Python 3.6 + TensorFlow 2.4.1 + CUDA 11.0] or [Python 2.7 + TensorFlow 1.15.0 + CUDA 10.0] (modify the ''import'' tensorflow.compat.v1 to tensorflow)
  2. numpy
  3. tqdm

Usage

Part 1: CIFAR-10 on different numbers of labels

  1. Run Main_GATTA.py to get the training result of GATTA.
  2. adj_matrix.txt is the adjacent matrix of the communication network topology generated and stored by generate_network.py
  3. CategoryToClients3.txt stores the classes in each client and is generated by Sample_parti_noiid3.py`.
  4. LocalDist_niid3.txt and LocalDist_niid_test3.txt store the number of data samples in each class in all clients, corresponding to training data and testing data respectively. They are generated by Sample_parti_noiid3.py.
  5. Model_GATTA.py is the core of the algorithm, storing the network architecture where the last fully-connected layer is the node-specific layer with attention mechanism.
  6. Client_GATTA.py generate the clients.

Note

-The uploaded code only implements the basic setting of Fig. 3(a), with N=100, p=0.6, $c_i=3$.

-To generate the other communication network topology, please modify the generate_network.py and store the adjacent matrix.

-To consider different $c_i$, please modify and run the Sample_parti_noiid3.py. To generate new CategoryToClients3.txt LocalDist_niid3.txt and LocalDist_niid_test3.txt.

Part 2: FEMNIST on different numbers of writers

  1. Run Main_GATTA.py to get the training result of GATTA.
  2. adj_matrix.txt is the adjacent matrix of the communication network topology generated and stored by generate_network.py
  3. Categories2.txt the number of agents having the corresponding user and is generated by Sample_parti_noiid2.py`.
  4. LocalDist_niid2.txt and LocalDist_niid_test2.txt store the number of data samples correpsonding to the two local users in all clients, training data and testing data respectively. They are generated by Sample_parti_noiid2.py.
  5. UsersToClients2.txt stores the ID of the local users ($e_i=2$ here) in each agent and is generated by Sample_parti_noiid2.py.
  6. Model_GATTA.py is the core of the algorithm, storing the network architecture where the last fully-connected layer is the node-specific layer with attention mechanism.
  7. Client_GATTA.py generate the clients.

Note

-The uploaded code only implements the basic setting of Fig. 4(a), with N=100, p=0.6, $e_i=2$.

-To generate the other communication network topology, please modify the generate_network.py and store the adjacent matrix.

-To consider different $e_i$, please modify and run the Sample_parti_noiid2.py. To generate new Categories2.txt, UsersToClients2.txt, LocalDist_niid2.txt and LocalDist_niid_test2.txt.

Note

  • FEMNIST need to be preprocessed according to the official introduction and stored in femnist/data/train and femnist/data/train. Specifically, for the preprocessing, we shuffle the data and delete the users whose number of training samples are smaller than 10. Note there should be 3596 users left. Otherwise, modify the value in 82 line in the 'Dataset.py' from FEMNIST_code/Model. Then we separate the data for each user into 75% for training and 25% for testing and the results are stored in FEMNIST/femnist/data/train and femnist/data/train.

  • The data assignments for clients are already finished and stored. If one wants to reassign the non-i.i.d. data, run the 'Sample_parti_noiid3.py' in CIFAR10_code/Mainby uncommenting the code; or the 'Sample_parti_noiid2.py' in FEMNIST_code/Main .

About

Implementation code for the GATTA algorithm in "Distributed Learning over Networks with Graph-Attention-Based Personalization"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages