竇勇敢,袁曉彤
(1.南京信息工程大學 自動化學院,江蘇 南京 210044;2.江蘇省大數據分析技術重點實驗室,江蘇 南京 210044)
近些年來,隨著深度學習的興起,人們看到了人工智能的巨大潛力,同時希望人工智能技術應用到更復雜和尖端的領域。而現實狀況是數據分散在各個用戶或行業中,用戶數據存在隱私上的敏感性和安全性。如何在保護數據隱私前提下進行機器學習模型訓練,讓人工智能技術發揮出更強大的作用成為一種挑戰。
為了讓這些隱私數據流動起來,同時應對非獨立同分布數據的影響,Google 科學家Mcmahan 等[1]提出聯邦學習(federated learning),通過協調大量遠程分布式設備在保護用戶數據隱私的前提下訓練一個高質量的全局模型。
目前的聯邦學習算法還存在諸多問題。首先,每個設備CPU、GPU、ISP、電池以及網絡連接(3G、4G、5G、WIFI)[2]等硬件差異導致設備間存在很大的系統異構性。傳統的聯邦學習方法FedAvg[1]在規定時間內將沒有訓練結束的設備簡單丟棄,這在現實情況中是不可取的,浪費了大量的計算資源。其次,每個設備的數據分布和類型存在很大的差異[3],跨設備的數據是非獨立同分布的(non-IID),這是數據的異構性。不同的異構環境中模型的收斂效果差別很大,甚至無法收斂。這些系統級別的異構性給聯邦學習帶來了極大的挑戰。
現有針對異構性問題的分布式優化算法中,大部分都是針對特定異構環境設定的。例如:文獻[4-6]提出讓所有設備都參與每一輪的訓練,雖然在異構數據環境中的收斂性得到了保證,但是這在現實的聯邦環境[1]中是不可行的。這不僅增加了服務器的通信負擔,而且參與聯邦訓練的設備也應隨機抽取。也有方法通過共享本地數據來解決數據異構性的問題[7-8],但這違背了聯邦學習保護用戶數據隱私的前提。在聯邦設置中,文獻[9]通過在服務器端設計基于動量優化器FEDYOGI來加快異構數據環境中全局模型收斂速度,這雖然提高了模型的收斂速度,但卻增加了服務器的計算量,在有限的計算資源下不是好的選擇。此外,也有研究者利用二階擬牛頓法優化模型[10],在相同的異構環境中,與FedAvg 相比達到相同精度下減少了通信輪數,提高了通信效率,但這潛在增加了客戶端本地的計算量。
除了數據異構性,每個參與聯邦訓練的客戶端的硬件存在差異,這導致設備間存在很大的系統異構性[11]。例如:在文獻[12-15]中,介紹了在異構環境中目前最新的聯邦學習研究進展,在全局模型聚合階段的更新方式同FedAvg[1]一樣,在指定的時間窗口內,服務器將未完成訓練的設備直接丟棄,不允許上傳本輪訓練的模型參數。各參與訓練的設備不能根據自己硬件性能在本地執行可變數量的本地工作,缺乏自主調節能力。
在解決聯邦學習異構性的問題上,近鄰優化的更新方式廣泛地用于研究,包括高效通信分布式機器學習[16]、聯邦學習中公平性和魯棒性的權衡[17]。近鄰優化在原理上與有偏正則化相同,其中文獻[18]中考慮有偏正則化的方法對FedAvg進行重新參數化,提出FedProx,通過有偏正則化約束每個設備學習的本地模型更加接近于全局模型,并允許各參與訓練的設備在本地執行可變數量的工作,在異構環境中提供了收斂的保證。由于FedProx 在優化全局模型參數w時和FedAvg 方式相同,通過簡單平均本地上傳的模型參數來更新全局模型參數,導致全局模型收斂速度慢,缺乏直接對全局模型參數的優化。
受小批量近似更新的元學習機制[19]的啟發,本文提出了基于隱式隨機梯度下降優化的聯邦學習算法,在本地模型更新階段通過近鄰優化約束本地模型更新更加接近于全局模型,在全局模型聚合階段通過求解近似全局梯度,利用梯度下降來更新全局模型參數。最終實現全局模型能夠在較少的通信輪數下達到更快更穩定的收斂結果。
本文的貢獻主要體現在以下3 個方面:
1)區別于已有的方法,不在對全局模型參數進行簡單平均。在全局模型聚合階段,通過利用本地上傳的模型參數近似求出平均全局梯度,同時也避免求解一階導數。
2)針對異構性導致的全局模型收斂慢甚至無法收斂的問題,區別于現有的聯邦學習算法,本文提出基于隱式隨機梯度下降優化的聯邦學習算法,通過隱式隨機梯度下降來更新全局模型參數,能夠使全局模型參數實現更加高效的更新,從而可以在有限的計算資源下加快模型的收斂速度。
3)和現有的工作相比,本文的算法在高度異構的合成數據集上,30 輪左右就可以達到FedAvg 的收斂效果,40 輪左右可以達到FedProx 的收斂效果。在相同收斂效果的前提下,本文的算法比FedProx 減少了近50%的通信輪數。
聯邦學習更新架構主要有客戶端-服務器和去中心化對等計算架構。其中最常用的是客戶端-服務器的聯邦學習更新架構。訓練過程主要分為兩個階段:本地模型更新階段和全局模型聚合階段。具體更新過程如圖1 所示。

