Federated Learning with Uncertainty and Personalization via Efficient Second-order Optimization

Shivam Pal Department of Computer Science and Engineering, IIT Kanpur, India Aishwarya Gupta Department of Computer Science and Engineering, IIT Kanpur, India Saqib Sarwar Department of Computer Science and Engineering, IIT Kanpur, India Piyush Rai Department of Computer Science and Engineering, IIT Kanpur, India
Abstract

Federated Learning (FL) has emerged as a promising method to collaboratively learn from decentralized and heterogeneous data available at different clients without the requirement of data ever leaving the clients. Recent works on FL have advocated taking a Bayesian approach to FL as it offers a principled way to account for the model and predictive uncertainty by learning a posterior distribution for the client and/or server models. Moreover, Bayesian FL also naturally enables personalization in FL to handle data heterogeneity across the different clients by having each client learn its own distinct personalized model. In particular, the hierarchical Bayesian approach enables all the clients to learn their personalized models while also taking into account the commonalities via a prior distribution provided by the server. However, despite their promise, Bayesian approaches for FL can be computationally expensive and can have high communication costs as well because of the requirement of computing and sending the posterior distributions. We present a novel Bayesian FL method using an efficient second-order optimization approach, with a computational cost that is similar to first-order optimization methods like Adam, but also provides the various benefits of the Bayesian approach for FL (e.g., uncertainty, personalization), while also being significantly more efficient and accurate than SOTA Bayesian FL methods (both for standard as well as personalized FL settings). Our method achieves improved predictive accuracies as well as better uncertainty estimates as compared to the baselines which include both optimization based as well as Bayesian FL methods.

Keywords: Bayesian Federated Learning, Variational Inference, Second-order Optimization

1 Introduction

Federated Learning (FL) [1] aims at learning a global model collaboratively across clients without compromising their privacy. It involves multiple client-server communication rounds, where in each round the selected clients send their local models (trained on their private dataset) to the server and the server aggregates the received models followed by its broadcasting to all clients. Thus, the global model, an approximation to the model obtained if all the data was accessible, depends significantly both on the quality of the received clients’ models and the chosen aggregation strategy at the server. As a result, a straightforward approach like FedAvg[1] can yield a high-performing global model if the data is i.i.d. distributed among clients; however performs suboptimally in case of non i.i.d. data distribution. Moreover, the challenges are compounded if each client has a limited private dataset.

The limitations of standard FL become even more apparent with data heterogeneity, where clients have distinct data distributions. A single global model might fail to represent all clients well, leading to poor performance. This motivates personalized FL (pFL) [2], which aims to adopt models to individual clients while leveraging shared global knowledge.

In such settings, learning the posterior distribution instead of a point estimate at each client results in enhanced performance and uncertainty measures, as demonstrated in several recent works, such as [3, 4, 5, 6] which have advocated taking a Bayesian approach to FL. Moreover, Bayesian FL is also natural for personalization because the server model can serve as a prior distribution in a hierarchical Bayesian framework, enabling easy personalization of client models using their respective client-specific likelihoods. However, existing Bayesian FL and pFL methods usually rely on running computationally expensive routines on the clients (e.g., requiring expensive MCMC sampling [3], expensive Laplace’s approximation which requires Hessian computations [4] on the clients, or methods based on learning deep ensembles [7]), as well as expensive client-server communication [8] and aggregation at the server (note that, unlike standard FL, Bayesian FL would require sending the whole client posterior to the server). Due to such computational bottlenecks and communication overhead, Bayesian approaches lack scalability, especially for clients with limited resources and bandwidth.

Refer to caption
Figure 1: Illustration of FedIVON.

Thus, to bridge this gap, we propose a novel Bayesian FL algorithm FedIvon (with its high-level idea illustrated in Fig. 1), that balances the benefits of Bayesian inference - enhanced performance, and quantification of predictive uncertainty - with minimal increase in computational and communication overhead. In particular, we leverage the IVON (Improved Variational Online Newton) algorithm [9] to perform highly efficient variational inference (VI) on each client by approximating its local posterior using a Gaussian with diagonal covariance. It uses the natural gradient to capture the geometry of the loss function for faster convergence. Moreover, it computes the Hessian implicitly, making our method computationally cheaper than other existing Bayesian FL and pFL methods that use explicit Hessian computation, e.g., Laplace’s approximation [4], expensive MCMC sampling [3, 5], or even VI [8] at the clients. These local posteriors can be efficiently sent to the server and the global posterior can be computed for which we also present local posterior aggregation strategies. Our main contributions are:

  • We introduce a Bayesian FL method FedIvon that uses an efficient second-order optimization approach for variational inference, maintaining computational costs similar to first-order methods like Adam.

  • Our method demonstrates improvements in predictive accuracy and uncertainty estimation compared to state-of-the-art (SOTA) Bayesian and non-Bayesian FL methods.

  • Our method also supports client-level model personalization naturally by leveraging a hierarchical Bayesian framework. Clients can use the server’s posterior as priors to learn their private models, effectively balancing local adaptation with global knowledge sharing.

2 Related Work

FedAvg [1], the foundational federated learning algorithm, approximates the global model as the weighted aggregation of locally trained client models, performing effectively with i.i.d. data distributions. Since then, numerous sophisticated and efficient algorithms have been proposed to handle more realistic challenges such as non-i.i.d. data distribution, heterogeneous and resource-constrained clients, and multi-modal data as explored in recent survey works [10, 11, 12, 13, 14]. However, here, we will restrict our discussion to Bayesian FL and personalized FL algorithms as they are most relevant to our work.

Bayesian Federated Learning A key limitation of point-estimate-based approaches is their susceptibility to overfitting in limited data settings and lack of predictive uncertainty estimates. To address this, Bayesian approaches have been advocated for federated learning, which involves the computation of clients’ local posterior distribution followed by their aggregation at the server to compute the global posterior distribution, offering enhanced performance and quantification of predictive uncertainty. Unfortunately computing full posterior distribution is intractable and poses communication overhead. FedBE [15] mitigates the communication overhead by leveraging SWAG [16] to learn each client’s posterior but communicating only its mean. The server then fits a Gaussian/Dirichlet distribution to the clients’ posterior mean and distills it into a single model to be communicated in the next round. However, FedBE does not incorporate clients’ covariances, omitting the underlying uncertainty in their models during aggregation. FedPA [3] addresses this by learning a Gaussian distribution for each client and computes the mean of the global posterior at the server. However, it eventually discards the covariance of the global posterior and computes a point estimate to limit the communication costs. Similarly, FedLaplace [4] approximates each client’s posterior as a Gaussian distribution, modeling the global posterior as a mixture of Gaussian, though eventually it too simplifies it to a single Gaussian by minimizing KL divergence.

Second-order Optimization for Federated Learning shows promise for improving convergence but is often limited by efficiency and communication overhead. Methods such as FedNL [17], which use privacy-preserving Hessian learning and compression, and second-order approaches incorporating global line search [18], offer potential solutions to these challenges.

Personalized Federated Learning In the case of non-iid data distribution among clients, a single global model represents the average data distribution and diverges substantially from each client’s local distribution. Consequently, the global model, though benefitted from collaborative learning, performs suboptimally for individual clients. Personalized federated learning addresses this challenge by adapting a part or the whole model to the local data distribution explicitly. A typical approach is to split the model into two parts - a base model for global representation learning and a head model for personalized learning. FedPer [19] and FedRep [20] use this strategy, applying FedAvg for collaborative learning of the base model leveraged by the head for local data adaptation. Similarly, FedLG [21] splits the model into local and global components to learn local and shared representations respectively. It shares the global parameters with the server while enhancing local parameters further using the unsupervised or self-supervised approach. PerFedAvg [22] applies a Model-Agnostic Meta-Learning (MAML) [23] inspired framework to learn a shared model for faster adaptation to the client’s data. pFedME [24] decouples personalized adaptation from shared learning by regularizing each client’s loss function using Moreau envelopes. pFedBayes [25] is a Bayesian approach that aims at learning the personalized posterior distribution of each client. In each round, pFedBayes computes clients’ posterior using the global model as the prior and sends it to the server for updating the global model. pFedVEM [26] also computes the client’s posterior by restricting it to the Gaussian family. However, it leverages the collaborative knowledge of other clients by assuming conditional independence among clients’ models given the global model.

3 Bayesian FL via Improved Variational Online Newton

The standard formulation of FL is similar to distributed optimization except some additional constraints, such as no data sharing among clients and server and a limited communication budget. Assuming K𝐾Kitalic_K clients, let 𝒟=k[K]𝒟k𝒟subscript𝑘delimited-[]𝐾subscript𝒟𝑘\mathcal{D}=\bigcup_{k\in[K]}\mathcal{D}_{k}caligraphic_D = ⋃ start_POSTSUBSCRIPT italic_k ∈ [ italic_K ] end_POSTSUBSCRIPT caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT be the total available data where 𝒟ksubscript𝒟𝑘\mathcal{D}_{k}caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT denotes the private data of client k𝑘kitalic_k. The objective of standard FL is to solve 𝜽=argmin𝜽k[K]logp(𝒟k𝜽)superscript𝜽subscriptargmin𝜽subscript𝑘delimited-[]𝐾𝑝conditionalsubscript𝒟𝑘𝜽\bm{\theta}^{*}=\operatorname*{arg\,min}_{\bm{\theta}}\sum_{k\in[K]}-\log p(% \mathcal{D}_{k}\mid\bm{\theta})bold_italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_k ∈ [ italic_K ] end_POSTSUBSCRIPT - roman_log italic_p ( caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∣ bold_italic_θ ). However, this optimization problem is not trivial as it requires access to each client’s data which is not permitted in the federated setting. Thus, a multi-round approach is usually taken where clients learn their local models, send these local models to a central server which aggregates them into a global model, and send the global model to the clients to continue the next round of learning.

Unlike standard FL which only learns a point estimate of 𝜽𝜽\bm{\theta}bold_italic_θ, an alternative is to learn a distribution of 𝜽𝜽\bm{\theta}bold_italic_θ. The posterior distribution of 𝜽𝜽\bm{\theta}bold_italic_θ can be written as

