甘 宏
(廣州南方學院,510970,廣州)
近年來,將深度學習技術用于視覺識別任務取得了相當大的進展[1-5]。然而,有監督的深度學習模型需要大量的標記樣本和迭代步驟來訓練模型的參數,這嚴重限制了深度學習技術對新出現或罕見類別的適用性,同時收集并標記大量的樣本需要耗費大量的人力物力。相比之下,人類卻擅長通過少量甚至幾個樣本來識別物體,而深度學習技術難以用于每類僅有一個或幾個樣本的學習。受人類具備小樣本學習能力的啟發,使得小樣本學習問題引起了廣泛的關注。
現有的小樣本學習方法大致可以分成3類:度量學習、元學習以及基于數據增強方法。度量學習方法利用輔助數據集學習得到一個度量空間,使得在該度量空間中同一類樣本的特征向量彼此間的距離較近,而不同類樣本的特征向量距離則較遠,從而實現小樣本學習。文獻[6]將卷積孿生網絡用于單樣本圖像識別,通過有監督的方式訓練孿生網絡,然后重用網絡所提取的特征向量進行單樣本學習。文獻[7]提出了匹配網絡,該算法的核心是episode-based的訓練策略,其基本思想是訓練和測試是要在同樣條件下進行,即在訓練的時候讓網絡模型只看每一類的少量樣本,使得訓練和測試的過程保持一致。原型網絡的基本思想是每個類都存在一個原型表達,該類的原型是支撐集在利用度量空間中特征向量的均值作為類表示[8]。F Sung等[9]提出了關系網絡求解小樣本學習問題,該模型2個模塊:嵌入模塊和關系模塊。嵌入式模塊用于提取數據樣本的特征表示,而關系模塊用于估計2個特征表示之間的距離。D DAS等[10]添加了一個預訓練階段,利用所有基類的分類任務預訓練模型獲得參數的初始化。Li等[11]在分類損失函數中添加一個與任務相關的附加邊際損失,以更好地區分不同類別的樣本,從而提高分類性能。Zhou等[12]利用貪婪算法選擇與支持集樣本的相似基類,使得度量模型能對新的小樣本任務有較強的適應性。元學習方法通過對多個任務的學習,以使元模型(meta-learner)能夠對新的任務做出快速而準確的學習,該方法包含了2個關鍵問題:訓練得到最優初始化參數和學習有效的參數更新規則。FINN等[13]提出了MAML(Model-Agnostic Meta-Learning)的元學習方法,基本思想是訓練一組初始化參數,通過在初始參數的基礎上進行一或多步的梯度調整,來達到僅用少量數據就能快速適應新任務的目的。K Wang等[14]給出了結合概率推理和元學習的識別模型,以阻止元模型訓練過程中偏向某些具體任務,從而提高元模型對新任務的泛化能力。Meta-SGD算法[15]對MAML算法進一步優化,不僅對初始參數進行了學習,而且對元模型的更新方向和學習速率進行學習。文獻[16]提出了一階元學習算法,該算法采用一階導數近似表示二階導數,使得元參數更新過程中不需要像MAML算法一樣計算二階導數,從而提高元模型的訓練效率。數據增強方法通過擴充樣本來提高小樣本學習的性能。然而,數據生成模型在僅有少數幾個訓練數據時,往往表現不佳。
本文提出算法屬于元學習方法的范疇。針對現有元學習方法對部分訓練任務存在有偏的不足,本文提出基于正則化元學習算法。通過在元學習的目標函數中添加正則化項,阻止元學習的初始模型偏向現有某些訓練任務,提高元模型對新任務的泛化能力,從而提高小樣本圖像分類的性能。
小樣本分類的目標是找到參數θ,小樣本分類目標是學習得到參數θ使得分類器fθ在詢問集中的期望值最大
(1)

為了減小元訓練過程中產生有偏,提高元學習模型的泛化能力。本節提出了正則化元學習算法(Regularized Meta Learning,REML)。通過在元目標函數添加正則化約束項,使得模型對訓練任務無偏。針對小樣本圖像分類問題,MAML算法的元目標函數為:
(2)


(3)
其中LTi(fθ)采用交叉熵損失函數,表示為:

(4)
因此,MAML算法的元目標函數可以表示為:
(5)
為盡量減小參數θ對訓練任務有偏,提高元模型的泛化能力。本文引入交叉熵的約束條件,作為原目標函數的正則化項,使得參數θ對訓練任務是無偏的。交叉熵表示為:

