宋華偉,李升起,萬方杰,衛玉萍
(鄭州大學網絡空間安全學院,河南 鄭州 450000)
在互聯網時代,全球數百億的聯網設備產生的數據呈指數級增長[1-2]。充分利用這些海量數據可以助力于建立更復雜、準確的神經網絡模型[3],提高神經網絡模型的質量。然而,現實中的數據由于數據隱私、行業競爭等限制[4],數據本身的整合存在巨大的阻礙,集中式訓練的方式變得越來越不可行。
由于上述問題的出現,聯邦學習(FL)得到了越來越多的關注[5]。聯邦學習技術由谷歌于2016 年首次提出[6],核心思想是在保護數據隱私的前提下,實現多方參與的訓練,解決數據集中化和數據孤島問題。聯邦學習采用分布式的訓練過程,客戶端利用本地數據更新局部模型,無需上傳本地數據,僅將更新后的局部模型參數上傳到服務器,不斷交互、更新執行,直到全局模型收斂或達到預定的訓練輪數。聯邦學習技術很好地平衡了大量數據、數據隱私與數據價值之間的矛盾。
但是,聯邦學習擁有優勢的同時也帶來新的問題。聯邦學習通常涉及大量客戶端,這些參與方數據的獨立同分布(IID)程度對模型訓練的最終效果有重要的影響[7-8]。但在現實情況中,每個參與方的本地數據總是非獨立同分布(Non-IID)的,例如:在醫療領域,不同的參與方可能擁有來自不同病人的醫療圖像數據,數據在大小、顏色、對比度、亮度等方面均存在差異;綜合性醫院每種類型的數據可能都比較全面,而專科醫院只有某一類的數據比較全面,且這一類數據的質量較高。因此,如何設計一個在數據Non-IID 下的聯邦學習方法,實現更好的學習效果,對聯邦學習的發展和應用具有重要的現實意義,同時也是本文研究的問題。
為了解決Non-IID 數據場景下模型準確率下降的問題,相關學者進行了一系列的研究。ZHAO等[9]的研究表明,在CIFAR-10 數據集中只需共享5%的數據子集就能夠提高30%的準確率,但是共享數據可能會泄露數據隱私。LI等[10]給模型的目標函數加上一個限制項,用于限制局部模型和全局模型的差異,以此來減小數據異構帶來的影響,然而,FedProx相較于FedAvg 提升 較為有限。WANG等[11]考慮了參與方可能在每輪需要執行不同數量的局部步驟,因此,為了確保全局更新沒有偏差,在全局聚合之前對其每一方的局部更新進行歸一化和縮放,消除目標不一致性,保持全局模型快速大幅度收斂。文獻[12]采用集群學習方式將具有相似分布的參與方聚集到固定的簇中,并為不同簇訓練不同的全局模型以適應其固有數據分布,但這種方式得到的全局模型泛化性能差。FedFa[13]將客戶端模型的準確率和選中次數作為本地數據質量的衡量標準,并為每個客戶端賦予相應的聚合權重。
另一個研究方向是將一些技術融入聯邦學習,諸如多目標學習[14-15]、元學習[16]、持續學習[17]、對抗學習[18]、區塊鏈[19-20]等技術被廣泛應用于Non-IID 聯邦學習場景中。其中一個很好的研究方向是將持續學習應用于聯邦學習領域。SHOHAM等[21]基于聯邦學習與持續學習的類比,將災難性遺忘的解決方案應用于非獨立同分布數據下聯邦學習存在的局部模型漂移問題。FedLSD[22]側重于本地更新學習到的知識,通過蒸餾獲取全局模型的知識。這些方法減少了數據異質帶來的影響,但是不同客戶端數據分布不同,學習到的知識也存在差異,全局模型聚合時仍會存在差異的干擾。
基于上述研究成果不難發現,數據非獨立同分布下的聯邦學習優化方法可以從全局模型聚合和本地客戶端更新兩個角度加以改進。但大多數方法都是在一個角度進行優化,難免不會引發另一角度帶來的影響,降低全局模型的質量。針對這種現象,本文提出了分層持續學習的聯邦學習優化方法(FedMas)。在FedMas 中,將數據非獨立同分布導致的全局模型準確率低的問題建模為持續學習任務。考慮一個極端的例子:假設有10 個參與方,每個參與方擁有MNIST 數據集的其中一類的全部數據,采用聯邦平均算法[6]進行訓練,每次全局模型聚合都會因權重發散導致準確率低,但如果用持續學習的災難性遺忘的解決方案去訓練,則可以融合不用任務的特征,提高全局模型的準確率。
FedMas 將參與方根據其數據分布特征劃分邏輯層,單個層中的參與方的數據分布相似,避免隨機抽取時由于數據分布不同導致權重發散、準確率降低的問題,推動全局模型訓練更快地收斂。由于層間數據分布不同,為了學習不同數據分布的特征,本文采用持續學習算法-記憶感知突觸算法[23]融合不同數據間的差異性。簡言之,FedMas 算法通過聚類分層減少不同數據分布客戶端聚合時的干擾,通過記憶感知突觸算法持續學習有益的全局知識,以最大化提高全局模型的收斂速度和模型質量。
在聯邦學習場景中,全局服務器和參與方通過網絡連接,模型訓練使用的訓練數據是分散在各個邊緣設備上的,通過迭代的全局聚合和更新來實現模型的訓練。聯邦學習的優化目標是最小化所有樣本的平均損失,如式(1)所示:
Fk(w)代表了局部數據的分布信息,當所有的參與方都是獨立同分布的時候,可以得到式(2):
即當客戶端上的數據與總體數據分布相同時,客戶端上的預測損失應與全局的預測損失期望相同,通過多次同步后,其聚合后的全局模型能夠逼近集中式訓練的模型;而當數據不滿足獨立同分布假設時,模型擬合自身所持有的數據集,造成參數方向的分歧,并且隨著同步的次數增多,分歧越來越大,在服務器端聚合時偏移全局最優解,如圖1 所示。