p(𝜽𝒟)p(𝜽)k[K]p(𝒟k𝜽)proportional-to𝑝conditional𝜽𝒟𝑝𝜽subscriptproduct𝑘delimited-[]𝐾𝑝conditionalsubscript𝒟𝑘𝜽p(\bm{\theta}\mid\mathcal{D})\propto p(\bm{\theta})\prod_{k\in[K]}p(\mathcal{D% }_{k}\mid\bm{\theta})italic_p ( bold_italic_θ ∣ caligraphic_D ) ∝ italic_p ( bold_italic_θ ) ∏ start_POSTSUBSCRIPT italic_k ∈ [ italic_K ] end_POSTSUBSCRIPT italic_p ( caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∣ bold_italic_θ ) (1)

where p(𝜽)𝑝𝜽p(\bm{\theta})italic_p ( bold_italic_θ ) is prior distribution on 𝜽𝜽\bm{\theta}bold_italic_θ and p(𝒟k𝜽)𝑝conditionalsubscript𝒟𝑘𝜽p(\mathcal{D}_{k}\mid\bm{\theta})italic_p ( caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∣ bold_italic_θ ) is data likelihood of client k𝑘kitalic_k. Assuming uniform prior p(𝜽)𝑝𝜽p(\bm{\theta})italic_p ( bold_italic_θ ), it can be trivially shown that optimizing the standard FL objective function is equivalent to finding the mode of the posterior p(𝜽𝒟)𝑝conditional𝜽𝒟p(\bm{\theta}\mid\mathcal{D})italic_p ( bold_italic_θ ∣ caligraphic_D ), i.e., 𝜽=argmax𝜽logp(𝜽𝒟)superscript𝜽subscriptargmax𝜽𝑝conditional𝜽𝒟\bm{\theta}^{*}=\operatorname*{arg\,max}_{\bm{\theta}}\ \log p(\bm{\theta}\mid% \mathcal{D})bold_italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT roman_log italic_p ( bold_italic_θ ∣ caligraphic_D ).

Computing the full posterior p(𝜽𝒟)𝑝conditional𝜽𝒟p(\bm{\theta}\mid\mathcal{D})italic_p ( bold_italic_θ ∣ caligraphic_D ) is more useful than computing just the point estimate 𝜽superscript𝜽\bm{\theta}^{*}bold_italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT because the posterior helps take into account model uncertainty. However, it is computationally intractable to compute the posterior exactly. Directly approximating p(𝜽𝒟)𝑝conditional𝜽𝒟p(\bm{\theta}\mid\mathcal{D})italic_p ( bold_italic_θ ∣ caligraphic_D ) using approximate inference methods such as MCMC or variational inference [27] is also non-trivial, as it requires computing each client’s likelihood which in turn requires global access to all the client’s data.

Claim 1.

The global posterior p(𝛉𝒟)𝑝conditional𝛉𝒟p(\bm{\theta}\mid\mathcal{D})italic_p ( bold_italic_θ ∣ caligraphic_D ) can be approximated at the server by the product of local client posteriors without requiring access to any client’s local data.

If local posteriors p(θ𝒟k)𝑝conditional𝜃subscript𝒟𝑘p(\theta\mid\mathcal{D}_{k})italic_p ( italic_θ ∣ caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) are also being approximated, multiple rounds of optimization are needed to reduce the aggregation error in the global posterior [3]. In FL, another challenge is to make the computation of the local posteriors, their aggregation at the server, and the client-server communication, efficient, which in general can be difficult even for simple models [3].

3.1 Client’s posterior approximation

Assuming client k𝑘kitalic_k has Nksubscript𝑁𝑘N_{k}italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT training examples, its local loss can be defined as ¯k(𝜽)=1Nki=1Nki(𝜽)subscript¯𝑘𝜽1subscript𝑁𝑘superscriptsubscript𝑖1subscript𝑁𝑘subscript𝑖𝜽\bar{\ell}_{k}(\bm{\theta})=\frac{1}{N_{k}}\sum_{i=1}^{N_{k}}\ell_{i}(\bm{% \theta})over¯ start_ARG roman_ℓ end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) = divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ), and we can compute the point estimate of the parameters as 𝜽k=argmin𝜽¯k(𝜽)subscriptsuperscript𝜽𝑘subscriptargmin𝜽subscript¯𝑘𝜽\bm{\theta}^{*}_{k}=\operatorname*{arg\,min}_{\bm{\theta}}\ \bar{\ell}_{k}(\bm% {\theta})bold_italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT over¯ start_ARG roman_ℓ end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ). However, in our Bayesian FL setting, we will compute the (approximate) posterior distribution for each client using variational inference, which amounts to solving the following optimization problem

qk(𝜽)subscriptsuperscript𝑞𝑘𝜽\displaystyle q^{*}_{k}(\bm{\theta})italic_q start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) =argminqk(𝜽)k(q)absentsubscriptargminsubscript𝑞𝑘𝜽subscript𝑘𝑞\displaystyle=\operatorname*{arg\,min}_{q_{k}(\bm{\theta})}\mathcal{L}_{k}(q)= start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_q ) (2)
k(q)subscript𝑘𝑞\displaystyle\mathcal{L}_{k}(q)caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_q ) =𝔼qk(𝜽)[¯k(𝜽)]+𝔻KL(qk(𝜽)pk(𝜽))absentsubscript𝔼subscript𝑞𝑘𝜽delimited-[]subscript¯𝑘𝜽subscript𝔻𝐾𝐿conditionalsubscript𝑞𝑘𝜽subscript𝑝𝑘𝜽\displaystyle=\mathbb{E}_{q_{k}(\bm{\theta})}[\bar{\ell}_{k}(\bm{\theta})]+% \mathbb{D}_{KL}(q_{k}(\bm{\theta})\|p_{k}(\bm{\theta}))= blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) end_POSTSUBSCRIPT [ over¯ start_ARG roman_ℓ end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) ] + blackboard_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) ∥ italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) ) (3)

where pk(𝜽)subscript𝑝𝑘𝜽p_{k}(\bm{\theta})italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) is the prior and 𝔻KLsubscript𝔻𝐾𝐿\mathbb{D}_{KL}blackboard_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT is the Kullback-Leibler divergence. If we use the Gaussian variational family for qk(𝜽)subscript𝑞𝑘𝜽q_{k}(\bm{\theta})italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) with diagonal covariance then qk(𝜽)=𝒩(𝜽|𝒎k,diag(𝝈k2))subscript𝑞𝑘𝜽𝒩conditional𝜽subscript𝒎𝑘diagsuperscriptsubscript𝝈𝑘2q_{k}(\bm{\theta})=\mathcal{N}(\bm{\theta}|\bm{m}_{k},\text{diag}(\bm{\sigma}_% {k}^{2}))italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) = caligraphic_N ( bold_italic_θ | bold_italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , diag ( bold_italic_σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ), where 𝒎ksubscript𝒎𝑘\bm{m}_{k}bold_italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and 𝝈k2superscriptsubscript𝝈𝑘2\bm{\sigma}_{k}^{2}bold_italic_σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT denote the variational parameters that are to be optimized for. Optimizing the objective in Equation 3 w.r.t these variational parameters requires making the following updates

𝒎kt+1subscriptsuperscript𝒎𝑡1𝑘\displaystyle\bm{m}^{t+1}_{k}bold_italic_m start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT =𝒎ktα^𝒎kk(q)absentsubscriptsuperscript𝒎𝑡𝑘𝛼subscript^subscript𝒎𝑘subscript𝑘𝑞\displaystyle=\bm{m}^{t}_{k}-\alpha\hat{\nabla}_{\bm{m}_{k}}\mathcal{L}_{k}(q)= bold_italic_m start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_α over^ start_ARG ∇ end_ARG start_POSTSUBSCRIPT bold_italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_q ) (4)
𝝈kt+1superscriptsubscript𝝈𝑘𝑡1\displaystyle\bm{\sigma}_{k}^{t+1}bold_italic_σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT =𝝈ktα^𝝈kk(q)absentsuperscriptsubscript𝝈𝑘𝑡𝛼subscript^subscript𝝈𝑘subscript𝑘𝑞\displaystyle=\bm{\sigma}_{k}^{t}-\alpha\hat{\nabla}_{\bm{\sigma}_{k}}\mathcal% {L}_{k}(q)= bold_italic_σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT - italic_α over^ start_ARG ∇ end_ARG start_POSTSUBSCRIPT bold_italic_σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_q ) (5)

where α>0𝛼0\alpha>0italic_α > 0 is the learning rate.
Computing exact gradients in the above update equations is difficult due to the expectation term in k(q)subscript𝑘𝑞\mathcal{L}_{k}(q)caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_q ). A naïve way to optimize is to use stochastic gradient estimators. However, these approaches are not very scalable due to the high variance in the gradient estimates. [9] improved these update equations and provided much more efficient update equations similar to Adam optimizer, which is essentially the improved variational online Newton (IVON) algorithm [9], with almost exact computational cost as Adam, and their key differences are summarized below

  • Unlike Adam which solves for 𝜽𝜽\bm{\theta}bold_italic_θ, IVON solves for both the mean vector 𝒎𝒎\bm{m}bold_italic_m and the variances 𝝈2superscript𝝈2\bm{\sigma}^{2}bold_italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT which provides us an estimate of the Gaussian variational approximation at each client. Note that the mean 𝒎𝒎\bm{m}bold_italic_m plays the role of 𝜽𝜽\bm{\theta}bold_italic_θ in Adam. In addition, the variances naturally provide the uncertainty estimates for 𝜽𝜽\bm{\theta}bold_italic_θ, essential for Bayesian FL (both in estimating the client models’ uncertainties as well as during the aggregation of client models at the server).

  • Unlike Adam which uses squared minibatch gradients to adjust the learning rates in different dimensions, IVON uses a reparametrization defined as gradient element-wise multiplied by (𝜽𝐦)/𝝈2𝜽𝐦superscript𝝈2(\bm{\theta}-\mathbf{m})/\bm{\sigma}^{2}( bold_italic_θ - bold_m ) / bold_italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT to get an unbiased estimate of the (diagonal) Hessian. Using this, IVON is able to get a cheap estimate of the Hessian, which makes it a second-order method, unlike Adam.

  • IVON offers the significant advantage of providing an estimate of second-order information hhitalic_h with minimal computational overhead. The Hessian hhitalic_h corresponds to the inverse of σ2superscript𝜎2\sigma^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, where σ2=1h+δsuperscript𝜎21𝛿\sigma^{2}=\frac{1}{h+\delta}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_h + italic_δ end_ARG. An estimate of hhitalic_h is accessible throughout the training process (see Algorithm 2). Moreover, there is no explicit update question for hhitalic_h. It is computed implicitly using gradient information. In comparison, standard optimization methods such as SGD, Adam, and SWAG require additional effort to estimate second-order information.

