Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning - arXiv
←
→
Page content transcription
If your browser does not render page correctly, please read the page content below
Preservation of the Global Knowledge by Not-True Self Knowledge Distillation in Federated Learning Gihun Lee, Yongjin Shin, Minchan Jeong, and Se-Young Yun KAIST {opcrisis, yj.shin, mcjeong, yunseyoung}@kaist.ac.kr arXiv:2106.03097v1 [cs.LG] 6 Jun 2021 Abstract In Federated Learning (FL), a strong global model is collaboratively learned by aggregating the clients’ locally trained models. Although this allows no need to access clients’ data directly, the global model’s convergence often suffers from data heterogeneity. This paper suggests that forgetting could be the bottleneck of global convergence. We observe that fitting on biased local distribution shifts the feature on global distribution and results in forgetting of global knowledge. We consider this phenomenon as an analogy to Continual Learning, which also faces catastrophic forgetting when fitted on the new task distribution. Based on our findings, we hypothesize that tackling down the forgetting in local training relives the data heterogeneity problem. To this end, we propose a simple yet effective framework Federated Local Self-Distillation (FedLSD), which utilizes the global knowledge on locally available data. By following the global perspective on local data, FedLSD encourages the learned features to preserve global knowledge and have consistent views across local models, thus improving convergence without compromising data privacy. Under our framework, we further extend FedLSD to FedLS-NTD, which only considers the not-true class signals to compensate noisy prediction of the global model. We validate that both FedLSD and FedLS-NTD significantly improve the performance in standard FL benchmarks in various setups, especially in the extreme data heterogeneity cases. 1 Introduction Nowadays, massive data is collected from edge devices such as phones, vehicles, and facilities. By decoupling the ability to learn from the need to access private client data directly, Federate Learning (FL) proposes a distributed learning paradigm that enables learning a global model while preserves clients’ data privacy [24, 25, 36, 37]. In FL, clients independently train local models using their private local data, and the global server aggregates them into a single global model. In this process, most of the computation is performed by the client devices, and the global server only updates the aggregated model parameters and distributes them to clients [1, 56]. Most FL algorithms are based on FedAvg [36], which aggregates the locally trained model parameters by weighted averaging proportional to the amount of local data that each client had. While various FL algorithms have been proposed until recently, they all conduct parameter averaging in certain manners [29, 34, 43, 51, 52, 58]. Although this aggregation scheme empirically works well and provides a conceptually ideal framework when all client devices are active and i.i.d. distributed (a.k.a. LocalSGD), the data heterogeneity problem [30, 63] is one of the major challenges for FL applications to be widespread. Since the client generates its own data, the data is not identically distributed. More precisely, the local data across clients are drawn from heterogeneous underlying distributions; thereby, locally available data fail to represent the overall global distribution. Despite its inevitable occurrence in many real- world scenarios, the heterogeneity of local data not only makes the theoretical analysis difficult Preprint. Under review.
(a) Continual Learning (b) Federated Learning (FedAvg) (c) Federated Learning (FedLSD) Figure 1: An overview of forgetting in learning scenarios. [19, 21, 30, 63], but also degrades many FL algorithms’ performance [16, 28, 29, 33]. By resolving the data heterogeneity problem, the learning becomes more robust against partial participation of clients [30, 36], and the communication cost is also reduced by faster convergence [51, 58]. Interestingly, Continual Learning [44, 49] faces a similar challenge. In CL, a learner model is contin- uously updated on a sequence of tasks, with the objective to remember whole tasks. Unfortunately, since the data distribution of each task is heterogeneous, learning on different tasks often results in catastrophic forgetting [23, 32, 35, 40], whereby fitting on a new task interferes with the learned representations on the previous task. As a consequence, the model parameters drift away from the area where the past knowledge is desirably preserved (Figure 1a). Our first conjecture is that such forgetting also exists in FL scenarios as illustrated in Figure 1b. The global model is updated on randomly sampled local clients at each round, but the local datasets have different distributions from the previous round. We focus on this analogy of the FL scenario in the sense of catastrophic forgetting. To empirically verify this analogy, we examine the prediction of the global model before and after the communication round. We find that the prediction consistency between two global models degrades as the data heterogeneity increases. We then analyze the representation change after local training and figure out fitting on the biased local distribution causes the trained model’s features to be shifted towards the locally available region, resulting in forgetting of global knowledge. Based on our findings, we hypothesize that tackling down forgetting of global knowledge can relieve the data heterogeneity problem in FL. Inspired by the prior works in CL [15, 32], which distills predictions on old tasks while learning the new tasks, we propose a simple yet effective FL framework Federated Local Self Distillation (FedLSD). FedLSD utilizes the global model’s prediction on local data to preserve global knowledge, which the local distribution cannot represent(Figure 1c). More specifically, the local model self-distills the distributed global model’s prediction on locally available data. This approach resembles distillation-based knowledge preservation in Continual Learning [15, 32], where new task data is evaluated for the old task and its prediction is maintained. However, merely following the global prediction may obtain sub-optimal performance in FL. Unlike CL, where each task is sufficiently learned before observing a new task, FL update with only a few local steps due to its distributed setups. This induces the global model’s prediction to be noisy (the predicted label is different from the true label), especially in the early communication rounds. To compensate for it, we further extend FedLSD to Federated Local Self Not-True Distillation (FedLS-NTD) by following not-true class signals only. We demonstrate the preservation of global knowledge of both algorithms and their efficacy on standard FL benchmarks with various setups. Our contributions can be summarized as follows: • Based on the analogy of FL scenarios on CL, We emphasize the forgetting of global knowledge in FL and suggest that feature shifting induced fitting on biased local distribution causes the forgetting (Section 2). • We propose FedLSD to preserve global knowledge. To compensate the noisy prediction of global model in early phase, We further extend FedLSD as FedLS-NTD. Unlike previous distillation-based approaches in FL, our methods does not require any additional local information or auxiliary data (Section 3). • We validate the efficacy of FedLSD and FedLS-NTD in various setups. Using Centered Kernel Analysis (CKA) and T-SNE visualization, we demonstrate that both algorithms enjoy remarkable feature consistency during local training. (Section 4). 2
2 Forgetting global knowledge in FL To confirm our conjecture on forgetting, we first consider how the prediction of the global model changes. If the data heterogeneity induces forgetting, the global model’s prediction after aggregation may be less consistent when compared to before it is distributed. To examine it, we measure Intersection of Union (IoU) of mutually correct samples between Distributed Global (DG) model and |DGcorrect ∩ AGcorrect | Aggregated Global (AG) model as IoUcorrect = |DG correct ∪ AGcorrect | . The result is summarized in Table 1. As expected, the prediction consistency degrades as data heterogeneity increases (less number of classes per client), along with growing variance. This implies that forgetting also occurs in FL, and it is closely related to data heterogeneity 1 . Table 1: Prediction consistency of CIFARCNN model by varying heterogeneity. max class / client CIFAR10 2 3 5 8 10 (iid) (DG vs AG) IoUcorrect 0.270±0.092 0.326±0.095 0.375±0.076 0.405±0.074 0.475±0.014 max class / client CIFAR100 5 10 20 40 100 (iid) (DG vs AG) IoUcorrect 0.257±0.081 0.314±0.069 0.334±0.070 0.357±0.062 0.418±0.037 To figure out forgetting of knowledge in the global model, we now analyze how the representation on the global distribution changes during local training. To this end, we design a straightforward experiment that shows the change of features on the unit hypersphere. More specifically, we modified the network architecture to map CIFAR-10 input data to 2-dimensional vectors and normalize them to be aligned on the unit hypersphere S 1 = {x ∈ R2 : kxk2 = 1}. We then estimate their probability density function. The global model is learned for 100 rounds of communication, but this time, on homogeneous locals (i.i.d. distributed) and distributed to heterogeneous locals with different local distributions. The result is in Figure 2. (a) FedAvg (b) FedLSD (ours) Figure 2: Features of CIFAR-10 test samples on S 2 . We plot the feature distribution with Gaussian kernel density estimation (KDE) in R2 and arctan(y, x) for each point (x, y) ∈ S 1 . The distributed global model (first column) is trained on heterogeneous locals (middle 3 columns) and aggregated by parameter averaging (last column). The values in the parenthesis are test accuracy. 1 We distribute CIFAR-10 data to 100 clients and sample 10 clients per round, where each device has the same amount of samples but with maximally two classes. The experiment is conducted for 300 communication rounds on a CIFARCNN(2-conv + 3-fc layers) model, the same architecture used in FedAvg [36]. More details are in Appendix. 3
Obviously, the feature vectors from the global model learned with homogeneous locals are uniformly distributed to hypersphere S 1 . However, after the local training, the entire feature region is shifted towards the locally available classes as in Figure 2a. Thus, when aggregated the local models, the global model only reflects the feature regions where was dominant in the local models, and the knowledge of unseen classes is forgotten, significantly degrading the global test accuracy. Our analysis shows that feature shifting on the global distribution, where the local set fails to represent, causes the forgetting of the global knowledge. Based on our findings, we hypothesize that preserving the knowledge on the global distribution prevents forgetting and relives the data heterogeneity problem in FL. 3 Federated Local Self Distillation In this section, we first introduce the proposed Federated Local Self Distillation (FedLSD) and its key features. We then further extend FedLSD to Federated Local Self Not-True Distillation (FedLS-NTD), which shows more improved performance. Algorithm 1 Federated Local Self-Distillation (FedLSD) Input: weight w0 , total rounds T , local epochs E, dataset X × Y, sampled client set K (t) in round t, learning rate α Initialize w0 for global server for each communication round t = 1, · · · , T do Server samples clients K (t) and broadcasts w̃t ← wt to all clients for each client k ∈ K (t) in parallel do Construct prediction pair Xk × Yk for Local Steps i = 1 · · · E do for Batches j = 1 · · · B do w̃kt ← w̃kt − α∇w LFedLSD [Xk × Yk ]j (w̃kt ) [Equation 1 or Equation 2] end for end for end for Upload w̃kt to server Server Aggregation :wt+1 ← |K1(t) | k∈K (t) w̃kt P end for Server output : wT 3.1 Local self distillation The core idea of FedLSD is to preserve the global view on the local data during local training. The shifting of the learned features occurs by fitting on the biased local distribution, which cannot represent the global distribution. Therefore, preserving information about where local data should be located in the global distribution prevents the locally learned features from being deviated from the feature space of the global model. To this end, FedLSD conducts local-side self-distillation: the prediction of the distributed global model on the local data is distilled during local training, using the following loss function LFedLSD as a linear combination of cross-entropy LCE and the distillation loss LLSD . LFedLSD = (1 − β) · LCE (q, py ) + β · LLSD (qτ , qτg ) (0 < β < 1) , (1) where q, q g is the softmax probability of each local client model and global model at last round, and py is the one-hot label. The distillation loss LLSD is defined as follows: exp(zc /τ ) qτ (c) = PCi=1 exp(zi /τ ) XC qτ (c) g g , L (q LSD τ τ, q ) = − q (c) log . g qτg (c) = PCexp(zc /τg) c=1 τ qτg (c) i=1 exp(z /τ ) i Here, q(c) stands for the prediction probability on class c ∈ [C] as a softmax function of logits z and C is the number of classes, and py (c) denotes one-hot true label. We omit data x and model weight w and write q(x, w)(c) and py (x)(c) as q(c) and py (c) for simplicity. For distillation loss, the logits are divided by τ to soften the prediction probability. Note that the LLSD is the KL-Divergence loss between local prediction and global prediction. 4
FedLSD resembles memory-based CL algorithms [15, 32, 42], which exploits distillation loss to avoid forgetting by preserving the knowledge on old tasks. Although CL scenarios generally assume a small amount of old task data to replay (episodic memory), such an assumption is not allowed in FL due to the data privacy issue. Instead, the global model’s predictions can be utilized as a reference of the previous data distribution to induce a similar effect to that of using the episodic memory in CL. We analyze the effect of LLSD in terms of weight divergence. Proposition 1 implies that increasing β in local objective reduces the weight divergence wt − w(GD) t . The corresponding experiment is in Figure 3, where the local models trained using LLSD (τ=1) and varying β. Figure 3: Weight divergence of CIFARCNN on CIFAR-10 datasets (max class/client = 2). Proposition 1. Consider the FedAvg[36] setting with uniformly weighted K = |K (t) | local clients holding class distribution Pk(t) . We train the global and local clients by LLSD (τ =1) with convex class landscape w 7→ EDc [LLSD (·, w)], where Dc is the data distribution of c-th class. Assume that ∂z w 7→ ∂w (·, w) is λJc -Lipschitz /B-bounded and w 7→ q(·, w) is λqc -Lipschitz with k · kL2 (Dc ) . √ Then, if 0 < η < 2 minc ( 2λJc + Bλqc )-1 , the weight divergence between averaged local wt+1 and t+1 global w(GD) after one round is bounded by the following inequality. E−1 ηB X (t) X kwt+1 − w(GD)t+1 q(·, w̃kt,i ) − β q g + (1 − β) py k≤ χ(P Pk ) (t) , K i=0 L2 Dk k∈R(t) (t) where Dk is the data on k-th local and q g = q(x, wt ) is the global prediction at last round. 3.2 Local self not-true distillation Although FedLSD reduces forgetting and effectively stabilizes weight divergence, merely following the global prediction is sub-optimal to prevent forgetting. At first, unlike CL, where each task is sufficiently learned before observing a new task, FL updates with only a few local steps due to its distributed setups. This induces the global model’s prediction to be noisy (the predicted label is different from the true label), especially in the early communication rounds. Secondly, to prevent forgetting, information not included in the local distribution should have a higher priority. When the teacher prediction is noisy, however, the KD-loss cannot sufficiently enjoy the benefits from temper- Figure 4: Not-True Distillation ature softening. As discussed in [22], using a small temperature is rather performs better when teacher is noisy, neglecting the negative noise effect . To tackle down both noisy global model’s predictions and local information priority issues, we further extend FedLSD by distilling the global view on the relationship between not-true classes only, as illustrated in Figure 4. With the decoupling between signals for the true class and the not-true classes, the feature shifting problem can now be relieved without compromising learning on the local data. More specifically, the logits belong to the true class is not used, when obtaining the softmax prediction probability for both local and global model, and the distillation loss closes the gap between them. The proposed Federated Local Not-True Distillation (FedLS-NTD) is formed similar to the FedLSD loss function: LFedLS-NTD = (1 − β) · LCE (q, py ) + β · LNTD (q̃τ , q̃τg ) (0 < β < 1) , (2) 5
where the q̃τ and q̃τg are the softmax of not-true class logits and LNTD is the KL-Divergence loss between them: exp(zc /τ ) q̃τ (c) = PCi=1,i6=y exp(zi /τ ) C g X g q̃τ (c) (∀c 6 = y), L NTD (q̃τ , q̃ ) = − q̃ (c) log . exp(zcg /τ ) τ τ q̃τg (c) q̃τg (c) = PC g c=1,c6=y i=1,i6=y exp(z /τ ) i We analyze the advantage of FedLS-NTD over FedLSD, using the loss function as follows: LLSD→NTD = (1 − β) · LCE (q, py ) + β · L(λ) , where L(λ) = (1 − λ) · LLSD (qτ , qτg ) + λ · LNTD (q̃τ , q̃τg ) . Here, λ moves the loss function between FedLSD(at λ = 0) and FedLS-NTD (at λ = 1), where other hyperparameters are fixed (τ = 1, β = 0.3). To improve the performance of the global model, an aggregated global model should 1) preserve the correct prediction and 2) discard the incorrect prediction of the previous global model. we examine the prediction consistency IoUcorrect and IoUIncorrect between Distributed Global (DG) and Aggregated Global (AG) models (Figure 5). Since FedLS-NTD does not follow the true-class signal, the ability to discard incorrect prediction is enhanced while preserves correct prediction on par with FedLSD. Figure 5: Prediction consistency of CIFARCNN on CIFAR-10 datasets (max class/client = 2). 4 Experiment We evaluate our FedLSD and FedLS-NTD on various FL scenarios with data heterogeneity. We compare our results with FedAvg[36] and FedProx[29] as baselines and validate how the proposed method improves the FL performance. VggNet-11[47] and ResNet-8[13] model are used to test the algorithms. We used three different datasets: FEMNIST[4], CIFAR-10, and CIFAR-100 [27]. For CIFAR datasets, we introduce label distribution heterogeneity, where the clients have the same number of samples but a limited number of classes (2 for CIFAR-10, and 20 for CIFAR-100). The experimental details are in Appendix A. 4.1 Performances on data heterogeneous FL scenarios We compare the algorithms on different datasets to validate the efficacy of the proposed algorithms. The results are in Table 2. In the experiment, FedLSD and FedLS-NTD consistently outperform baseline algorithms. We also investigate how they perform when combined with FedProx. Although FedProx does not always improve the FedAvg baseline in our experiment, the combination with FedLSD or FedLS-NTD brings additional performance improvement. Figure 6: Learning curves for Res8 model corresponds to results in Table 2. 6
Table 2: Test accuracy after target rounds and number of rounds to reach the target accuracy FEMNIST CIFAR10 (δ = 2) CIFAR100 (δ = 20) Model Algorithm test acc round (80%) test acc round (60%) test acc round (40%) FedAvg[36] 86.0 ±0.69 63 54.2 ±5.62 466 50.4 ±0.90 154 FedProx[29] 85.9 ±0.30 68 55.8 ±3.23 406 51.5 ±0.52 139 FedLSD (ours) 86.9 ±0.23 58 58.0 ±1.91 425 52.1 ±0.24 138 Res8 FedLS-NTD (ours) 86.2 ±0.53 71 63.5 ±0.95 299 52.9 ±0.14 127 FedProx[29] + FedLSD 85.4 ±0.06 94 60.5±2.52 397 52.1 ±0.11 134 FedProx[29] + FedLS-NTD 85.5 ±0.09 89 64.8 ±1.15 265 53.3 ±0.07 128 FedAvg[36] 84.5 ±0.37 118 67.3 ±1.54 302 42.5 ±0.52 240 FedProx[29] 85.1 ±0.16 132 61.1 ±0.42 254 42.3 ±0.04 220 FedLSD (ours) 85.6 ±0.16 129 68.6 ±1.15 246 46.3 ±0.16 193 VGG11 FedLS-NTD (ours) 84.8 ±0.70 129 69.6 ±2.01 190 45.4 ±0.23 193 FedProx[29] + FedLSD 84.6 ±0.18 129 67.8 ±1.29 227 46.4 ±0.03 198 FedProx[29] + FedLS-NTD 84.8 ±0.21 137 68.1 ±1.17 180 45.1 ±0.08 198 δ: indicates the (max class / client) While FedLS-NTD generally performs better than FedLSD, the performance gap varies depending on the datasets and model architectures. We analyze that since FedLS-NTD learns not-true class relationships only, FedLSD may perform better when the prediction noise of the global model is not significant. We suggest that FedLS-NTD is more effective when the intensity of data heterogeneity is high, thereby following the global model possesses more risk to have biased distribution. For example, in FEMNIST datasets, almost all classes are included in the local data, but the hand-written data that belongs to the same class have different forms across the clients. In this case, FedLSD and FedLS-NTD show slow convergence in the early phase while they improve the final test accuracy. We further investigate the performance by varying the data heterogeneity (Table 3) and sampling ratio (Table 4). The results show that the proposed algorithms significantly outperform the baselines in data heterogeneity cases. Interestingly, while the baseline algorithms fail to learn and diverge when only one class is available in the local, FedLSD and FedLS-NTD are able to learn even under such extreme heterogeneity. Although the FL scenarios where only positive labels are accessible in the local is not our main scope, we argue that this result verifies that our proposed algorithms effectively encourage the preservation of global knowledge during local training. Table 3: Res8 on CIFAR10: Data heterogeneity Table 4: Res8 on CIFAR10: Sampling ratio max class / client sampled clients / N Algorithm Algorithm 1 2 3 5 10 (iid) 0.05 0.1 0.3 0.5 1.0 FedAvg failure 54.2 64.6 76.9 86.8 FedAvg 44.0 54.2 65.5 67.8 68.7 FedProx failure 55.8 66.6 77.8 86.5 FedProx 49.7 55.8 65.0 67.7 67.5 FedLSD 22.9 58.0 70.6 80.3 86.6 FedLSD 52.9 58.0 63.3 65.1 67.0 FedLS-NTD 21.3 63.5 74.9 81.7 86.0 FedLS-NTD 58.1 63.5 69.4 71.0 72.7 4.2 Preservation of global knowledge To empirically measure how well the global knowledge is preserved after local training, we examine the trained local models’ test accuracy on global distribution (Figure 7). If fitting on the local distribution preserves global knowledge, the trained local models could generalize well on the global distribution. However, since forgetting occurs during local training, the baseline algorithms show only around 20% test accuracy on global distribution. Since maximally two classes are available locally for each client is, this indicates that the trained model forgets the knowledge on the other classes, which have not been observed in the local training. On the other hand, FedLSD and FedLS-NTD show significantly higher local test accuracy while improving the final global accuracy. This implies that the FedLSD and FedLS-NTD preserve global distribution information during local training, even if the local distribution cannot explicitly represent it. 7
Figure 7: A comparison of final global test accuracy and local test accuracy corresponds to Table 2. The local test accuracy is an average of sampled clients for the last five rounds. 4.3 Analysis on feature similarity CKA analysis To demonstrate FedLSD and FedLS-NTD enjoys the feature consistency, we com- pare the CKA [26] between global model and locally trained models. More specifically, we measure the CKA of Distributed Global(DG), Local(L), and Aggregated Global(AG) models as 4-types: (a) DG vs L, (b) L vs L, (c) L vs AG, and (d) DG vs AG as plotted in Figure 8. As illustrated, the feature region is dramatically shifted during local training of FedAvg algorithm. In FedProx, the drift of the Figure 8: CKA[26] of Res8 model on CIFAR10 datasets (max class/client = 2). The values in the parenthesis indicates the test accuracy of the final global model. local model is relieved than FedAvg, but still suffers the inconsistency between DG vs. AG features. On the other hand, FedLSD and FedLS-NTD maintain the feature space not to be shifted during local training; hence the DG and AG models share the feature region. Notably, although the feature similarity of FedLS-NTD is lower at the early phase of learning, it reaches and even exceeds the feature similarity of FedLSD at the end of learning. T-SNE visualization Contrary to FedAvg, our FedLSD effectively binds the features even when the local distribution is heavily biased. In Figure 9a, the test set features from local models are shifted when FedAvg is used (blue). However, the local features learned by FedLSD (red) share same space with the global model’s features (green). When viewed by classes Figure 9a, these features are well-aligned by their semantics. Note that the selected locals are identical for both algorithms. (a) colored by different algorithms (b) colored by classes Figure 9: T-SNE visualization CIFAR-10 test set features on ResNet-8 model. The global model is learned for 100 communication round using FedAvg, and distributed to 10 locals to conduct local training. The T-SNE is performed all together for the test sample features of global and 10 local models. 8
5 Related work Federated Learning (FL) Federated Learning is proposed to update a global model by asyn- chronous SGD [9], while the local data is kept in users’ devices [24, 25]. A standard algorithm in FL is FedAvg [36], which aggregates trained local models by averaging their parameters. Although its effectiveness has been largely discussed in i.i.d. settings [18, 48, 54, 55, 59], many algorithms obtain the sub-optimal when the distributed data is heterogeneous [30, 63]. A wide range of variants of FedAvg has been proposed to overcome such a problem. One line of work is the local-side modifica- tion. For example, Fedprox [29] adds a proximal term in local objective, and SCAFFOLD [19] uses control variates to correct the client-drift. In FedNova [52], the locally trained models are normalized to tackle down local objective inconsistency. Likewise, the server-side modification is also considered in various works. For instance, AFL[38] optimizes a mixture of local distribution in a centralized manner. In PFNM [60] and FedMA [51], the optimal transport to match the layer-wise order of individual neurons by conducting post-training after aggregation. Our work can be considered as a local-side modification, which focuses on the local objective. Continual Learning (CL) Continual Learning [44, 49] is a learning paradigm that updates a sequence of tasks instead of training on the whole datasets at once. In CL, the main challenge is to avoid catastrophic forgetting [11], whereby training on the new task interferes with the learned representation of the previous tasks. The prior works in CL largely focuses on overcoming this catastrophic forgetting [5, 6, 15]. Existing methods can be mainly classified into two categories. At first, parameter-based approaches [2, 6, 23, 39, 62] regularize the change of parameters, so that task parameters do not interfere with each other. Secondly, memory-based approaches [3, 5, 15, 32, 42] maintain the knowledge of previous tasks by replaying a small episodic memory data. In our work, we consider the latter to relieve forgetting in FL scenarios. Although the episodic data of previous tasks (local data from previous rounds) cannot be stored due to the data privacy issue, we regard that following the distributed global model’s prediction has a similar effect of replaying memory. Knowledge Distillation in FL Knowledge Distillation (KD) is proposed to transfer knowledge from a trained model to another [14]. KD shows impressive success with its enormous variants [10, 50, 57, 64, 66]. In FL, most approaches based on KD aim to make the global model learn the ensemble of locally trained models [7, 12, 17, 34, 65]. For instance, FD [17] collects averaged local logits per label and distributes them to regularize local training. In FedDF [34], the averaged logits of local models are used for distillation at the server. FedAUX [45] expands the idea of FD to conducting self-supervised pre-training. FedBE [7] introduces the Bayesian view on sampling global model for aggregation. Another line of work is using KD to reduce the communication cost [12, 46, 53, 65]. In [12], one-round FL using the ensemble of local models is proposed. Instead of distilling from direct ensembles of clients, DOSFL [65], applied the dataset distillation [53]. The previous works require either additional local information [17, 45, 65] or auxiliary data [7, 12, 34, 46]. However, as discussed in [29], the partial global sharing of local information violates privacy, and using globally-shared auxiliary data should be cautiously generated or collected. On the other hand, our approach does not require any local information or auxiliary data. 6 Conclusion In this paper, we suggest that forgetting phenomenon occurs in Federated Learning scenarios, as in Continual Learning. By analyzing the features after local training, we find that fitting on heterogeneous local distributions shifts the features on global distribution, thereby cause the forgetting of global knowledge. To relieve this problem, we propose FedLSD by maintaining the global view on local data. We further extend it to FedLS-NTD by focusing on the not-true class relationship only. In the experiments, we validate the efficacy of proposed algorithms and demonstrate their effect on global knowledge preservation and feature consistency. 9
7 Broader Impact We believe that Federated Learning is an important learning paradigm that enables privacy-preserving ML. This paper suggests the forgetting issue and introduces the methods to relive it without com- promising data privacy. The insight behind this work may inspire new researches. However, the proposed method maintains the knowledge outside of local distribution in the global model. This implies that if the global model is biased, the trained local model is more prone to have a similar tendency. This should be considered for ML participators. References [1] Mohammed Aledhari, Rehma Razzak, Reza M Parizi, and Fahad Saeed. Federated learning: A survey on enabling technologies, protocols, and applications. IEEE Access, 8:140699–140725, 2020. [2] Rahaf Aljundi, Francesca Babiloni, Mohamed Elhoseiny, Marcus Rohrbach, and Tinne Tuyte- laars. Memory aware synapses: Learning what (not) to forget. In Proceedings of the European Conference on Computer Vision (ECCV), pages 139–154, 2018. [3] Rahaf Aljundi, Punarjay Chakravarty, and Tinne Tuytelaars. Expert gate: Lifelong learning with a network of experts. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 3366–3375, 2017. [4] Sebastian Caldas, Sai Meher Karthik Duddu, Peter Wu, Tian Li, Jakub Konečnỳ, H Brendan McMahan, Virginia Smith, and Ameet Talwalkar. Leaf: A benchmark for federated settings. arXiv preprint arXiv:1812.01097, 2018. [5] Arslan Chaudhry, Albert Gordo, Puneet K Dokania, Philip Torr, and David Lopez-Paz. Using hindsight to anchor past knowledge in continual learning. arXiv preprint arXiv:2002.08165, 2020. [6] Arslan Chaudhry, Naeemullah Khan, Puneet K Dokania, and Philip HS Torr. Continual learning in low-rank orthogonal subspaces. arXiv preprint arXiv:2010.11635, 2020. [7] Hong-You Chen and Wei-Lun Chao. Fedbe: Making bayesian model ensemble applicable to federated learning, 2021. [8] G. Cohen, S. Afshar, and A Tapson, J.and van Schaik. Emnist: an extension of mnist to handwritten letters. arXivpreprint arXiv:1702.05373, 2017. [9] Jeffrey Dean, Greg Corrado, Rajat Monga, Kai Chen, Matthieu Devin, Mark Mao, Marc' aurelio Ranzato, Andrew Senior, Paul Tucker, Ke Yang, Quoc Le, and Andrew Ng. Large scale distributed deep networks. In F. Pereira, C. J. C. Burges, L. Bottou, and K. Q. Weinberger, editors, Advances in Neural Information Processing Systems, volume 25. Curran Associates, Inc., 2012. [10] Tommaso Furlanello, Zachary Lipton, Michael Tschannen, Laurent Itti, and Anima Anandkumar. Born again neural networks. In International Conference on Machine Learning, pages 1607– 1616. PMLR, 2018. [11] Ian J Goodfellow, Mehdi Mirza, Da Xiao, Aaron Courville, and Yoshua Bengio. An empirical investigation of catastrophic forgetting in gradient-based neural networks. arXiv preprint arXiv:1312.6211, 2013. [12] Neel Guha, Ameet Talwalkar, and Virginia Smith. One-shot federated learning. arXiv preprint arXiv:1902.11175, 2019. [13] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016. [14] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015. [15] Saihui Hou, Xinyu Pan, Chen Change Loy, Zilei Wang, and Dahua Lin. Learning a unified classifier incrementally via rebalancing. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 831–839, 2019. [16] Tzu-Ming Harry Hsu, Hang Qi, and Matthew Brown. Measuring the effects of non-identical data distribution for federated visual classification. arXiv preprint arXiv:1909.06335, 2019. 10
[17] Eunjeong Jeong, Seungeun Oh, Hyesung Kim, Jihong Park, Mehdi Bennis, and Seong-Lyun Kim. Communication-efficient on-device machine learning: Federated distillation and augmen- tation under non-iid private data. arXiv preprint arXiv:1811.11479, 2018. [18] Peng Jiang and Gagan Agrawal. A linear speedup analysis of distributed deep learning with sparse and quantized communication. In Proceedings of the 32nd International Conference on Neural Information Processing Systems, pages 2530–2541, 2018. [19] Sai Praneeth Karimireddy, Satyen Kale, Mehryar Mohri, Sashank Reddi, Sebastian Stich, and Ananda Theertha Suresh. Scaffold: Stochastic controlled averaging for federated learning. In International Conference on Machine Learning, pages 5132–5143. PMLR, 2020. [20] Hsieh Kevin, Phanishayee Amar, Mutlu Onur, and Gibbons Phillip B. The non-iid data quagmire of decentralized machine learning. In Proceedings of the 37th International Conference on Machine Learning, Online, PMLR 119, 2020. [21] Ahmed Khaled, Konstantin Mishchenko, and Peter Richtárik. Tighter theory for local sgd on identical and heterogeneous data. In International Conference on Artificial Intelligence and Statistics, pages 4519–4529. PMLR, 2020. [22] Taehyeon Kim, Jaehoon Oh, NakYil Kim, Sangwook Cho, and Se-Young Yun. Comparing kullback-leibler divergence and mean squared error loss in knowledge distillation. arXiv preprint arXiv:2105.08919, 2021. [23] James Kirkpatrick, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins, Andrei A Rusu, Kieran Milan, John Quan, Tiago Ramalho, Agnieszka Grabska-Barwinska, et al. Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences, 114(13):3521–3526, 2017. [24] Jakub Konečnỳ, H Brendan McMahan, Daniel Ramage, and Peter Richtárik. Federated optimiza- tion: Distributed machine learning for on-device intelligence. arXiv preprint arXiv:1610.02527, 2016. [25] Jakub Konečnỳ, H Brendan McMahan, Felix X Yu, Peter Richtárik, Ananda Theertha Suresh, and Dave Bacon. Federated learning: Strategies for improving communication efficiency. arXiv preprint arXiv:1610.05492, 2016. [26] Simon Kornblith, Mohammad Norouzi, Honglak Lee, and Geoffrey Hinton. Similarity of neural network representations revisited. In International Conference on Machine Learning, pages 3519–3529. PMLR, 2019. [27] Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. Cifar-10 and cifar-100 datasets. URl: https://www. cs. toronto. edu/kriz/cifar. html, 6, 2009. [28] Tian Li, Anit Kumar Sahu, Ameet Talwalkar, and Virginia Smith. Federated learning: Chal- lenges, methods, and future directions. IEEE Signal Processing Magazine, 37(3):50–60, 2020. [29] Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar, and Virginia Smith. Federated optimization in heterogeneous networks. arXiv preprint arXiv:1812.06127, 2018. [30] Xiang Li, Kaixuan Huang, Wenhao Yang, Shusen Wang, and Zhihua Zhang. On the convergence of fedavg on non-iid data. arXiv preprint arXiv:1907.02189, 2019. [31] Yiying Li, Wei Zhou, Huaimin Wang, Haibo Mi, and Timothy M Hospedales. Fedh2l: Federated learning with model and statistical heterogeneity. arXiv preprint arXiv:2101.11296, 2021. [32] Zhizhong Li and Derek Hoiem. Learning without forgetting. IEEE transactions on pattern analysis and machine intelligence, 40(12):2935–2947, 2017. [33] Wei Yang Bryan Lim, Nguyen Cong Luong, Dinh Thai Hoang, Yutao Jiao, Ying-Chang Liang, Qiang Yang, Dusit Niyato, and Chunyan Miao. Federated learning in mobile edge networks: A comprehensive survey. IEEE Communications Surveys & Tutorials, 22(3):2031–2063, 2020. [34] Tao Lin, Lingjing Kong, Sebastian U Stich, and Martin Jaggi. Ensemble distillation for robust model fusion in federated learning. arXiv preprint arXiv:2006.07242, 2020. [35] Michael McCloskey and Neal J Cohen. Catastrophic interference in connectionist networks: The sequential learning problem. In Psychology of learning and motivation, volume 24, pages 109–165. Elsevier, 1989. 11
[36] Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. Communication-efficient learning of deep networks from decentralized data. In Artificial Intelligence and Statistics, pages 1273–1282. PMLR, 2017. [37] Brendan McMahan and Daniel Ramage. Collaborative machine learning without centralized training data. https://research.googleblog.com/2017/04/federated-learning-collaborative.html, 2017. [38] Mehryar Mohri, Gary Sivek, and Ananda Theertha Suresh. Agnostic federated learning. In International Conference on Machine Learning, pages 4615–4625. PMLR, 2019. [39] Cuong V Nguyen, Yingzhen Li, Thang D Bui, and Richard E Turner. Variational continual learning. arXiv preprint arXiv:1710.10628, 2017. [40] German I Parisi, Ronald Kemker, Jose L Part, Christopher Kanan, and Stefan Wermter. Continual lifelong learning with neural networks: A review. Neural Networks, 113:54–71, 2019. [41] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas Kopf, Edward Yang, Zachary DeVito, Martin Raison, Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. Pytorch: An imperative style, high- performance deep learning library. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché- Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems 32, pages 8024–8035. Curran Associates, Inc., 2019. [42] Sylvestre-Alvise Rebuffi, Alexander Kolesnikov, Georg Sperl, and Christoph H Lampert. icarl: Incremental classifier and representation learning. In Proceedings of the IEEE conference on Computer Vision and Pattern Recognition, pages 2001–2010, 2017. [43] Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečnỳ, Sanjiv Kumar, and H Brendan McMahan. Adaptive federated optimization. arXiv preprint arXiv:2003.00295, 2020. [44] Mark B Ring. Child: A first step towards continual learning. In Learning to learn, pages 261–292. Springer, 1998. [45] Felix Sattler, Tim Korjakow, Roman Rischke, and Wojciech Samek. Fedaux: Leveraging unlabeled auxiliary data in federated learning. arXiv preprint arXiv:2102.02514, 2021. [46] Felix Sattler, Arturo Marban, Roman Rischke, and Wojciech Samek. Communication-efficient federated distillation. arXiv preprint arXiv:2012.00632, 2020. [47] Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556, 2014. [48] Sebastian U Stich. Local sgd converges fast and communicates little. arXiv preprint arXiv:1805.09767, 2018. [49] Sebastian Thrun. Lifelong learning algorithms. In Learning to learn, pages 181–209. Springer, 1998. [50] Yonglong Tian, Dilip Krishnan, and Phillip Isola. Contrastive representation distillation. arXiv preprint arXiv:1910.10699, 2019. [51] Hongyi Wang, Mikhail Yurochkin, Yuekai Sun, Dimitris Papailiopoulos, and Yasaman Khazaeni. Federated learning with matched averaging. arXiv preprint arXiv:2002.06440, 2020. [52] Jianyu Wang, Qinghua Liu, Hao Liang, Gauri Joshi, and H Vincent Poor. Tackling the objective inconsistency problem in heterogeneous federated optimization. arXiv preprint arXiv:2007.07481, 2020. [53] Tongzhou Wang, Jun-Yan Zhu, Antonio Torralba, and Alexei A Efros. Dataset distillation. arXiv preprint arXiv:1811.10959, 2018. [54] Blake Woodworth, Kumar Kshitij Patel, Sebastian Stich, Zhen Dai, Brian Bullins, Brendan Mcmahan, Ohad Shamir, and Nathan Srebro. Is local sgd better than minibatch sgd? In International Conference on Machine Learning, pages 10334–10343. PMLR, 2020. [55] Blake Woodworth, Jialei Wang, Adam Smith, Brendan McMahan, and Nathan Srebro. Graph oracle models, lower bounds, and gaps for parallel stochastic optimization. arXiv preprint arXiv:1805.10222, 2018. 12
[56] Qiang Yang, Yang Liu, Tianjian Chen, and Yongxin Tong. Federated machine learning: Concept and applications. ACM Transactions on Intelligent Systems and Technology (TIST), 10(2):1–19, 2019. [57] Junho Yim, Donggyu Joo, Jihoon Bae, and Junmo Kim. A gift from knowledge distillation: Fast optimization, network minimization and transfer learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 4133–4141, 2017. [58] Tehrim Yoon, Sumin Shin, Sung Ju Hwang, and Eunho Yang. Fedmix: Approximation of mixup under mean augmented federated learning. In International Conference on Learning Representations, 2021. [59] Hao Yu, Rong Jin, and Sen Yang. On the linear speedup analysis of communication efficient momentum sgd for distributed non-convex optimization. In International Conference on Machine Learning, pages 7184–7193. PMLR, 2019. [60] Mikhail Yurochkin, Mayank Agarwal, Soumya Ghosh, Kristjan Greenewald, Nghia Hoang, and Yasaman Khazaeni. Bayesian nonparametric federated learning of neural networks. In International Conference on Machine Learning, pages 7252–7261. PMLR, 2019. [61] Mikhail Yurochkin, Mayank Agarwal, Soumya Ghosh, Kristjan Greenewald, Nghia Hoang, and Yasaman Khazaeni. Bayesian nonparametric federated learning of neural networks. In International Conference on Machine Learning, pages 7252–7261. PMLR, 2019. [62] Friedemann Zenke, Ben Poole, and Surya Ganguli. Continual learning through synaptic intelligence. In International Conference on Machine Learning, pages 3987–3995. PMLR, 2017. [63] Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, and Vikas Chandra. Federated learning with non-iid data. arXiv preprint arXiv:1806.00582, 2018. [64] Helong Zhou, Liangchen Song, Jiajie Chen, Ye Zhou, Guoli Wang, Junsong Yuan, and Qian Zhang. Rethinking soft labels for knowledge distillation: A bias-variance tradeoff perspective. arXiv preprint arXiv:2102.00650, 2021. [65] Yanlin Zhou, George Pu, Xiyao Ma, Xiaolin Li, and Dapeng Wu. Distilled one-shot federated learning. arXiv preprint arXiv:2009.07999, 2020. [66] Xiatian Zhu, Shaogang Gong, et al. Knowledge distillation by on-the-fly native ensemble. In Advances in neural information processing systems, pages 7517–7527, 2018. 13
A Experimental setup details A.1 Models Here we provide details on the models we use in our experiments. All the following models are implemented by PyTorch [41] with slight modification on them. CifarCNN This model architecture is from [36] on CIFAR experiments, which is composed of two convolutional layers followed by two fully connected layers, and a linear transformation layer for the last logits outputs. VggNet-11 VggNet was proposed by [47]. The architecture consists of stacks of convolutional layers and max-pooling, followed by three fully connected layers, and again the last linear transforma- tion layer to produce logits. In our experiments, we use VggNet-11 which is composed of cascaded 11 convolutional layers with max-pooling layers. The original architecture implemented by PyTorch [41] has batch normalization. However, we skipped batch normalization by following [20]. ResNet-8 also uses a sequence of convolutional layers like VggNet. However, ResNet makes use of shortcut connections to resolve the vanishing gradient problem and to improve its performance[13]. We implement ours by removing the batch normalization layers. A.2 Datasets In the pre-processing of all the image datasets, we followed the normalization process and the standard data augmentation such as random cropping, horizontal random flipping. FEMNIST Federated Extended MNIST was built by partitioning EMNIST handwritten image dataset [8] based on the writer of the digit/characters. 3500 writers generated 62 different classes, including 10 digits, 26 lowercase, and 26 uppercase. To generated heterogeneous data partitions, we regarded each writer as local clients and selected 100 clients with more than 250 data samples. 10% of data for each client is used for the test set. CIFAR10 & CIFAR100 CIFAR10 and CIFAR100 datasets [27] are the most widely used image datasets and consist of 10 and 100 classes, respectively. Each dataset consists of 5000 and 500 training samples for one class, respectively. For label heterogeneous partitions(max class / client), we followed the method of [36]. Since the number of local clients is fixed to 100, one local client has 500 data samples. A.3 Training setups We use initial learning rate 0.1 (for CIFARCNN & VGG-11) and 0.3 (for ResNet-8). The learning rate is decayed with a factor of 0.99 for each 5 communication rounds. The momentum is set as 0.9 in the experiment. The momentum is only used in the local training, which implies the momentum information is not communicated between the server and local clients. The weight decay is set as 0. The hyperparameters for learning setups for different classes are in Table 5. Table 5: Training setup details Datasets FEMNIST CIFAR-10 CIFAR-100 local epochs 5 5 5 local batch size 10 50 50 data samples / client 300 (average) 500 500 max class / client (δ) 50 (average) 2 20 total dataset classes 62 10 100 µ (for FedProx) 0.5 0.1 0.05 A.4 Hyperparameters We use the same hyperparameters for FedLSD and FedLS-NTD across different datasets: FedLSD(β = 0.3, τ = 3) and FedLS-NTD(β = 0.3, τ = 1). Since the difference between logits scale is much smaller when the true class logits are excluded, FedLS-NTD requires a smaller temperature for softening the global model’s prediction. A.5 GPU resources We use 1 Titan-RTX and 1 RTX 2080Ti GPU card. Multi-GPU training is not conducted in the paper experiments. 14
B Proof of Proposition 1 On this section, we provide the proof of the following Proposition 1. Proposition 1. Consider the FedAvg[36] setting with uniformly weighted K = |K (t) | local clients holding class distribution Pk(t) . We train the global and local clients by LFedLSD (τ=1) with convex class landscape w 7→ EDc [LFedLSD (·, w)], where Dc is the data distribution of c-th class. Assume ∂z that w 7→ ∂w (·, w) is λJc -Lipschitz /B-bounded and w 7→ q(·, w) is λqc -Lipschitz with k · kL2 (Dc ) . √ Then, if 0 < η < 2 minc ( 2λJc + Bλqc )-1 , the weight divergence between averaged local wt+1 and t+1 global w(GD) after one round is bounded by the following inequality. E−1 t+1 t+1 ηB X (t) X t,i g kw − w(GD) k ≤ χ(P Pk ) q(·, wk ) − β q + (1 − β) py (t) , K i=0 L2 Dk k∈R(t) (t) where Dk is the data on k-th local and q g = q(x, wt ) is the global prediction at last round. Proof. First, we arrange our notations and assumptions: C C (t) X X • w ∈ W (bounded) ⊂ Rm , Dk = (t) Pk (c) Dc , and D= P(c)Dc , c=1 c=1 where D is the global data distribution and P(c) is the probability for c-th class on D. Z • w 7→ EDc [LFedLSD (·, w)] = LFedLSD (x, w) dP(x) is convex for each c. Dc Here, we note that convexity on D (which is weaker) is in fact enough and more reasonable. ∂z • For each class c, mapping from weight w to jacobian (·, w) satisfies: sZ ∂w 2 ∂z ∂z (x, w2 ) − (x, w1 ) dP(x) ≤ λJc w2 − w1 2 ∀w1 , w2 ∈ W Dc ∂w ∂w F sZ 2 ∂z (x, w) dP(x) ≤ B ∀w ∈ W Dc ∂w F • For each class c, mapping from weight w to prediction on data q(·, w) satisfies: sZ 2 q(x, w2 ) − q(x, w1 ) dP(x) ≤ λqc w2 − w1 2 ∀w1 , w2 ∈ W Dc F The differential of the softmax probability is the following: ezc0 ezc0 ezc ∂q(c0 ) ∂ = = z1 1(c = c0 ) − z1 ∂zc ∂zc ez1 + · · · + ezC e + · · · + ezC e + · · · + ezC = q(c0 ) 1(c = c0 ) − q(c) , i.e. ∂q(c0 ) = −q(c0 ) (1c0 − q). ∂z Next, we provide the differential of LCE , LLSD and LFedLSD with respect to the logit z. ∂LCE ∂(− log(q(true))) = = −(py − q) ∂z ∂z C X C ∂LLSD ∂ X g q(i) ∂(− log(q(i)) = −q (i) log g = q g (i) = −(q g − q) ∂z ∂z i=1 q (i) i=1 ∂z ∂LFedLSD ∂LLSD ∂LCE =β + (1 − β) = (q − (βq g + (1 − β)py )) ∂z ∂z ∂z Now, we prove the following lemma. 15
√ Lemma 1. Given w1 , w2 ∈ W and η ∈ 0, 2 minc ( 2λJc + Bλqc )-1 , the following inequality holds for all β ∈ [0, 1] and τ = 1. We simply denote LFedLSD (τ = 1, β)(w) as LFedLSD (w) if ther is no conflict. w2 − η∇w ED LFedLSD (w2 ) − w1 − η∇w ED LFedLSD (w1 ) ≤ w2 − w1 2 2 ∵) First, we prove that w 7→ ∇w ED LFedLSD (w) is Lipschitz. The gradient of LSD loss has the following form: ∂LFedLSD ∂z ∂z ∇w LFedLSD = = q − (βq g + (1 − β)py . ∂z ∂w ∂w Therefore, ∇w ED LFedLSD (·, w2 ) − ∇w ED LFedLSD (·, w1 ) 2 = ∇w ED LFedLSD (w2 ) − LFedLSD (w1 ) 2 = ED ∇w LFedLSD (w2 ) − ∇w LFedLSD (w1 ) 2 g ∂z g ∂z = ED (q(w2 ) − (βq + (1 − β)py ) (w2 ) − (q(w1 ) − (βq + (1 − β)py ) (w1 ) ∂w ∂w 2 " # ∂z ∂z ∂z = ED (q(w2 ) − (βq g + (1 − β)py ) (w2 ) − (w1 ) − (q(w1 ) − q(w2 )) (w1 ) ∂w ∂w ∂w 2 ∂z ∂z ∂z ≤ ED (q(w2 ) − (βq g + (1 − β)py ) (w2 ) − (w1 ) +ED (q(w1 ) − q(w2 )) (w1 ) ∂w ∂w 2 ∂w 2 ∂z ∂z ≤ ED q(w2 ) − (βq g + (1 − β)py 2 (w2 ) − (w1 ) ∂w ∂w F ∂z + ED q(w1 ) − q(w2 ) 2 (w1 ) ∂w F C X g ∂z ∂z = P(c) EDc q(w2 ) − (βq + (1 − β)py 2 (w2 ) − (w1 ) c=1 ∂w ∂w F C X ∂z + P(c) EDc q(w1 ) − q(w2 ) 2 (w1 ) c=1 ∂w F C X ∂z ∂z ≤ P(c) q(·, w2 ) − (βq g (·) + (1 − β)py (·) 2 (·, w2 ) − (·, w1 ) c=1 ∂w ∂w F L1 (Dc ) C X ∂z + P(c) q(w1 ) − q(w2 ) 2 (w1 ) c=1 ∂w F L1 (Dc ) C X ∂z ∂z ≤ P(c) q(·, w2 ) − (βq g (·) + (1 − β)py (·) 2 (·, w2 ) − (·, w1 ) c=1 L2 (Dc ) ∂w ∂w F L2 (Dc ) C X ∂z + P(c) q(w1 ) − q(w2 ) 2 (w1 ) (Hölder’s Inequality) c=1 L2 (Dc ) ∂w F L2 (Dc ) C X √ XC √ ≤ P(c) 2λJc kw2 − w1 k2 + λqc kw2 − w1 k2 B = P(c) 2λJc + Bλqc kw2 − w1 k2 c=1 c=1 ! √ ≤ max ( 2λJc + Bλqc ) kw2 − w1 k2 =: Γ kw2 − w1 k2 c∈[C] Since w 7→ ED [LFedLSD (·, w)] is convex, Γ-Lipschitzness of differential implies Γ-smoothness, i.e. 2 ∇w ED LFedLSD (·, w2 ) − ∇w ED LFedLSD (·, w1 ) 2 D E ≤ Γ ∇w ED LFedLSD (·, w2 ) − ∇w ED LFedLSD (·, w1 ) , w2 − w1 16
Finally, if 0 < η < 2/Γ, 2 w2 − η∇w ED LFedLSD (·, w2 ) − w1 − η∇w ED LFedLSD (·, w1 ) 2 D E 2 = kw2 − w1 k2 − 2η ∇w ED LFedLSD (·, w2 ) − ∇w ED LFedLSD (·, w1 ) , w2 − w1 2 + η 2 ∇w ED LFedLSD (·, w2 ) − ∇w ED LFedLSD (·, w1 ) 2 2 2 2 2 = kw2 − w1 k2 − 2ηh· · · i + η 2 k· · ·k = kw2 − w1 k2 − η 2 (2/ηh· · · i − k· · ·k ) 2 2 2 ≤ kw2 − w1 k2 − η 2 (Γh· · · i − k· · ·k ) ≤ kw2 − w1 k2 /// Now, we provide the main proposition. We initialize the weight of local clients wkt,0 = wt , perform E epochs on global and local clients and compare averaged weight on local clients wt+1 with weight t+1 of global model w(GD) . The idea for proof is developed from [cite]. First, note that t+1 1 X t,E t,E 1 X t,E t,E wt+1 − w(GD) 2 = wk − w(GD) ≤ wk − w(GD) 2 . K K k 2 k Next, for each i ∈ {1, · · · , E}, h i h i wkt,i − w(GD) 2 t,i = w t,i−1 k − η∇ w EDk (t) LFedLSD (w t,i−1 k ) − w t,i−1 (GD) + η∇ w E D L FedLSD (w t,i−1 (GD) ) 2 h i h i t,i−1 t,i−1 t,i−1 t,i−1 ≤ wk − η∇w ED LFedLSD (wk ) − w(GD) + η∇w ED LFedLSD (w(GD) ) 2 h i h i t,i−1 t,i−1 + η∇w ED LFedLSD (wk ) − η∇w ED(t) LFedLSD (wk ) k 2 C X ∂z t,i−1 ≤ wkt,i−1 − w(GD) t,i−1 + η | (t) t,i−1 Pk (c) − P(c)| EDc (q(wk ) − (βq g + (1 − β)p y ) (w ) . 2 c=1 ∂w k (Lemma 1) The second term of the last equation can be bounded as: C X ∂z t,i−1 η |Pk(t) (c) − P(c)| EDc (q(wkt,i−1 ) − (βq g + (1 − β)py ) (wk ) c=1 ∂w 2 C (t) | Pk (c) − P(c)| X q (t) ≤η q Pk (c) EDc [· · · ] (t) 2 c=1 Pk (c) v v v u C u C u C uX |P (t) (c) − P(c)|2 uX 2 uX (t) t 2 k (t) (t) ≤η t (t) t Pk (c) EDc [· · · ] = ηχ(P Pk ) Pk (c) EDc [· · · ] c=1 Pk (c) c=1 2 c=1 2 r 2 (t) (t) ≤ ηχ(P Pk ) ED(t) [· · · ] = ηχ(P Pk ) ED(t) [· · · ] k 2 k 2 (t) t,i−1 g ∂z t,i−1 ≤ ηχ(P Pk ) ED(t) (q(wk ) − (βq + (1 − β)py ) (w ) k ∂w k 2 (t) ≤ ηBχ(P Pk )ED(t) kq(wkt,i−1 ) g − (βq + (1 − β)py )k2 (Jensen’s Inequality) k (t) Pk ) q(·, wkt,i−1 ) − β q g + (1 − β) py ≤ ηBχ(P 2 (t) . (Hölder’s Inequality) L Dk Thus, (t) wkt,i − w(GD) t,i ≤ wkt,i−1 − w(GD) t,i−1 Pk ) q(·, wkt,i−1 ) − β q g + (1 − β) py 2 2 +ηBχ(P (t) . L2 Dk 17
Consequently, we get t+1 1 X t,E t,E wt+1 − w(GD) 2 ≤ wk − w(GD) 2 K k E 1 X X t,i t,i t,i−1 t,i−1 = wk − w(GD) 2 − w k − w (GD) 2 (wkt,0 = w(GD) t,0 = wt ) K k i=1 E ηB X (t) X q(·, wkt,i−1 ) − β q g + (1 − β) py ≤ χ(P Pk ) (t) . K i=1 L2 Dk k∈R(t) C Local features visualization We further conduct an additional experiment on features of the trained local model. The experimental setup is same with Section 2, but this time, we trained the global model for 100 communication rounds on heterogeneous locals and distributed over 10 homogeneous locals and 10 heterogeneous locals. In the homogeneous local case Figure 10a, the features are clustered by classes, regardless of which local they are learned from. On the other hand, in the heterogeneous local case Figure 10b, the features are clustered by which local distribution is learned. The same visualization, but colored by a different view, is in Figure 11. (a) Homogeneous Locals (colored by classes) (b) Heterogeneous Locals (colored by locals) Figure 10: T-SNE visualization of features on CIFAR-10 test samples after local training on (a) homogeneous local distributions and (b) heterogeneous local distributions. The T-SNE is conducted together for the test sample features of global and 10 local models. (a) Homogeneous Locals (colored by classes) (b) Heterogeneous Locals (colored by locals) Figure 11: T-SNE visualization of feature region shifting on CIFAR-10 test samples after local training on (a) Homogeneous local distributions and (b) heterogeneous local distributions 18
You can also read