圖1 Non-IID 數據訓練時的模型偏移Fig.1 Model offset during Non-IID data training
在上述方法的基礎上,本文提出了FedMas 方法。FedMas 將整個過程分為兩個部分:按數據分布對客戶端分層,以及對不同層進行知識融合。FedMas 的整體架構如圖2 所示,算法描述見算法1。

圖2 FedMas 整體架構Fig.2 Overall architecture of FedMas
算法1FedMas
數據分布不同的客戶端在聚合時會受到其他客戶端學習知識的干擾。為了避免這個問題,本文將所有客戶端進行了分層。文獻[12]為每個層訓練一個全局模型,這樣做的后果是不能很好地利用聯邦學習維持的大規模數據集的優勢,數據量少的客戶端容易過擬合,因此,本文對聚合的層進行了知識融合。分層的具體過程如下:
在服務器端初始化一個全局分類模型,并將全局模型廣播至所有客戶端對本地數據的樣本進行e輪次的本地訓練,每個客戶端利用本地的數據集按式(3)進行參數更新:
其中:η是學習率;xi是客戶端的數據樣本;?(w0;xi)為參數w0的損失函數的梯度。
訓練結束后,客戶端上傳本地模型參數到服務器端。因為更新之前的模型參數相同,更新過程中只有數據不同,所以得到的新的模型參數僅僅包含了客戶端的數據分布信息。采用DBSCAN 聚類[24]方法對收集到的模型參數進行聚類,將客戶端劃分到不同的層Tier 中,聚類分層后,每個層中客戶端的數據分布相似(見算法1 中的第1~4 行)。
如何融合不同層的知識是本文算法分層后要解決的關鍵問題,持續學習為解決這個問題提供了很好的思路。持續學習可以在學習后一個任務時不忘記前一個任務學習的知識,經過不斷的發展已經取得不錯的成果,其中記憶感知突觸算法是一個成功的方法。同樣作為基于正則化的方法,和彈性權值合并算法相比,記憶感知突觸算法的重要性權重矩陣獲取方式能夠從無標簽數據中學習,這個屬性使得該方法能夠應用在沒有訓練數據的場景下并且其占用的內存更小。因此,記憶感知突觸算法更適合用于聯邦學習場景。記憶感知突觸算法通過計算網絡模型中每個參數對于該任務的重要性,并沿用到訓練后續的任務的方式,以保持對原數據集較好的分類性能。記憶感知突觸算法損失函數如式(4)所示:
其中:Ln(w)為新任務(第n個任務)的損失函數損失函數;Ωij表示每個參數對于該任務的重要性;是由前n-1 個任務訓練后得到的模型,同時也是用于訓練第n個任務的初始模型的參數;λ為一個正則項的可調的超參數。
此外,式(4)中Ωij為重要性權重矩陣,原作者使用L2 范數的平分的偏導代替,具體如式(5)所示:
本文將Non-IID 數據學習問題建模為持續學習任務,考慮到Non-IID 場景下聚合權重發散以及持續學習任務之間的順序性,本文沒有將每個邊緣設備視為一個單獨的學習“任務”,而是設計了FedMas,將具有相似原型的設備分組在一起,并將每組設備視為學習任務。
FedMas 算法需要執行C個通信輪次,在每個通信輪次內所有層按序參與訓練,每個層訓練時只隨機抽取層內的一部分客戶端,并采用加權聚合更新全局模型。在Tier 中第一次選取層訓練時,因為全局模型為初始化參數,所以損失函數為交叉熵損失,不加入記憶感知突觸算法項,損失函數如式(6)所示:
其中:n為訓練集的樣本大小;y為訓練數據的標簽向量;a為神經網絡模型的輸出向量。
從第2 次選擇層訓練開始直到訓練結束,客戶端接收到的模型來自上一層,模型在本層訓練時參數的變動可能會覆蓋神經網絡在舊數據上所學的知識。為了緩解這個問題,本文在本地訓練時引入記憶感知突觸算法,通過盡量減少舊任務上重要參數的改變幅度,以期同時在不同任務上取得良好的效果。損失函數為交叉熵損失和記憶感知突觸算法項的和,如式(4)所示。記憶感知突觸算法的Ωij項一般是在舊數據集上進行計算,考慮到聯邦學習的舊數據集在上層的多個客戶端上,每個層的客戶端數據分布類似,本文采用隨機抽取一個客戶端的方式更新Ωij,具體過程見算法1 中的第5~24 行。
FedMas 算法的主要思想是通過聚類分層減小層內權重分歧的影響,同時通過引入記憶感知突觸算法聚合層間的知識,減小Non-IID 數據對全局目標函數的影響,提高訓練質量。
實驗使用MNIST 和CIFAR-10 數據集,具體如下:
1)MNIST 數據集。MNIST 數據集有10 個不同類別的手寫體數字(數字0~9),其中,訓練集包含60 000 張圖片和標簽,測試集包含10 000 張圖片和標簽。在MNIST 數據集上使用由1 個卷積層、1 個最大池化層、3 個全連接層組成的神經網絡模型。
2)CIFAR-10 數據集。CIFAR-10 數據集包含60 000 張32×32 像素的彩色圖像,其中,訓練集包含50 000 張圖像,測試集包含10 000 張圖像。CIFAR-10數據集圖像共有10 個類,分別為飛機、汽車、鳥類、狗等,該數據集更復雜,學習的難度更大。在CIFAR-10 數據集上使用由2 個卷積層、2 個平均池化層和2 個全連接層組成的卷積神經網絡模型。
為了體現FedMas 算法在數據異質環境下的特點和性能,將其與目前表現較好的聯邦學習算法FedProx[10]、Scaffold[25]和FedCurv[21]進行對 比,實 驗結果將在2.4 節中討論。FedProx 算法基于FedAvg改進了局部目標,引入了一個附加的近端項,用于限制局部模型和全局模型的差異,FedProx 的超參數mu 按照原文選擇0.2。Scaffold 算法引入了控制變量糾正Non-IID 數據局部訓練時的漂移問題。FedCurv 和FedProx 很類似,只是把正則化項改為EWC 算法的正則化項,通過持續學習正則化項克服數據異構下聯邦學習的災難性遺忘,FedCurv 在原文中λ=2.0 時效果更好,因此,在本文中設定λ=2.0。對于FedMas 算法的超參數λ,如果設置得過小,則對局部更新沒有影響;如果設置得過大,模型更新很慢,參考FedFMC[26]的設置方式,設為
為了更符合真實情況,本文通過狄利克雷分布來模擬不同客戶端數據集標簽傾斜的Non-IID 分布。根據狄利克雷分布劃分而來的數據集分布情況受狄利克雷的參數α控制[27]:α越大,所得到的概率分布越逼近均勻分布,采樣所生成的數據集越趨向于獨立同分布;α越小,所得的概率分布越偏向于集中某一些點,數據集的偏斜越嚴重,所得數據集越近似Non-IID 數據集。本文分別在α=0.3和α=0.7 取值下對MNIST 和CIFAR-10 這兩個數據集進行隨機采樣,產生Non-IID 程度不一致的數據集并隨機分發給各個客戶端進行實驗,以此評估FedMas 在處理不同程度的Non-IID 數據時的表現。以MNIST 數據集為例,在不同的異構設置下,取前10 個客戶端,其本地數據分布如圖3 所示(彩色效果見《計算機工程》官網HTML 版)。