圖1 客戶端-服務器聯邦學習架構Fig.1 Federated learning architecture of client and server
1)本地模型更新
在本地模型更新階段,服務器首先隨機選取K個客戶端,然后服務器發送全局模型參數[[wt]]給被選客戶端,客戶端利用本地數據并行執行E個epoch 的隨機梯度下降,然后將更新后的模型參數經過同態加密算法[20]加密,之后再上傳至服務器。
2)全局模型聚合
在本節中,主要介紹聯邦近鄰優化算法和隱式隨機梯度下降優化算法的關鍵要素。由于聯邦學習是通過大量設備與中央服務器協同學習一個最優的全局模型,因此我們的最終目標是最小化:

式中:wk是設備k在本地迭代過程中所得的近似最優解;w是需要求解全局模型的最優解;Fk(wk):=,每個設備本地數據xk服從不同的分布 Dk,損失函數是預測值與真實值之間的差。式(1)包含兩方面的優化過程:1)在本地模型訓練階段,每個設備通過全局模型參數w學習一個本地近似最優wk;2)在全局模型聚合階段,服務器通過各設備上傳的wk利用隱式隨機梯度下降來調整全局模型參數w,使w與所有wk的平均距離較小。具體的算法流程為:

在算法1 中,步驟4)~6)為本地模型訓練階段,7)~9)為Server 全局模型更新階段,然后將更新后的模型參數發送給下一輪參與訓練的設備。不斷重復以上過程,直至模型損失收斂。
在本地模型訓練階段,主要在本地模型更新時引入帶參數的近鄰算子約束本地模型更新更加接近于全局模型,這種本地優化算法被稱為Fed-Prox 算法[18],每個設備k的本地目標函數被重新定義為

式中:λ是一個約束本地模型和全局模型差異的超參數;wt表示在第t輪服務器聚合更新之后的全局模型參數。

由鏈式法則可以得到:

所以?Gk(wt)=,式(4)展現了全局模型的梯度估計可以通過求解當前任務的近似更新來計算。在第t輪,所選設備在本地數據集上利用隨機梯度下降更新E輪后,求出近似最優解。服務器通過式(4)可以計算出平均的全局梯度:


式中:St為K個設備的子集;t為當前訓練輪數;為按固定輪數衰減的學習率;ηgi為初始化學習率,在訓練模型初期用較大的學習率對全局模型進行優化,隨著通信輪數的不斷增加學習率逐步減小,有效保證了全局模型在訓練過程中能以較快的速度逐步趨于穩定。更新后的wt+1作為下一輪訓練的全局模型參數。
從式(3)~(6)推導過程很容易看出,本文提出基于隱式隨機梯度下降優化的聯邦學習算法是直接對全局模型參數進行優化,而不是簡單平均所有設備上傳的本地模型參數作為更新后的全局模型參數。因為 ?Gk(wt)=,所以在服務器端只需通過就可以得到平均全局模型梯度,因此避免了求解一階導數,然后利用隨機梯度下降對全局模型參數進行更新。相比于FedProx,本算法在信息比較冗余的情況下能更高效地利用有效信息。其次,在迭代的過程中也會很快收斂到最小值附近,加快模型的收斂速度。
為了驗證本文提出的隱式隨機梯度下降優化算法的有效性,本文在3 個真實數據集和3 個合成數據集上進行實驗,在分類和回歸任務上進行評估,并與當前具有代表性的解決異構性問題的方法FedProx[18]以及經典的FedAvg[1]算法進行比較。
在Linux 系統下,包括2 塊GeForce GTX 1 080 Ti 和1 塊GeForce GTX TITAN X 的服務器上進行仿真實驗,代碼使用Tensorflow 框架實現,基于Python3 來實現基于隱式隨機梯度下降優化的聯邦學習算法。其中,訓練輪數、每輪迭代次數、選擇設備數量、學習率等超參數設置如表1 所示。

表1 超參數設置Table 1 Setting of Hyperparameters
為了保證評估方法與結果的公平性,本文提出的方法與FedProx、FedAvg 使用了相同的本地求解器,在模擬系統異構設置時,掉隊的設備數量分別設置為0%、50%、90%。生成合成數據集本文使用了和FedProx 類似的方法,通過式(7)生成本地數據:

式中:W∈10×60;x∈60;b∈10。通過式(7)生成30 個設備的數據集,同樣每輪隨機抽取10 個參與訓練。
Sent140[21]是一個Twitter 帶有表情的文本信息情感分類數據集,該任務使用的是一個兩層LSTM,包含256 個隱藏層單元,每個Twitter 帳戶對應一個設備。該模型以25 個字符序列作為輸入,通過兩個LSTM 層和一個全連接層,每個訓練樣本輸出一個字符。
MNIST[22]是一個0~9 手寫體數字識別數據集,在這個任務上利用邏輯回歸的方法研究手寫數字圖像分類問題。為了生成非獨立同分布數據,本文將數據隨機分布在1 000 個設備中,每個設備只有2 種數字。模型的輸入是28×28 維的圖像,輸出是0~9 這10 個數字的標簽。
EMNIST[23]是MNIST 數據集的擴展,包含0~9 數字和26 個英文字母的大小寫,構成了更大難度的62 類手寫字符圖像分類任務,但在實驗中只隨機抽取10 個小寫字母,每個設備分配5 個類,在這個任務上利用邏輯回歸的方法研究圖像分類問題。模型的輸入是28×28 維的圖像,輸出是a~j 這10 個類的標簽。
對于以上所有數據集,客戶端的本地數據分配遵循冪律分布[24]。本文在本地分配80%為訓練集,20%為測試集。各設備數據集組成如表2 所示。

表2 設備數據集分布Table 2 Datasets distribution on devices
首先在第1 個實驗中,為了驗證本文的算法在異構數據集上有更快的收斂速度,本文在3 組合成數據集上進行實驗,分別是Synthetic_0_0、Synthetic_0.5_0.5、Synthetic_1_1,從左到右數據異構性逐漸增強,異構性越強,對模型收斂影響越大。本文通過損失的減小速度和梯度方差[25]的變化來衡量模型的收斂速度,結果如圖2 所示。為了證明本文方法的公平性和有效性,約束項λ統一設置成相同的值。由圖2 訓練損失和梯度方差可以看出,本文的方法在第30 輪左右達到了FedAvg 的收斂效果,在第40 輪左右達到了FedProx 的收斂效果,并且40 輪以后還在繼續收斂。梯度方差(variance of local gradient,VLG)越小表示越穩定,收斂性越好。VLG 可表示為


圖2 合成數據集實驗結果分析Fig.2 Analysis of experimental results of synthetic datasets
實驗中,通過使所有設備執行相同的工作量來模擬不存在系統異構性的情況,隨著數據異構性增強,全局模型收斂結果最終會趨于某個區間,因此本文取最后一半通信輪數的平均測試精度作為模型好壞的評判標準,在合成數據集上平均測試精度如表3 所示,可以看出本文提出的算法平均測試精度普遍高于FedProx 和FedAvg。

表3 合成數據集上平均測試精度Table 3 Average test accuracy on synthetic datasets %
在本實驗中,為了驗證本文提出的算法在高度系統異構性和數據異構性環境下的整體效果,本節在3 個聯邦學習常用真實數據集和一個合成數據集上比較不同算法的穩定性和收斂效果,其中Synthetic_1_1 客戶端本地類別設置為5,實現在數據異構性基礎上模擬不同系統異構性的聯邦設置。
本文通過約束設備的本地工作量,使每個設備訓練指定的E來模擬系統的異構性,對于不同的異構設置,隨機選擇不同的E(E<20)分配給0%、50%和90%當前參與訓練的設備。當掉隊者為0% 時,代表所有設備執行相同的工作量(E=20)。在指定的全局時間周期內,當E<20 時,FedAvg 會丟掉這些掉隊者,本文的算法和Fed-Prox 會合并這些掉隊者,不同的是本文在全局模型聚合階段會有效地使用合并掉隊者的模型參數,利用隱式隨機梯度下降對全局模型進一步優化。真實數據集上的訓練損失如圖3 所示,從上到下3 行圖片分別代表0%、50%和90%的掉隊者。隨著迭代輪數的不斷增加,平均損失逐漸趨于穩定,從圖3 中可以看出本文提出的算法的收斂速度明顯優于Fedavg 和FedProx。

圖3 真實聯邦數據集實驗結果分析Fig.3 Analysis of experimental results of realistic federated datasets
表4 給出了在高度異構環境下模型的平均測試精度,從表中可以看出掉隊者為90%時,本文提出的算法的平均測試精度最高,其次是Fed-Prox。本文算法在MNIST 數據集上比FedProx高5%。實驗中,在Sent140 數據集上通過設置相同超參數進行比較不同算法運行時間,在通信輪數為200 的情況下,FedAvg、FedProx 和本文所提算法運行時間分別為67 min、108 min、108 min。

表4 高度異構環境各算法平均測試精度Table 4 Average test accuracy of each algorithm in highlyheterogeneous environment %
本文提出了一種基于隱式隨機梯度下降優化的聯邦學習算法。全局模型聚合階段不再是簡單的平均各設備上傳的模型參數,而是利用本地上傳的模型參數近似求出全局梯度,同時避免求解一階導數。利用隨機梯度下降對全局模型參數進行更新,在信息冗余的情況下能更準確地利用有效信息,隨著通信輪數不斷增加,全局模型會很快收斂到最小值附近。在3 個合成數據集和3 個真實數據集上的實驗結果充分表明:該算法能夠在不同異構環境中均表現出更快更穩健的收斂結果,顯著提高了聯邦學習在實際應用系統中的穩定性和魯棒性。