楊曼,黃遠民,石遠豪
(1.佛山職業技術學院,廣東 佛山 528137;2.佛山賽寶信息產業技術研究院有限公司,廣東 佛山 528000)
一般而言,損失函數很復雜,參數空間龐大,而通過巧妙的使用梯度來尋找函數最小值的方法就是梯度法,梯度表示的是各點處的函數值減小最多的方向。在神經網絡的學習中,尋找最小值的梯度法稱為梯度下降法,由于神經網絡中選取的訓練數據多為隨機批量數據,因此稱為隨機梯度下降法。用數學公式來表示方法如下式所示:

在使用SGD訓練參數時,SGD每次都會在當前位置上沿著負梯度方向更新,并不考慮之前的方向梯度大小,有時候會下降的非常慢,并且可能會陷入到局部最小值中,動量的引入就是為了加快學習過程,引入一個新的變量去積累之前的梯度(通過指數衰減平均得到),實現加速學習過程的目的。用數學公式來表示方法如下式所示:

這里變量v表示物體在梯度方向上的受力,遵循在力的作用下,物體的速度增加這一法則,v初始為None,若當前的梯度方向與累積的歷史梯度方向一致,則當前的梯度會被加強,從而這一步下降的幅度更大。若當前的梯度方向與累積的梯度方向不一致,則會減弱當前下降的梯度幅度。a初始值設定為0.5、0.9或者0.99。
AdaGrad算法的思想是每一次更新參數,不同的參數使用不同的學習率。將每一個參數的每一次迭代的梯度取平方累加后再開方,用全局學習率除以這個數,作為學習率的動態更新。用數學公式來表示AdaGrad方法如下式所示:這里變量h保存了所有梯度值的平方和,在更新參數時,通過乘以來調整學習的尺度,參數元素中被大幅更新的元素的學習率將變小。從算法AdaGrad中可以看出,隨著算法不斷迭代,h會越來越大,整體的學習率會越來越小。

Adam方法計算了梯度的指數移動均值,融合了RMSProp和Momentum的方法,Adam會設置3個超參數,一個是學習率,標準值設定為0.001,另外兩個超參數beta1和beta2控制了這些移動均值的衰減率,移動均值的初始值beta1、beta2值接近于1,在深度學習庫中,標準的設定值是beta1=0.9、beta2=0.999。
MNIST數據集60000個訓練數據樣本,構建5層神經網絡,輸入層數據大小為784,隱藏層神經元個數分別為600、400、200,輸出層神經元個數為10,激活函數用relu。
其迭代后損失函數結果如下表1所示(保留4位有效數字)。

表1 損失函數值
根據以上數據可得到損失AdaGrad函數最小為.0105,Adam次之且相差不大,其余兩種方法學習效率較低。
其迭代后損失函數結果如下表2所示(保留4位有效數字)。

表2 損失函數值
根據以上數據可得到Adam損失函數最小為。0.0019,在神經網絡層數不變,神經元個數不變的情況下,增加訓練樣本的迭代次數,AdaGrad的損失函數更新速度變慢,Adam方法損失函數更新速度更快、效果更好,此時損失函數值更小。
以MNIST數據集為例,分別采用SGD、Momentum、AdaGrad、Adam方法對目標函數進行優化,構建以下深度CNN網絡結構圖,見圖1。完成對該數據集中手寫數字的分類識別。

圖1 深度網絡結構圖
其中Conv為卷積層,主要進行卷積運算,相當于圖像處理中的濾波器運算。Pool為池化層,主要是縮小高、長方向上的空間運算。Dropout是一種在學習的過程中隨機刪除神經元的抑制過擬合問題的方法,訓練時隨機選出隱藏層的神經元,然后將其刪除,被刪除的神經元不再進行信號的傳遞。softmax函數為輸出層函數,softmax函數的輸出是0.0到1.0之間的實數。采用上述深度卷積神經網絡進行測試,當epochs=12,得到準確率如下表3所示。

表3 測試結果
結果顯示,采用Adam進行目標函數優化,分類識別準確率為0.9951。設定當epochs=20,得到準確率如下表4所示。

表4 測試結果
結果顯示,增大epochs后,采用Adam 進行目標函數優化,分類識別準確率仍為最高0.9955。增加了0.04%,但計算速度會減慢。
采用SGD、Momentum、AdaGrad、Adam 4種權重參數更新方法,SGD下降過程更為曲折,因噪聲使梯度更新的準確率下降,同時SGD會在某一維度上梯度更新較大產生振蕩,可能會越過最優解并逐漸發散。加入動量Momentum梯度移動更為平滑,但效率仍偏低。AdaGrad可根據自變量在各個維度的梯度值大小來調整各個維度上的學習率,避免統一的學習率難以適應所有維度的問題。Adam融合了Momentum、AdaGrad的方法,Adam的更新過程類似Momentum,但相比之下,Adam晃動的程度有所減輕。以MNIST數據集為樣本集,構建深度卷積神經網絡結構,分別采用SGD、Momentum、AdaGrad、Adam方法對目標函數進行優化,完成對該數據集中手寫數字的分類識別,結果對比顯示采用方法,識別準確率為99.55%。