圖3 MNIST 數據集在不同異構設置下前10 個客戶端的數據分布圖Fig.3 First ten clients' data distribution in MNIST dataset under different data heterogeneities
在本地訓練中,使用的優化器中SGD 參數設置為:學習率0.01,本地訓練輪次為5 輪,對比實驗中MNIST 數據集 通信輪次為50 輪,CIFAR-10 數據集通信輪次為100 輪,客戶端總數為100 個,每次以0.2的樣本率對客戶端進行隨機抽樣,樣本輸入維度為64,batch 大小設置為10。模擬實驗在同一臺具有NVIDIA RTX A5000 24 GB 的機器上進行。
FedMas 整體可分為2 個部分:1)將數據分布相似的客戶端劃分到一個層;2)在本地客戶端局部更新時加入記憶感知突觸算法項。為了驗證這種分層持續學習的聯邦學習優化方法的有效性,對上述兩個部分的有效性分別進行消融實驗。為了更清晰地了解數據非獨立同分布性質的干擾,采用MNIST 數據集進行實驗,該數據集共有10 類,將每個類別的數據平均分給其中的10 個客戶端,通信輪次為200 輪,其余實驗相關參數設置不變,分別以FedAvg算法(FedAvg)、結合記憶感知突觸算法項的FedAvg算法(FedAvgMas)、對客戶端進行分層的FedAvg 算法(TFedAvg)和對客戶端進行分層并結合記憶感知突觸算法的FedAvg 算法(TFedAvgMas)進行實驗,結果如圖4 所示。