Algorithm 1 FedIvon Algorithm
1:  Input: Total communication rounds R𝑅Ritalic_R, total clients K𝐾Kitalic_K, clients’ private datasets {𝒟k}k=1Ksuperscriptsubscriptsubscript𝒟𝑘𝑘1𝐾\{\mathcal{D}_{k}\}_{k=1}^{K}{ caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT, initial model weight 𝒎~𝟎subscriptbold-~𝒎0\bm{\tilde{m}_{0}}overbold_~ start_ARG bold_italic_m end_ARG start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT, initial model Hessian 𝒉~𝟎subscriptbold-~𝒉0\bm{\tilde{h}_{0}}overbold_~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT
2:  for r=1𝑟1r=1italic_r = 1 to R𝑅Ritalic_R do
3:     Broadcast 𝒎~𝒓,𝒉~𝒓subscriptbold-~𝒎𝒓subscriptbold-~𝒉𝒓\bm{\tilde{m}_{r}},\bm{\tilde{h}_{r}}overbold_~ start_ARG bold_italic_m end_ARG start_POSTSUBSCRIPT bold_italic_r end_POSTSUBSCRIPT , overbold_~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT bold_italic_r end_POSTSUBSCRIPT to all K𝐾Kitalic_K clients
4:     Randomly sample k𝑘kitalic_k clients {Update selected client models locally}
5:     for i=1𝑖1i=1italic_i = 1 to k𝑘kitalic_k do
6:        𝒎𝒊,𝒉𝒊=subscript𝒎𝒊subscript𝒉𝒊absent\bm{m_{i}},\bm{h_{i}}=bold_italic_m start_POSTSUBSCRIPT bold_italic_i end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT bold_italic_i end_POSTSUBSCRIPT = Client_Update(Di,𝒎~𝒓,𝒉~𝒓subscript𝐷𝑖subscriptbold-~𝒎𝒓subscriptbold-~𝒉𝒓D_{i},\bm{\tilde{m}_{r}},\bm{\tilde{h}_{r}}italic_D start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , overbold_~ start_ARG bold_italic_m end_ARG start_POSTSUBSCRIPT bold_italic_r end_POSTSUBSCRIPT , overbold_~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT bold_italic_r end_POSTSUBSCRIPT)
7:     end for
8:     Initialize 𝒎~𝒓+𝟏0,𝒉~𝒓+𝟏0formulae-sequencesubscriptbold-~𝒎𝒓10subscriptbold-~𝒉𝒓10\bm{\tilde{m}_{r+1}}\leftarrow 0,\bm{\tilde{h}_{r+1}}\leftarrow 0overbold_~ start_ARG bold_italic_m end_ARG start_POSTSUBSCRIPT bold_italic_r bold_+ bold_1 end_POSTSUBSCRIPT ← 0 , overbold_~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT bold_italic_r bold_+ bold_1 end_POSTSUBSCRIPT ← 0 {Aggregation of client models at server}
9:     for i=1𝑖1i=1italic_i = 1 to k𝑘kitalic_k do
10:        𝒉~𝒓+𝟏𝒉~𝒓+𝟏+𝒉iw[i]subscriptbold-~𝒉𝒓1subscriptbold-~𝒉𝒓1subscript𝒉𝑖𝑤delimited-[]𝑖\bm{\tilde{h}_{r+1}}\leftarrow\bm{\tilde{h}_{r+1}}+\bm{h}_{i}*w[i]overbold_~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT bold_italic_r bold_+ bold_1 end_POSTSUBSCRIPT ← overbold_~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT bold_italic_r bold_+ bold_1 end_POSTSUBSCRIPT + bold_italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∗ italic_w [ italic_i ]
11:        𝒎~𝒓+𝟏𝒎~𝒓+𝟏+𝒎i𝒉iw[i]subscriptbold-~𝒎𝒓1subscriptbold-~𝒎𝒓1direct-productsubscript𝒎𝑖subscript𝒉𝑖𝑤delimited-[]𝑖\bm{\tilde{m}_{r+1}}\leftarrow\bm{\tilde{m}_{r+1}}+\bm{m}_{i}\odot\bm{h}_{i}*w% [i]overbold_~ start_ARG bold_italic_m end_ARG start_POSTSUBSCRIPT bold_italic_r bold_+ bold_1 end_POSTSUBSCRIPT ← overbold_~ start_ARG bold_italic_m end_ARG start_POSTSUBSCRIPT bold_italic_r bold_+ bold_1 end_POSTSUBSCRIPT + bold_italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⊙ bold_italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∗ italic_w [ italic_i ]
12:     end for
13:     𝒎~𝒓+𝟏𝒎~𝒓+𝟏𝒉~𝒓+𝟏subscriptbold-~𝒎𝒓1subscriptbold-~𝒎𝒓1subscriptbold-~𝒉𝒓1\bm{\tilde{m}_{r+1}}\leftarrow\frac{\bm{\tilde{m}_{r+1}}}{\bm{\tilde{h}_{r+1}}}overbold_~ start_ARG bold_italic_m end_ARG start_POSTSUBSCRIPT bold_italic_r bold_+ bold_1 end_POSTSUBSCRIPT ← divide start_ARG overbold_~ start_ARG bold_italic_m end_ARG start_POSTSUBSCRIPT bold_italic_r bold_+ bold_1 end_POSTSUBSCRIPT end_ARG start_ARG overbold_~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT bold_italic_r bold_+ bold_1 end_POSTSUBSCRIPT end_ARG (elementwise division) {Global weight and Hessian}
14:  end for
15:  Output: Global model weights and Hessian (𝒎~𝑹,𝒉~𝑹)subscriptbold-~𝒎𝑹subscriptbold-~𝒉𝑹(\bm{\tilde{m}_{R}},\bm{\tilde{h}_{R}})( overbold_~ start_ARG bold_italic_m end_ARG start_POSTSUBSCRIPT bold_italic_R end_POSTSUBSCRIPT , overbold_~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT bold_italic_R end_POSTSUBSCRIPT )
Algorithm 2 Client_Update
1:  Input: Local dataset D𝐷Ditalic_D, model weights 𝒎𝒎\bm{m}bold_italic_m, Hessian(𝒉𝒉\bm{h}bold_italic_h), local_epochs(E𝐸Eitalic_E), learning rates {αe}subscript𝛼𝑒\{\alpha_{e}\}{ italic_α start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT }, weight decay δ𝛿\deltaitalic_δ, hyperparameters β1,β2subscript𝛽1subscript𝛽2\beta_{1},\beta_{2}italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, batch-size B𝐵Bitalic_B
2:  Output: Trained model weights 𝒎𝒎\bm{m}bold_italic_m, Hessian 𝝈𝝈\bm{\sigma}bold_italic_σ
3:  𝐠0,λ|D|,n=E|D|/Bformulae-sequence𝐠0formulae-sequence𝜆𝐷𝑛𝐸𝐷𝐵\mathbf{g}\leftarrow 0,\quad\lambda\leftarrow|D|,n=E*|D|/Bbold_g ← 0 , italic_λ ← | italic_D | , italic_n = italic_E ∗ | italic_D | / italic_B.
4:  𝝈1/λ(𝐡+δ)𝝈1𝜆𝐡𝛿\bm{\sigma}\leftarrow 1/\sqrt{\lambda(\mathbf{h}+\delta)}bold_italic_σ ← 1 / square-root start_ARG italic_λ ( bold_h + italic_δ ) end_ARG.
5:  αe(h+δ)αesubscript𝛼𝑒𝛿subscript𝛼𝑒\alpha_{e}\leftarrow\left(h+\delta\right)\alpha_{e}italic_α start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT ← ( italic_h + italic_δ ) italic_α start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT for all e{1,2,,n}𝑒12𝑛e\in\{1,2,\dots,n\}italic_e ∈ { 1 , 2 , … , italic_n }.
6:  for e=1𝑒1e=1italic_e = 1 to E𝐸Eitalic_E do
7:     Sample a batch of inputs of size B𝐵Bitalic_B from D𝐷Ditalic_D.
8:     g^^¯(𝜽)^g^¯𝜽\widehat{\mathrm{g}}\leftarrow\widehat{\nabla}\bar{\ell}(\bm{\theta})over^ start_ARG roman_g end_ARG ← over^ start_ARG ∇ end_ARG over¯ start_ARG roman_ℓ end_ARG ( bold_italic_θ ), where 𝜽qsimilar-to𝜽𝑞\bm{\theta}\sim qbold_italic_θ ∼ italic_q
9:     𝐡^g^(𝜽𝐦)/𝝈𝟐^𝐡^g𝜽𝐦superscript𝝈2\widehat{\mathbf{h}}\leftarrow\widehat{\mathrm{g}}\cdot(\bm{\theta}-\mathbf{m}% )/\bm{\sigma^{2}}over^ start_ARG bold_h end_ARG ← over^ start_ARG roman_g end_ARG ⋅ ( bold_italic_θ - bold_m ) / bold_italic_σ start_POSTSUPERSCRIPT bold_2 end_POSTSUPERSCRIPT
10:     𝐠β1𝐠+(1β1)𝐠^𝐠subscript𝛽1𝐠1subscript𝛽1^𝐠\mathbf{g}\leftarrow\beta_{1}\mathbf{g}+\left(1-\beta_{1}\right)\widehat{% \mathbf{g}}bold_g ← italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_g + ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) over^ start_ARG bold_g end_ARG
11:     𝐡β2𝐡+(1β2)𝐡^+12(1β2)2(𝐡𝐡^)2/(𝐡+δ)𝐡subscript𝛽2𝐡1subscript𝛽2^𝐡12superscript1subscript𝛽22superscript𝐡^𝐡2𝐡𝛿\mathbf{h}\leftarrow\beta_{2}\mathbf{h}+\left(1-\beta_{2}\right)\widehat{% \mathbf{h}}+\frac{1}{2}\left(1-\beta_{2}\right)^{2}(\mathbf{h}-\widehat{% \mathbf{h}})^{2}/(\mathbf{h}+\delta)bold_h ← italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_h + ( 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) over^ start_ARG bold_h end_ARG + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( bold_h - over^ start_ARG bold_h end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ( bold_h + italic_δ )
12:     𝐠¯𝐠/(1β1e)¯𝐠𝐠1superscriptsubscript𝛽1𝑒\overline{\mathbf{g}}\leftarrow\mathbf{g}/\left(1-\beta_{1}^{e}\right)over¯ start_ARG bold_g end_ARG ← bold_g / ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT )
13:     𝐦𝐦αe(𝐠¯+δ𝐦)/(𝐡+δ)𝐦𝐦subscript𝛼𝑒¯𝐠𝛿𝐦𝐡𝛿\mathbf{m}\leftarrow\mathbf{m}-\alpha_{e}(\overline{\mathbf{g}}+\delta\mathbf{% m})/(\mathbf{h}+\delta)bold_m ← bold_m - italic_α start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT ( over¯ start_ARG bold_g end_ARG + italic_δ bold_m ) / ( bold_h + italic_δ )
14:     𝝈1/λ(𝐡+δ)𝝈1𝜆𝐡𝛿\bm{\sigma}\leftarrow 1/\sqrt{\lambda(\mathbf{h}+\delta)}bold_italic_σ ← 1 / square-root start_ARG italic_λ ( bold_h + italic_δ ) end_ARG
15:  end for