(6)
以交叉熵作為正則化項,則元目標函數表示為

(7)

(8)
元目標函數梯度更新表示為

(9)

(10)
求導涉及到二維求導問題,大大增加了算法的計算量。針對以上不足,利用一階導數近似二階導數得到
(11)
則元參數更新模型(9)可以簡化為
(12)
本節將給出算法的詳細步驟,詳見算法1。
算法1:正則化元學習算法。

1)While not done do;
2)抽取幾個任務Ti構成任務塊Tbat;
3)for allTiinTbatdo;
4)從Ti中每類選取K個樣本記做D;
5)利用LTi(fθ)和D計算?θL(fθ);
7)從Ti抽取Dval用于元參數學習;
8)End for;
9)利用Dval和元學習目標函數L(θ)學習元模型參數θ,
10)End while。
輸出:元模型參數θ。
本節通過在miniImageNet、CUB-200和CIFAR-100這3個典型數據集上進行的小樣本分類實驗,來充分驗證本文算法性能,并與MAML、Reptile、Relation Networks和Prototypical Networks等先進算法比較。實驗1比較了不同算法在MiniImageNet數據集中的性能,并給出了參數λ對本文算法的影響;實驗2比較了不同算法在數據集CUB-200上的算法性能;實驗3給出了在數據集CIFAR-100上不同算法的性能比較。
為方便與其他算法進行比較,在后續的實驗中本文算法采用了與文獻[8-9,13,16]相同的網絡結構。網絡結構由4個模塊組成,每個模塊包含1個3×3×64的卷積層和1個2×2的池化層,每個卷積層均采用歸一化處理。
MiniImageNet數據集包含100個類,其中每個類包含600個樣本。采用與其他算法相同的拆分,其中64個類用于訓練,16個類進行驗證,20個類用于測試。分別進行了5-way 1-shot和5-way 5-shot小樣本圖像分類實驗,表1給出不同算法的分類精度比較。
由表1可以看出,本文算法由于提高了模型對新任務的泛化能力,從而使分類精度得到了一定的提升。

表1 不同算法在數據集miniImagenet中分類精度的比較
CUB-200數據集[14]包括了200種細分的類。參照文獻[15]中的劃分,隨機選取100個類用于元訓練,50個類用于驗證,50個類進行測試,并將每幅圖像的尺寸大小調整為84×84。分別進行了5-way 1-shot和5-way 5-shot小樣本圖像分類實驗,表2比較了4種算法的分類精度。
由表2可以看出,本文算法相對于MAML算法的分類精度能有將近4%的提升。
CIFAR-100數據集包括了100個類,每個類包含600張尺寸為32×32的圖形。隨機選取64個類進行元訓練,16個類用于驗證,20個類用于小樣本分類性能測試。與其他實驗類似,分別進行了5-way 1-shot和5-way 5-shot小樣本圖像分類實驗,表3比較了不同算法的分類精度。由表3可以看出,本文算法相對于MAML算法精度有3%左右的精度提高。

表2 不同算法在數據集CUB-200中分類精度的比較

表3 不同算法在數據集CIFAR-100中的分類精度比較
本小節通過對以上3個數據庫的5-way 5-shot小樣本圖像分類實驗,分析平衡參數λ對算法性能的影響。圖1給出了本文算法(REML)在不同參數值時的分類精度。由圖1可以看出,當參數λ取值接近0時,算法識別精度與MAML算法接近;當參數λ取0.2~0.3之間時能獲得較高的識別精度;當參數λ大于0.3之后,隨著參數的增加算法性能逐步下降。

圖1 平衡參數λ不同時的算法分類精度
針對小樣本學習問題,本文提出了正則化元學習算法(REML)用于求解小樣本圖像分類問題。該算法以交叉熵作為正則化項,以阻止元模型參數偏向某些具體任務,從而提高元模型的泛化能力,即提高元模型對新任務的適應能力。此外,采用一階導數近似二階導數減小元學習模型訓練所需計算量。在miniImageNet、CUB-200和CIFAR-100這3個數據集上進行的實驗表明,本文算法的分類性能優于現有的同類算法,并表明平衡參數選擇在0.2~0.3之間時能獲得較高的識別精度。