圖4 分層和局部更新的有效性Fig.4 Effectiveness of hierarchical and local updates
1)分層策略的有效性。分層策略考慮了數據分布不同的客戶端在聚合時會受到其他客戶端學習知識的干擾。圖4 中FedAvg 和FedAvgMas 沒有分層,采用的是FedAvg 算法的隨機挑選策略,TFedAvg 和TFedAvgMas 采用了客戶端分層策略。可以看出:在MNIST 數據集下,采用客戶端分層策略的TFedAvg和TFedAvgMas,其平均準確率比隨機挑選的方案提升了近5 個百分點;在前80 輪通信過程中,4 種策略的平均準確率較為接近,但是TFedAvg 和TFedAvgMas 的提升過程更為平穩;在其他通信輪次,分層策略的平均準確率幾乎全部優于隨機挑選策略;此外,隨機挑選策略相較于分層策略的訓練曲線始終存在較大的波動。實驗結果表明了分層策略在數據非獨立同分布下的有效性,分層策略可以避免因隨機挑選帶來的全局模型聚合時多個客戶端數據分布不同的相互干擾,且分層后每個輪次可以學習全部數據分布客戶端的豐富知識,穩步提升全局模型的性能。
2)局部更新策略的有效性。在MNIST 數據集下,加入了記憶感知突觸算法項策略的平均準確率較未加入的方案提高了1 個百分點;FedAvg 與FedAvgMas 以及TFedAvg 與TFedAvgMas 的對比結果表明,使用記憶感知突觸算法項可以通過緩解局部模型訓練時的災難性遺忘進一步提升全局模型的性能。
表1 展示了FedMas 和其他方法在不同Non-IID程度數據集上準確率比較的結果,其中加粗表示最優值。從實驗結果來看:結合了持續學習算法的FedCurv 算法和FedMas 算法在不同的數據集以及數據異構情況下,比FedProx 算法和Scafflod 算法得到一個更好的全局模型;Scafflod 算法在特征分布更為復雜的CIFAR-10 數據集比FedProx 算法優勢更明顯;Non-IID 程度越高時,FedMas 算法的效果與其他算法準確率差距越大,這說明本文提出的算法能有效避免數據非獨立同分布的干擾,充分學習不同數據分布之間的知識,獲得更好的分類效果;當數據Non-IID 程度低時,FedMas 算法和其他算法效果差距減小,但持續學習算法的優勢在于即便是在獨立同分布的數據場景下使用,神經網絡訓練時仍能因其抗遺忘特性而提高模型的質量。因此,FedMas 在聯邦學習中優勢更加明顯。