3.2 Posterior aggregation at server

At the server, we can aggregate the client posteriors to compute the global posterior [28]. IVON approximates clients’ posteriors as Gaussians and product of Gaussian distributions is still a Gaussian distribution up to a multiplicative constant. Thus we approximate the global distribution as a Gaussian whose optimal mean and covariance matrix expressions are given below. Moreover, since each client’s variational approximation is a Gaussian with diagonal covariance matrix, it makes the aggregation operations efficient. Let’s assume q(𝜽𝒟k)=𝒩(𝜽𝝁k,𝚲k1)𝑞conditional𝜽subscript𝒟𝑘𝒩conditional𝜽subscript𝝁𝑘superscriptsubscript𝚲𝑘1q(\bm{\theta}\mid\mathcal{D}_{k})=\mathcal{N}(\bm{\theta}\mid\bm{\mu}_{k},\bm{% \Lambda}_{k}^{-1})italic_q ( bold_italic_θ ∣ caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) = caligraphic_N ( bold_italic_θ ∣ bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_Λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) where 𝝁k=𝐦ksubscript𝝁𝑘subscript𝐦𝑘\bm{\mu}_{k}=\mathbf{m}_{k}bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = bold_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and 𝚲k=diag(𝝈k2)subscript𝚲𝑘diagsuperscriptsubscript𝝈𝑘2\bm{\Lambda}_{k}=\text{diag}(\bm{\sigma}_{k}^{2})bold_Λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = diag ( bold_italic_σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). Using results of the product of Gaussians based aggregation [4, 28], we have

logq(𝜽𝒟)k=1Kwklogq(𝜽𝒟k)𝑞conditional𝜽𝒟superscriptsubscript𝑘1𝐾subscript𝑤𝑘𝑞conditional𝜽subscript𝒟𝑘\log q(\bm{\theta}\mid\mathcal{D})\approx\sum_{k=1}^{K}w_{k}\log q(\bm{\theta}% \mid\mathcal{D}_{k})roman_log italic_q ( bold_italic_θ ∣ caligraphic_D ) ≈ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT roman_log italic_q ( bold_italic_θ ∣ caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) (6)

where wk=Nkk=1KNksubscript𝑤𝑘subscript𝑁𝑘superscriptsubscript𝑘1𝐾subscript𝑁𝑘w_{k}=\frac{N_{k}}{\sum_{k=1}^{K}N_{k}}italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = divide start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG and

q(𝜽𝒟)𝒩(𝜽𝝁,𝚲1)𝑞conditional𝜽𝒟𝒩conditional𝜽𝝁superscript𝚲1q(\bm{\theta}\mid\mathcal{D})\approx\mathcal{N}(\bm{\theta}\mid\bm{\mu},\bm{% \Lambda}^{-1})italic_q ( bold_italic_θ ∣ caligraphic_D ) ≈ caligraphic_N ( bold_italic_θ ∣ bold_italic_μ , bold_Λ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) (7)

where 𝚲=k=1Kwk𝚲k𝚲superscriptsubscript𝑘1𝐾subscript𝑤𝑘subscript𝚲𝑘\bm{\Lambda}=\sum_{k=1}^{K}w_{k}\bm{\Lambda}_{k}bold_Λ = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_Λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and 𝝁=𝚲1k=1Kwk𝚲k𝝁k𝝁superscript𝚲1superscriptsubscript𝑘1𝐾subscript𝑤𝑘subscript𝚲𝑘subscript𝝁𝑘\bm{\mu}=\bm{\Lambda}^{-1}\sum_{k=1}^{K}w_{k}\bm{\Lambda}_{k}\bm{\mu}_{k}bold_italic_μ = bold_Λ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_Λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT.

Other aggregation strategies are also possible [28] and we leave this for future work. Note that our aggregation strategy can also be seen as Fisher-weighted model merging [29] where each client model is represented as the mean weights 𝒎ksubscript𝒎𝑘\bm{m}_{k}bold_italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and a Fisher matrix which depends on local posterior’s variances 𝝈k2superscriptsubscript𝝈𝑘2\bm{\sigma}_{k}^{2}bold_italic_σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (although model merging only computes the mean, not the covariance, and thus does not yield a global posterior distribution at the server).

The appendix provides further details of IVON and its integration in our Bayesian FL setup.

Notably, FedIvon is appealing from two perspectives: It can be viewed an an efficient Bayesian FL algorithm offering the various benefits of the Bayesian approach, as well as a federated learning algorithm that easily incorporates second-order information during the training of the client models, while not incurring the usual overheads of second-order methods used by some FL algorithms [30].

3.3 Personalized Federated Learning

Personalized FL in FedIVON can be achieved straightforwardly. Similar to equation 3, the personalized loss function for each client k𝑘kitalic_k is defined as,

k(q)=𝔼qk(𝜽)[¯k(𝜽)]+β𝔻KL(qk(𝜽)pk(𝜽)).subscript𝑘𝑞subscript𝔼subscript𝑞𝑘𝜽delimited-[]subscript¯𝑘𝜽𝛽subscript𝔻𝐾𝐿conditionalsubscript𝑞𝑘𝜽subscript𝑝𝑘𝜽\mathcal{L}_{k}(q)=\mathbb{E}_{q_{k}(\bm{\theta})}[\bar{\ell}_{k}(\bm{\theta})% ]+\beta\;\mathbb{D}_{KL}(q_{k}(\bm{\theta})\|p_{k}(\bm{\theta})).caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_q ) = blackboard_E start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) end_POSTSUBSCRIPT [ over¯ start_ARG roman_ℓ end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) ] + italic_β blackboard_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) ∥ italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) ) . (8)

Where β0𝛽0\beta\geq 0italic_β ≥ 0 controls the level of personalization. The term pk(𝜽)subscript𝑝𝑘𝜽p_{k}(\bm{\theta})italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) represents the prior distribution for client k𝑘kitalic_k. During each communication round, the posterior distribution from the server can be used as the prior pk(𝜽)subscript𝑝𝑘𝜽p_{k}(\bm{\theta})italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) for the client. This setup enables clients to adapt the global model according to their local data characteristics while leveraging information from the global model.

When β=0𝛽0\beta=0italic_β = 0, the model becomes fully personalized, relying solely on the client’s data without influence from the prior (i.e., no information from the server). Conversely, a higher value of β𝛽\betaitalic_β incorporates more knowledge from the global server model into the client’s learning process, balancing between personalization and shared global information. This framework provides a flexible mechanism to adapt client models according to their individual data while still benefiting from collective learning through the shared server posterior. We fixed β=1𝛽1\beta=1italic_β = 1 in all our pFL experiments.

4 Experiments: Standard FL

We experiment on three publicly available datasets: EMNIST [31], SVHN [32] and CIFAR-10 [33]. EMNIST consists of 28x28 grayscale images of alphabets and digits (0-9) with a train and test split comprising 124800124800124800124800 and 20800208002080020800 images respectively; however, in our experiments, we restrict to alphabets only. SVHN consists of 32x32 RGB images of house number plates categorized into 10 distinct classes, each corresponding to one of the ten digits. It has a train and test split of size 73252732527325273252 and 26032260322603226032 respectively. CIFAR-10 comprises 32x32 RGB images of objects classified into 10 classes with 50000500005000050000 training images and 10000100001000010000 test images.

In our experiments, We use ADAM optimizer with learning_rate=1e-3, weight_decay=2e-4 for FedAvg and FedLaplace method. IVON[9] optimizer is used for FedIvon with different hyperparameters given in Table 1. Linearly decaying learning rate is used in all the experiments.

params SVHN EMNIST CIFAR-10
initial learning rate 0.1 0.1 0.1
final learning rate 0.01 0.01 0.01
weight decay 2e-4 2e-4 2e-4
batch size 32 32 32
ESS (λ𝜆\lambdaitalic_λ) 5000 5000 5000
initial hessian (h0subscript0h_{0}italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT) 2.0 5.0 1.0
MC sample while training 1 1 1
MC samples while test 500 500 500
Table 1: Ivon Hyperparameters for FL experiments