表1 不同Non-IID 程度下的準確率比較Table 1 Accuracy comparison under different Non-IID levels %
各算法在α=0.3 時準確率隨訓練輪次的變化趨勢如圖5、圖6 所示,可以看出:本文提出的算法具有更快的收斂速度,最終準確率也最高,證明了提出模型的有效性;FedMas 算法每次局部更新時利用重要性權重矩陣限制了學習到知識的參數的更新程度,相較于其他算法每次變化更穩定,準確率更高;FedProx 算法和Scafflod 算法在學習的過程中波動較大,其中Scafflod 算法在MNIST 數據集上波動比較大,在CIFAR-10 數據集相對穩定,并且性能較好。

圖5 α=0.3 時FedCurv、FedProx、Scaffold 和FedMas 在MNIST 數據集上的準確率Fig.5 The accuracy of FedCurv,FedProx,Scaffold and FedMas on the MNIST dataset when α=0.3

圖6 α=0.3 時FedCurv、FedProx、Scaffold 和FedMas 在CIFAR-10 數據集上下的準確率Fig.6 The accuracy of FedCurv,FedProx,Scaffold and FedMas on the CIFAR-10 dataset when α=0.3
本文為非獨立同分布場景下的聯邦學習提供了一種新方法,它建立在全局聚合和局部更新的解決方案之上。該方法通過關注客戶端的數據分布情況對其進行分層,將每個層建模為持續學習的任務,再對層進行抗遺忘的知識融合學習,得到最終的預測模型。在不同數據集上和其他模型的對比結果,證明了本文方法的有效性。本文方法架構考慮到真實場景中的客戶端數據異質情況,因此具有一定的普適性,可應用在多客戶端共同訓練的場景下。在未來工作中,將關注因硬件設施導致的掉隊設備給實驗帶來干擾,以及客戶端設備異構的問題,設計性能更好的聯邦學習算法。