We evaluate FedIvon in a challenging and realistic scenario involving heterogeneous data distribution among a large number of clients with each client having very few training examples. For each experiment, we consider a total of 200200200200 clients with each client having a small private training set of less than 100100100100 examples. To simulate non-iid data distribution, we randomly sample inputs from the training split, partition the sampled inputs into shards, and distribute shards among clients to create class-imbalanced training data similar to [15]. For a fair comparison, we use the same non-iid data split across clients for all the baseline methods and FedIvon. We follow the experimental setup of [5] and train customized CNN models on EMNIST, SVHN, and CIFAR-10 datasets. We compare our proposed method FedIvon with FedAvg [1] (simple aggregation of client models at server) and FedLaplace [4] (using the Laplace’s approximation to fit a Gaussian distribution to each client’s local model followed by aggregation at the server). FedAvg serves as a baseline to emphasize the importance of uncertainty quantification without compromising on the performance while FedLaplace serves as a competitive baseline to evaluate FedIvon’s predictive uncertainty measures. For all the baselines and FedIvon, we run the federated algorithm for 2000200020002000 communication rounds, selecting a randomly sampled 5%percent55\%5 % i.e., 10101010 clients per round. We train each client’s model locally for 2222 epochs using a batch size of 32323232. We provide further details on hyperparameters, model architectures, and split in the appendix.

4.1 Classification Task

We train a classification model in FL setting using all the methods and report the results in Table 2. We evaluate all trained models’ performance (accuracy and negative log-likelihood) on the test split and use metrics such as Expected Calibration Error (ECE) and Brier score to quantify predictive uncertainty. In our results, FedIvon@mean denotes point estimate based predictions evaluated at the mean of IVON posterior and FedIvon corresponds to Monte Carlo averaging with 500500500500 samples.

As shown in Table 2, FedIvon outperforms all the baselines and yields the best test performance and calibration scores. FedIvon leverages the improved variational online Newton method to approximate the Hessian by continuous updates throughout the training. We also show the convergence of all the methods on all the datasets in Figure 2 and  3. As observed, FedIvon exhibits slightly slower improvements in the early training phase as compared to other baselines but soon outperforms them owing to its improved Hessian approximation as training progresses. Moreover, unlike FedLaplace which fits Gaussian distribution to the client’s model using Laplace approximation evaluated at MAP estimate, FedIvon approximates the Hessian over the entire course of its training, resulting in much better predictive uncertainty estimates. As FedIvon approximates the posterior at both the server and client, it performs well even in scenarios where clients have very limited data (fewer than 50 samples). These results are presented in the supplementary material.

Models EMNIST CIFAR-10 SVHN
ACC(\uparrow) ECE(\downarrow) NLL(\downarrow) BS(\downarrow) ACC(\uparrow) ECE(\downarrow) NLL(\downarrow) BS(\downarrow) ACC(\uparrow) ECE(\downarrow) NLL(\downarrow) BS(\downarrow)
FedAvg 91.66 0.0405 0.3355 0.1303 62.25 0.0981 1.199 0.5191 82.14 0.0311 0.6857 0.2640
FedLaplace 91.33 0.0381 0.3255 0.1314 61.80 0.1072 1.233 0.5284 81.99 0.0211 0.6423 0.2627
FedIvon@mean 93.14 0.0349 0.2821 0.1075 62.92 0.0983 1.1500 0.5114 84.54 0.0241 0.5624 0.2256
FedIvon 93.09 0.0188 0.2341 0.1019 62.54 0.0312 1.0790 0.5021 84.76 0.0148 0.5303 0.2210
Table 2: Test accuracy(ACC), Expected Calibration Error (ECE), Negative Log Likelihood (NLL), and Brier Score (BS)
Refer to caption
Refer to caption
Refer to caption
Figure 2: Loss of various methods vs rounds (left: EMNIST, center: SVHN, right: CIFAR-10).
Refer to caption
Refer to caption
Refer to caption
Figure 3: Test accuracy vs rounds (left: EMNIST, center: SVHN, right: CIFAR-10).

4.2 Out-of-Distribution Detection Task

Predictive uncertainty of the model plays a crucial role in uncertainty-driven tasks such as OOD detection and active learning. We evaluate FedIvon and the baselines for distinguishing OOD inputs from in-distribution inputs using their predictive uncertainty. Given any input 𝐱𝐱\mathbf{x}bold_x, the predictive uncertainty of the model’s output is given by its Shannon entropy and is used to filter OOD inputs. We simulate this task by randomly sampling 5000500050005000 images from the OOD dataset and mixing it with an equal number of randomly sampled inputs from the test split of the training dataset.

Models EMNIST CIFAR-10 SVHM
FedAvg 0.8910 0.7896 0.7975
FedLaplace 0.8297 0.7513 0.8222
FedIvon 0.9032 0.7662 0.8233
Table 3: AUROC (\uparrow) score for OOD/in-domain data detection

Specifically, we use EMNIST, CIFAR-10, and SVHN as the OOD dataset for the models trained on EMNIST, SVHN, and CIFAR-10 respectively. We report the AUROC (area under the ROC curve) metric for all the methods on all the datasets in Table 3 which shows that FedIvon achieves better or competitive AUROC scores as compared to the other baselines.

4.3 Ablation Studies

In our federated learning experiments, we set E=2𝐸2E=2italic_E = 2 for the number of local epochs in the client’s update. In this section, we empirically investigate the impact of varying the number of local epochs on the convergence behavior of different methods in the server. Figure 4 shows the convergence plots for varying values of E𝐸Eitalic_E. When E=1𝐸1E=1italic_E = 1, FedIvon shows slower convergence compared to FedAvg, and FedLaplace converges even more slowly than FedIvon. The slower convergence in FedIvon can be attributed to the way gradients are computed. Specifically, FedIvon uses stochastic sampling of the weights to estimate gradients, and at initialization, this leads to less accurate gradient estimates, which in turn causes slower convergence. Similarly, FedLaplace, which requires the calculation of a MAP estimate, also suffers from slow convergence. With only one epoch of training, the MAP estimate is suboptimal, leading to slower convergence. When E=2𝐸2E=2italic_E = 2, all methods show improved convergence compared to when E=1𝐸1E=1italic_E = 1. This improvement is likely due to more training iterations allowing for better gradient and MAP estimates. In the case of FedLaplace, the MAP estimate becomes more accurate with increased training, resulting in faster convergence. However, FedIvon still outperforms both FedAvg and FedLaplace after a few communication rounds. This improvement can be attributed to the method’s ability to refine gradient estimates over successive communication rounds, allowing FedIvon to overcome its initial slower convergence.

Refer to caption
Refer to caption
Figure 4: Convergence of all the methods on CIFAR-10 dataset with varying local epochs

5 Experiments: Personalized FL

For personalized FL experiments, we focus on two types of data heterogeneity in the clients similar to [26] for classification task. We compare our approach FedIvon against personalized federated baselines (pFedME [24], pFedBayes [25], and pFedVEM [26]).

  • Class distribution skew: In class distribution skew, clients have data from only a limited set of classes. To simulate this, we use the CIFAR-10 dataset and assign each client data from a random selection of 5 out of the 10 classes.

  • Class concept drift: To simulate class concept drift, we use the CIFAR-100 dataset, which includes 20 superclasses, each containing 5 subclasses. For each client, we randomly select one subclass from each superclass (1 out of 5). The client’s local data is then drawn exclusively from these selected subclasses, creating a shift in label concepts across clients. We define the classification task as predicting the superclass.

To model data quantity disparity, we randomly divide the training set into partitions of varying sizes by uniformly sampling slice indices, then assign each partition to a different client.

5.1 Setup

We evaluate our approach in 3 different settings: number of clients K{50,100,200}𝐾50100200K\in\{50,100,200\}italic_K ∈ { 50 , 100 , 200 }. We followed the same model architectures as the prior work [26]. A simple 2-convolution layered-based model is used for CIFAR-10, while a deeper model having 6 convolution layers is used for the CIFAR-100 dataset. We assess both a personalized model (PM) and a global model (GM) at the server. The PMs are evaluated using test data that matches the labels (for label distribution skew) or subclasses (for label concept drift) specific to each client, while the GM is evaluated on the entire test set. All experiments are repeated 3 times, using the same set of 3 random seeds for data generation, parameter initialization, and client sampling. The results are presented in the Table 5.

params CIFAR-10 CIFAR-100
initial learning rate 0.1 0.1
final learning rate 0.001 0.001
weight decay 1e-3 1e-3
batch size 32 32
ESS (λ𝜆\lambdaitalic_λ) 10000 10000
initial hessian (h0subscript0h_{0}italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT) 1.0 1.0
MC sample while training 1 1
MC samples while test 64 64
Table 4: Ivon Hyperparameters for personalized FL experiments
Dataset Method 50 Clients 100 Clients 200 Clients
PM GM PM GM PM GM
CIFAR10 Local 56.9±0.1plus-or-minus56.90.156.9\pm 0.156.9 ± 0.1 - 52.1±0.1plus-or-minus52.10.152.1\pm 0.152.1 ± 0.1 - 46.6±0.1plus-or-minus46.60.146.6\pm 0.146.6 ± 0.1 -
pFedME [24] 72.3±0.1plus-or-minus72.30.172.3\pm 0.172.3 ± 0.1 56.6±1.0plus-or-minus56.61.056.6\pm 1.056.6 ± 1.0 71.4±0.2plus-or-minus71.40.271.4\pm 0.271.4 ± 0.2 60.1±0.3plus-or-minus60.10.360.1\pm 0.360.1 ± 0.3 68.5±0.2plus-or-minus68.50.268.5\pm 0.268.5 ± 0.2 58.7±0.2plus-or-minus58.70.258.7\pm 0.258.7 ± 0.2
pFedBayes [25] 71.4±0.3plus-or-minus71.40.371.4\pm 0.371.4 ± 0.3 52.0±1.0plus-or-minus52.01.052.0\pm 1.052.0 ± 1.0 68.5±0.3plus-or-minus68.50.368.5\pm 0.368.5 ± 0.3 53.2±0.7plus-or-minus53.20.753.2\pm 0.753.2 ± 0.7 64.6±0.2plus-or-minus64.60.264.6\pm 0.264.6 ± 0.2 51.4±0.3plus-or-minus51.40.351.4\pm 0.351.4 ± 0.3
pFedVEM [26] 73.2±0.2plus-or-minus73.20.273.2\pm 0.273.2 ± 0.2 56.0±0.4plus-or-minus56.00.456.0\pm 0.456.0 ± 0.4 71.9±0.1plus-or-minus71.90.171.9\pm 0.171.9 ± 0.1 60.1±0.2plus-or-minus60.10.260.1\pm 0.260.1 ± 0.2 70.1±0.3plus-or-minus70.10.370.1\pm 0.370.1 ± 0.3 59.4±0.3plus-or-minus59.40.359.4\pm 0.359.4 ± 0.3
FedIvon@mean 74.4±0.3plus-or-minus74.40.374.4\pm 0.374.4 ± 0.3 67.1±1.0plus-or-minus67.11.067.1\pm 1.067.1 ± 1.0 71.7±0.3plus-or-minus71.70.371.7\pm 0.371.7 ± 0.3 68.4±0.2plus-or-minus68.40.268.4\pm 0.268.4 ± 0.2 69.7±0.7plus-or-minus69.70.769.7\pm 0.769.7 ± 0.7 68.2±0.3plus-or-minus68.20.368.2\pm 0.368.2 ± 0.3
FedIvon 75.5±0.4plus-or-minus75.50.4\mathbf{75.5\pm 0.4}bold_75.5 ± bold_0.4 67.8±1.6plus-or-minus67.81.6\mathbf{67.8\pm 1.6}bold_67.8 ± bold_1.6 72.6±0.2plus-or-minus72.60.2\mathbf{72.6\pm 0.2}bold_72.6 ± bold_0.2 69.2±0.2plus-or-minus69.20.2\mathbf{69.2\pm 0.2}bold_69.2 ± bold_0.2 70.8±0.4plus-or-minus70.80.4\mathbf{70.8\pm 0.4}bold_70.8 ± bold_0.4 68.7±0.3plus-or-minus68.70.3\mathbf{68.7\pm 0.3}bold_68.7 ± bold_0.3
CIFAR100 Local 34.3±0.2plus-or-minus34.30.234.3\pm 0.234.3 ± 0.2 - 27.6±0.3plus-or-minus27.60.327.6\pm 0.327.6 ± 0.3 - 22.2±0.2plus-or-minus22.20.222.2\pm 0.222.2 ± 0.2 -
pFedME [24] 52.5±0.5plus-or-minus52.50.552.5\pm 0.552.5 ± 0.5 47.9±0.5plus-or-minus47.90.547.9\pm 0.547.9 ± 0.5 47.6±0.5plus-or-minus47.60.547.6\pm 0.547.6 ± 0.5 45.1±0.3plus-or-minus45.10.345.1\pm 0.345.1 ± 0.3 41.6±1.8plus-or-minus41.61.841.6\pm 1.841.6 ± 1.8 41.5±1.6plus-or-minus41.51.641.5\pm 1.641.5 ± 1.6
pFedBayes [25] 49.6±0.3plus-or-minus49.60.349.6\pm 0.349.6 ± 0.3 42.5±0.5plus-or-minus42.50.542.5\pm 0.542.5 ± 0.5 46.5±0.2plus-or-minus46.50.246.5\pm 0.246.5 ± 0.2 41.3±0.3plus-or-minus41.30.341.3\pm 0.341.3 ± 0.3 40.1±0.3plus-or-minus40.10.340.1\pm 0.340.1 ± 0.3 37.4±0.3plus-or-minus37.40.337.4\pm 0.337.4 ± 0.3
pFedVEM [26] 61.0±0.4plus-or-minus61.00.461.0\pm 0.461.0 ± 0.4 52.8±0.4plus-or-minus52.80.452.8\pm 0.452.8 ± 0.4 56.2±0.4plus-or-minus56.20.456.2\pm 0.456.2 ± 0.4 52.3±0.4plus-or-minus52.30.452.3\pm 0.452.3 ± 0.4 51.1±0.6plus-or-minus51.10.651.1\pm 0.651.1 ± 0.6 49.2±0.5plus-or-minus49.20.549.2\pm 0.549.2 ± 0.5
FedIvon@mean 65.4±0.7¯¯plus-or-minus65.40.7\underline{65.4\pm 0.7}under¯ start_ARG 65.4 ± 0.7 end_ARG 63.2±0.5¯¯plus-or-minus63.20.5\underline{63.2\pm 0.5}under¯ start_ARG 63.2 ± 0.5 end_ARG 63.2±0.5¯¯plus-or-minus63.20.5\underline{63.2\pm 0.5}under¯ start_ARG 63.2 ± 0.5 end_ARG 62.1±0.5¯¯plus-or-minus62.10.5\underline{62.1\pm 0.5}under¯ start_ARG 62.1 ± 0.5 end_ARG 56.1±0.6¯¯plus-or-minus56.10.6\underline{56.1\pm 0.6}under¯ start_ARG 56.1 ± 0.6 end_ARG 55.5±0.6¯¯plus-or-minus55.50.6\underline{55.5\pm 0.6}under¯ start_ARG 55.5 ± 0.6 end_ARG
FedIvon 66.7±0.8plus-or-minus66.70.8\mathbf{66.7\pm 0.8}bold_66.7 ± bold_0.8 63.8±0.7plus-or-minus63.80.7\mathbf{63.8\pm 0.7}bold_63.8 ± bold_0.7 63.5±0.6plus-or-minus63.50.6\mathbf{63.5\pm 0.6}bold_63.5 ± bold_0.6 62.4±0.6plus-or-minus62.40.6\mathbf{62.4\pm 0.6}bold_62.4 ± bold_0.6 56.5±0.5plus-or-minus56.50.5\mathbf{56.5\pm 0.5}bold_56.5 ± bold_0.5 55.7±0.7plus-or-minus55.70.7\mathbf{55.7\pm 0.7}bold_55.7 ± bold_0.7
Table 5: Comparison of Personalized FL Methods

5.2 Results

Table 5 presents results on CIFAR-10 and CIFAR-100 datasets, which are used to simulate different types of data heterogeneity in federated learning: CIFAR-10 models class distribution skew, where each client has data from a limited set of classes, while CIFAR-100 represents class concept drift, where each client has data from distinct subclasses within superclasses. For both datasets, we evaluate client’s average accuracy (personalized model) and server accuracy (global model) across varying client counts (50, 100, and 200). FedIvon uses 64 Monte Carlo samples to perform Monte Carlo averaging. On the other hand, FedIvon@mean uses a point estimate using mode of the posterior.

On CIFAR-10, FedIvon achieves similar client accuracy to pFedVEM, indicating both methods perform well under class distribution skew for individual clients. However, in server accuracy, FedIvon shows a notable improvement over pFedVEM and other methods, highlighting FedIvon’s strength in aggregating data from heterogeneous clients into an accurate global model.

On CIFAR-100, which represents class concept drift, FedIvon demonstrates significant improvements over all other methods in both client’s average accuracy and server accuracy. This performance advantage in both personalized and global evaluations suggests that FedIvon is well-suited to handling concept drift, achieving higher accuracy for individual clients and in the global model. Overall, FedIvon consistently outperforms other methods, particularly in server accuracy on CIFAR-10 and in both accuracy metrics on CIFAR-100, underscoring its robustness across different data heterogeneity scenarios.

6 Conclusion

We presented a new Bayesian Federated Learning (FL) method that reduces the computational and communication overhead typically associated with Bayesian approaches. Our method uses an efficient second-order optimization technique for variational inference, achieving computational efficiency similar to first-order methods like Adam while still providing the benefits of Bayesian FL, such as uncertainty estimation and model personalization. We showed that our approach improves predictive accuracy and uncertainty estimates compared to both Bayesian and non-Bayesian FL methods. Additionally, our method naturally supports personalized FL by allowing clients to use the server’s posterior as a prior for learning their own models.

References

  • McMahan et al. [2017] Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. Communication-efficient learning of deep networks from decentralized data. In Artificial intelligence and statistics, pages 1273–1282. PMLR, 2017.
  • Tan et al. [2022] Alysa Ziying Tan, Han Yu, Lizhen Cui, and Qiang Yang. Towards personalized federated learning. IEEE transactions on neural networks and learning systems, 34(12):9587–9603, 2022.
  • Al-Shedivat et al. [2020] Maruan Al-Shedivat, Jennifer Gillenwater, Eric Xing, and Afshin Rostamizadeh. Federated learning via posterior averaging: A new perspective and practical algorithms. In International Conference on Learning Representations, 2020.
  • Liu et al. [2024] Liangxi Liu, Xi Jiang, Feng Zheng, Hong Chen, Guo-Jun Qi, Heng Huang, and Ling Shao. A bayesian federated learning framework with online laplace approximation. IEEE Transactions on Pattern Analysis and Machine Intelligence, 46(1):1–16, January 2024. ISSN 1939-3539. doi: 10.1109/tpami.2023.3322743. URL https://dx.doi.org/10.1109/TPAMI.2023.3322743.
  • Bhatt et al. [2023] Shrey Bhatt, Aishwarya Gupta, and Piyush Rai. Federated learning with uncertainty via distilled predictive distributions, 2023. URL https://arxiv.org/abs/2206.07562.
  • Guo et al. [2023] Han Guo, Philip Greengard, Hongyi Wang, Andrew Gelman, Yoon Kim, and Eric Xing. Federated learning as variational inference: A scalable expectation propagation approach. In The Eleventh International Conference on Learning Representations, 2023.
  • Linsner et al. [2021] Florian Linsner, Linara Adilova, Sina Däubener, Michael Kamp, and Asja Fischer. Approaches to uncertainty quantification in federated deep learning. In ECML PKDD Workshop on Parallel, Distributed, and Federated Learning, pages 128–145. Springer, 2021.
  • Kassab and Simeone [2022] Rahif Kassab and Osvaldo Simeone. Federated generalized bayesian learning via distributed stein variational gradient descent. IEEE Transactions on Signal Processing, 70:2180–2192, 2022.
  • Shen et al. [2024] Yuesong Shen, Nico Daheim, Bai Cong, Peter Nickl, Gian Maria Marconi, Clement Bazan, Rio Yokota, Iryna Gurevych, Daniel Cremers, Mohammad Emtiyaz Khan, and Thomas Möllenhoff. Variational learning is effective for large deep networks, 2024. URL https://arxiv.org/abs/2402.17641.
  • Lin et al. [2020] Tao Lin, Lingjing Kong, Sebastian U Stich, and Martin Jaggi. Ensemble distillation for robust model fusion in federated learning. Advances in Neural Information Processing Systems, 33:2351–2363, 2020.
  • Pfeiffer et al. [2023] Kilian Pfeiffer, Martin Rapp, Ramin Khalili, and Jörg Henkel. Federated learning for computationally constrained heterogeneous devices: A survey. ACM Computing Surveys, 55:1 – 27, 2023. URL https://api.semanticscholar.org/CorpusID:258590978.
  • Zhang et al. [2023] Yifei Zhang, Dun Zeng, Jinglong Luo, Zenglin Xu, and Irwin King. A survey of trustworthy federated learning with perspectives on security, robustness and privacy. Companion Proceedings of the ACM Web Conference 2023, 2023. URL https://api.semanticscholar.org/CorpusID:257050689.
  • Che et al. [2023] Liwei Che, Jiaqi Wang, Yao Zhou, and Fenglong Ma. Multimodal federated learning: A survey. Sensors (Basel, Switzerland), 23, 2023. URL https://api.semanticscholar.org/CorpusID:260693566.
  • Liu et al. [2023] Bingyan Liu, Nuoyan Lv, Yuanchun Guo, and Yawen Li. Recent advances on federated learning: A systematic survey. Neurocomputing, 597:128019, 2023. URL https://api.semanticscholar.org/CorpusID:255415857.
  • Chen and Chao [2020] Hong-You Chen and Wei-Lun Chao. Feddistill: Making bayesian model ensemble applicable to federated learning. CoRR, abs/2009.01974, 2020. URL https://arxiv.org/abs/2009.01974.
  • Maddox et al. [2019] Wesley J Maddox, Pavel Izmailov, Timur Garipov, Dmitry P Vetrov, and Andrew Gordon Wilson. A simple baseline for bayesian uncertainty in deep learning. Advances in neural information processing systems, 32, 2019.
  • Safaryan et al. [2021] M. H. Safaryan, Rustem Islamov, Xun Qian, and Peter Richtárik. Fednl: Making newton-type methods applicable to federated learning. ArXiv, abs/2106.02969, 2021. URL https://api.semanticscholar.org/CorpusID:235358296.
  • Bischoff et al. [2021a] Sebastian Bischoff, Stephan Günnemann, Martin Jaggi, and Sebastian U. Stich. On second-order optimization methods for federated learning, 2021a. URL https://arxiv.org/abs/2109.02388.
  • Achituve et al. [2021] Idan Achituve, Aviv Shamsian, Aviv Navon, Gal Chechik, and Ethan Fetaya. Personalized federated learning with gaussian processes. Advances in Neural Information Processing Systems, 34:8392–8406, 2021.
  • Collins et al. [2021] Liam Collins, Hamed Hassani, Aryan Mokhtari, and Sanjay Shakkottai. Exploiting shared representations for personalized federated learning. In International conference on machine learning, pages 2089–2099. PMLR, 2021.
  • Liang et al. [2020] Paul Pu Liang, Terrance Liu, Liu Ziyin, Nicholas B Allen, Randy P Auerbach, David Brent, Ruslan Salakhutdinov, and Louis-Philippe Morency. Think locally, act globally: Federated learning with local and global representations. arXiv e-prints, pages arXiv–2001, 2020.
  • Fallah et al. [2020] Alireza Fallah, Aryan Mokhtari, and Asuman Ozdaglar. Personalized federated learning with theoretical guarantees: A model-agnostic meta-learning approach. Advances in neural information processing systems, 33:3557–3568, 2020.
  • Finn et al. [2017] Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep networks. In International conference on machine learning, pages 1126–1135. PMLR, 2017.
  • T Dinh et al. [2020] Canh T Dinh, Nguyen Tran, and Josh Nguyen. Personalized federated learning with moreau envelopes. Advances in neural information processing systems, 33:21394–21405, 2020.
  • Zhang et al. [2022] Xu Zhang, Yinchuan Li, Wenpeng Li, Kaiyang Guo, and Yunfeng Shao. Personalized federated learning via variational bayesian inference. In International Conference on Machine Learning, pages 26293–26310. PMLR, 2022.
  • Zhu et al. [2023] Junyi Zhu, Xingchen Ma, and Matthew B Blaschko. Confidence-aware personalized federated learning via variational expectation maximization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 24542–24551, 2023.
  • Angelino et al. [2016] Elaine Angelino, Matthew James Johnson, Ryan P Adams, et al. Patterns of scalable bayesian inference. Foundations and Trends® in Machine Learning, 9(2-3):119–247, 2016.
  • Fischer et al. [2024] John Fischer, Marko Orescanin, Justin Loomis, and Patrick McClure. Federated bayesian deep learning: The application of statistical aggregation methods to bayesian models. arXiv preprint arXiv:2403.15263, 2024.
  • Daheim et al. [2023] Nico Daheim, Thomas Möllenhoff, Edoardo Maria Ponti, Iryna Gurevych, and Mohammad Emtiyaz Khan. Model merging by uncertainty-based gradient matching. arXiv preprint arXiv:2310.12808, 2023.
  • Bischoff et al. [2021b] Sebastian Bischoff, Stephan Günnemann, Martin Jaggi, and Sebastian U Stich. On second-order optimization methods for federated learning. arXiv preprint arXiv:2109.02388, 2021b.
  • Cohen et al. [2017] Gregory Cohen, Saeed Afshar, Jonathan C. Tapson, and André van Schaik. Emnist: an extension of mnist to handwritten letters. ArXiv, abs/1702.05373, 2017. URL https://api.semanticscholar.org/CorpusID:12507257.
  • Netzer et al. [2011] Yuval Netzer, Tao Wang, Adam Coates, A. Bissacco, Bo Wu, and A. Ng. Reading digits in natural images with unsupervised feature learning. 2011. URL https://api.semanticscholar.org/CorpusID:16852518.
  • Krizhevsky [2009] Alex Krizhevsky. Learning multiple layers of features from tiny images. 2009. URL https://api.semanticscholar.org/CorpusID:18268744.
  • Khan and Lin [2017] Mohammad Emtiyaz Khan and Wu Lin. Conjugate-computation variational inference : Converting variational inference in non-conjugate models to inferences in conjugate models, 2017. URL https://arxiv.org/abs/1703.04265.
  • Khan et al. [2018] Mohammad Khan, Didrik Nielsen, Voot Tangkaratt, Wu Lin, Yarin Gal, and Akash Srivastava. Fast and scalable Bayesian deep learning by weight-perturbation in Adam. In Jennifer Dy and Andreas Krause, editors, Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pages 2611–2620. PMLR, 10–15 Jul 2018. URL https://proceedings.mlr.press/v80/khan18a.html.

Appendix A More details on IVON

Computing exact gradients in equation 4 and 5 is difficult due to the expectation term in k(q)subscript𝑘𝑞\mathcal{L}_{k}(q)caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_q ). A naïve way to optimize is to use stochastic gradient estimators. However, these approaches are not very scalable due to the high variance in the gradient estimates. Using natural gradients, Khan and Lin [34] gave improved gradient based update equations for the variational parameters and they call this approach Natural Gradient VI (NGVI). The major difference between NVGI and original update equations is that learning rate is now adapted by the variance 𝝈𝒌𝒕+𝟏superscriptsubscript𝝈𝒌𝒕1\bm{\sigma_{k}^{t+1}}bold_italic_σ start_POSTSUBSCRIPT bold_italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_t bold_+ bold_1 end_POSTSUPERSCRIPT which makes these updates similar to Adam.

NVGI: 𝒎kt+1=𝒎kt+βt𝝈k2t+1[^𝒎kk(q)]𝝈k2t+1=𝝈k2t2βt[^𝝈k2k(q)]NVGI: subscriptsuperscript𝒎𝑡1𝑘subscriptsuperscript𝒎𝑡𝑘direct-productsuperscript𝛽𝑡superscriptsuperscriptsubscript𝝈𝑘2𝑡1delimited-[]subscript^subscript𝒎𝑘subscript𝑘𝑞superscriptsuperscriptsubscript𝝈𝑘2𝑡1superscriptsuperscriptsubscript𝝈𝑘2𝑡2superscript𝛽𝑡delimited-[]subscript^subscriptsuperscript𝝈2𝑘subscript𝑘𝑞\begin{split}\textbf{NVGI: }\bm{m}^{t+1}_{k}&=\bm{m}^{t}_{k}+\beta^{t}{\bm{% \sigma}_{k}^{2}}^{t+1}\odot[\hat{\nabla}_{\bm{m}_{k}}\mathcal{L}_{k}(q)]\\ {\bm{\sigma}_{k}^{-2}}^{t+1}&={\bm{\sigma}_{k}^{-2}}^{t}-2\beta^{t}[\hat{% \nabla}_{\bm{\sigma}^{2}_{k}}\mathcal{L}_{k}(q)]\end{split}start_ROW start_CELL NVGI: bold_italic_m start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL start_CELL = bold_italic_m start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_β start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT bold_italic_σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT ⊙ [ over^ start_ARG ∇ end_ARG start_POSTSUBSCRIPT bold_italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_q ) ] end_CELL end_ROW start_ROW start_CELL bold_italic_σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT end_CELL start_CELL = bold_italic_σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT - 2 italic_β start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT [ over^ start_ARG ∇ end_ARG start_POSTSUBSCRIPT bold_italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_q ) ] end_CELL end_ROW

Further, Khan et al. [35] showed that the NVGI update equations can be written in terms of scholastic gradient and Hessian of 𝜽𝜽\bm{\theta}bold_italic_θ, where 𝝈𝒌𝟐t=[N(𝒉kt+λ)]1superscriptsubscriptsuperscript𝝈2𝒌𝑡superscriptdelimited-[]𝑁subscriptsuperscript𝒉𝑡𝑘𝜆1{\bm{\sigma^{2}_{k}}}^{t}=[N(\bm{h}^{t}_{k}+\lambda)]^{-1}bold_italic_σ start_POSTSUPERSCRIPT bold_2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT = [ italic_N ( bold_italic_h start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_λ ) ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT. The vector 𝒉ktsubscriptsuperscript𝒉𝑡𝑘\bm{h}^{t}_{k}bold_italic_h start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT contains an online estimate of diagonal Hessian. This approach called Variational Online Newton (VON) is similar to NGVI except that it does not require the gradients of the variational objective.

VON: 𝒎kt+1=𝒎ktβt𝒈^(𝜽t)+λ𝒎kt𝒉kt+1+λ𝒉kt+1=(1βt)𝒉kt+βtdiag[^𝜽𝜽2¯k(𝜽t)]VON: subscriptsuperscript𝒎𝑡1𝑘subscriptsuperscript𝒎𝑡𝑘superscript𝛽𝑡^𝒈superscript𝜽𝑡𝜆subscriptsuperscript𝒎𝑡𝑘subscriptsuperscript𝒉𝑡1𝑘𝜆superscriptsubscript𝒉𝑘𝑡11superscript𝛽𝑡superscriptsubscript𝒉𝑘𝑡superscript𝛽𝑡diagdelimited-[]subscriptsuperscript^2𝜽𝜽subscript¯𝑘superscript𝜽𝑡\begin{split}\textbf{VON: }\bm{m}^{t+1}_{k}&=\bm{m}^{t}_{k}-\beta^{t}\frac{% \hat{\bm{g}}(\bm{\theta}^{t})+\lambda\bm{m}^{t}_{k}}{\bm{h}^{t+1}_{k}+\lambda}% \\ \bm{h}_{k}^{t+1}&=(1-\beta^{t})\bm{h}_{k}^{t}+\beta^{t}\text{diag}[\hat{\nabla% }^{2}_{\bm{\theta}\bm{\theta}}\bar{\ell}_{k}(\bm{\theta}^{t})]\end{split}start_ROW start_CELL VON: bold_italic_m start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL start_CELL = bold_italic_m start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_β start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT divide start_ARG over^ start_ARG bold_italic_g end_ARG ( bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) + italic_λ bold_italic_m start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_h start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_λ end_ARG end_CELL end_ROW start_ROW start_CELL bold_italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT end_CELL start_CELL = ( 1 - italic_β start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) bold_italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT + italic_β start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT diag [ over^ start_ARG ∇ end_ARG start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ bold_italic_θ end_POSTSUBSCRIPT over¯ start_ARG roman_ℓ end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) ] end_CELL end_ROW

In the update of VON for non-convex objective functions, the Hessian can be negative which might make 𝝈𝒌𝒕superscriptsubscript𝝈𝒌𝒕\bm{\sigma_{k}^{t}}bold_italic_σ start_POSTSUBSCRIPT bold_italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_t end_POSTSUPERSCRIPT negative, and break VON. To mitigate this issue Khan et al. [35] used a Generalized Gauss-Newton (GGN) approximation of Hessian which is always positive. This method is called VOGN.

θjθj2¯k(𝜽t)1Mi[θjki(𝜽t)]2:=h^j(𝜽)superscriptsubscriptsubscript𝜃𝑗subscript𝜃𝑗2subscript¯𝑘superscript𝜽𝑡1𝑀subscript𝑖superscriptdelimited-[]subscriptsubscript𝜃𝑗subscriptsuperscript𝑖𝑘superscript𝜽𝑡2assignsubscript^𝑗𝜽\nabla_{\theta_{j}\theta_{j}}^{2}\bar{\ell}_{k}(\bm{\theta}^{t})\approx\frac{1% }{M}\sum_{i\in\mathcal{M}}\left[\nabla_{\theta_{j}}{\ell}^{i}_{k}(\bm{\theta}^% {t})\right]^{2}:=\hat{h}_{j}(\bm{\theta})∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over¯ start_ARG roman_ℓ end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) ≈ divide start_ARG 1 end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ caligraphic_M end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT := over^ start_ARG italic_h end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( bold_italic_θ )
VOGN: 𝒎kt+1=𝒎ktβt𝒈^(𝜽t)+λ𝒎kt𝒉kt+1+λ𝒉kt+1=(1βt)𝒉kt+βth^j(𝜽t)VOGN: subscriptsuperscript𝒎𝑡1𝑘subscriptsuperscript𝒎𝑡𝑘superscript𝛽𝑡^𝒈superscript𝜽𝑡𝜆subscriptsuperscript𝒎𝑡𝑘subscriptsuperscript𝒉𝑡1𝑘𝜆superscriptsubscript𝒉𝑘𝑡11superscript𝛽𝑡superscriptsubscript𝒉𝑘𝑡superscript𝛽𝑡subscript^𝑗superscript𝜽𝑡\begin{split}\textbf{VOGN: }\bm{m}^{t+1}_{k}&=\bm{m}^{t}_{k}-\beta^{t}\frac{% \hat{\bm{g}}(\bm{\theta}^{t})+\lambda\bm{m}^{t}_{k}}{\bm{h}^{t+1}_{k}+\lambda}% \\ \bm{h}_{k}^{t+1}&=(1-\beta^{t})\bm{h}_{k}^{t}+\beta^{t}\hat{h}_{j}(\bm{\theta}% ^{t})\end{split}start_ROW start_CELL VOGN: bold_italic_m start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL start_CELL = bold_italic_m start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_β start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT divide start_ARG over^ start_ARG bold_italic_g end_ARG ( bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) + italic_λ bold_italic_m start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_h start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_λ end_ARG end_CELL end_ROW start_ROW start_CELL bold_italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT end_CELL start_CELL = ( 1 - italic_β start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) bold_italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT + italic_β start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT over^ start_ARG italic_h end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) end_CELL end_ROW

VOGN [35] improves these equations where Gauss Newton estimation is used instead of Hessian which gives similar update equations as the Adam optimizer. However, it still uses per-sample squaring which is costly as compared to Adam.

IVON: 𝐡^kt=^¯k(𝜽)𝜽𝐦kt𝝈k2t𝐡kt+1=(1ρ)𝐡kt+ρ𝐡^kt+12ρ2(𝐡kt𝐡^kt)2(𝐡kt+s0/λ)IVON: superscriptsubscript^𝐡𝑘𝑡^subscript¯𝑘𝜽𝜽superscriptsubscript𝐦𝑘𝑡superscriptsubscript𝝈𝑘superscript2𝑡subscriptsuperscript𝐡𝑡1𝑘1𝜌superscriptsubscript𝐡𝑘𝑡𝜌superscriptsubscript^𝐡𝑘𝑡12superscript𝜌2superscriptsuperscriptsubscript𝐡𝑘𝑡superscriptsubscript^𝐡𝑘𝑡2superscriptsubscript𝐡𝑘𝑡subscript𝑠0𝜆\begin{split}\textbf{IVON: }\widehat{\mathbf{h}}_{k}^{t}&=\widehat{\nabla}\bar% {\ell}_{k}(\bm{\theta})\cdot\frac{\bm{\theta}-\mathbf{m}_{k}^{t}}{\bm{\sigma}_% {k}^{2^{t}}}\\ \mathbf{h}^{t+1}_{k}&=(1-\rho)\mathbf{h}_{k}^{t}+\rho\widehat{\mathbf{h}}_{k}^% {t}+\frac{1}{2}\rho^{2}\frac{(\mathbf{h}_{k}^{t}-\widehat{\mathbf{h}}_{k}^{t})% ^{2}}{\left(\mathbf{h}_{k}^{t}+s_{0}/\lambda\right)}\end{split}start_ROW start_CELL IVON: over^ start_ARG bold_h end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_CELL start_CELL = over^ start_ARG ∇ end_ARG over¯ start_ARG roman_ℓ end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_θ ) ⋅ divide start_ARG bold_italic_θ - bold_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG start_ARG bold_italic_σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL bold_h start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL start_CELL = ( 1 - italic_ρ ) bold_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT + italic_ρ over^ start_ARG bold_h end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_ρ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT divide start_ARG ( bold_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT - over^ start_ARG bold_h end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG ( bold_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT + italic_s start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / italic_λ ) end_ARG end_CELL end_ROW

Further, Shen et al. [9] improved these update equations and provided much more efficient update equations similar to Adam optimizer, which is essentially the improved variational online Newton (IVON) algorithm [9].

Appendix B Reliability diagrams for FL experiments

Figures 5 and 6 show the reliability diagrams for CIFAR-10 and EMNIST experiments, respectively. The diagrams indicate that Fedivon has better-calibrated predictions compared to FedAvg and FedLaplace, as shown by its lower Expected Calibration Error (ECE).

Refer to caption
Refer to caption
Refer to caption
Figure 5: Reliability diagrams for CIFAR-10 experiments (left: FedAvg, center: Fedlaplace, right: FedIvon).
Refer to caption
Refer to caption
Refer to caption
Figure 6: Reliability diagrams for EMNIST experiments (left: FedAvg, center: Fedlaplace, right: FedIvon).

Appendix C Client data distribution in FL experiments

Figure 7 illustrates the data distribution among clients used in the FL experiments. Each client has a highly imbalanced dataset, with the number of samples per client ranging from 5 to 32. Additionally, each client’s dataset is limited to only a subset of classes, further emphasizing the non-IID nature of the data. This experimental setup poses significant challenges for training a robust global server model, as the limited and biased data from individual clients must be aggregated effectively to learn a model capable of generalizing across all classes. This scenario highlights the complexities and practical relevance of federated learning in real-world applications.

Refer to caption
Refer to caption
Refer to caption
Figure 7: Client data distribution for CIFAR-10, EMNIST, and SVHN dataset used in FL experiments.

Appendix D Client data distribution in pFL

Figure 8 illustrates the distribution of data points across classes and clients in three pFL experimental setups with 50, 100, and 200 clients. The number of data points per client varies significantly, with some clients having over 1,000 data points and others fewer than 5, indicating a high degree of imbalance. Despite this, every client retains examples from most classes, which is crucial for training personalized models that adapt to the unique data distribution of each client. This setup highlights the challenge of learning effective personalized models in pFL. Similarly, Figure 9 shows the data distribution for the CIFAR-100 dataset.

Refer to caption
Refer to caption
Refer to caption
Figure 8: Client data distribution for CIFAR-10 dataset used in pFL experiments (left: 50 clients, right: 100 clients, bottom: 200 clients).
Refer to caption
Refer to caption
Refer to caption
Figure 9: Client data distribution for CIFAR-100 dataset used in pFL experiments (left: 50 clients, right: 100 clients, bottom: 200 clients).