張曉峰,吳 剛
(中國科學技術大學 信息科學技術學院,合肥 230031)
隨著近些年來深度學習的發展,深度神經網絡[1]在分類任務上取得了革命性的突破.基于深度神經網絡的分類器在有充足標簽樣本為訓練數據的前提下可以達到很高的準確度.但是往往在一些場景下,有標簽的數據難以收集或者獲取這些數據成本高昂,費時費力.當數據不足時,神經網絡很難穩定訓練并且泛化能力較弱.如何在小規模數據集上有效的訓練神經網絡成為當下的一個研究熱點.常見的應對小規模數據集訓練問題的措施主要有以下3種:
(1)無監督預訓練和有監督微調相結合的方法.通過引入和訓練數據具有相同分布的大量無標簽數據的方式,神經網絡可以先收斂到一個較優的初始點,然后再在小數據上微調.但是這種方式存在一個潛在的假設:無標簽的數據容易獲得而且收集成本不高,但是在一些數據難以獲取的場景中,例如醫療圖像,這種方法將無法應用.
(2)遷移學習[2]的方法.相比于第一種方法,遷移學習的要求更加寬泛,額外的無標簽數據不需要和訓練數據具有相同的分布,只要相似或者分布有重疊即可.在視覺識別當中,一些視覺的基本模式像邊緣、紋理等在自然圖像中都是共通的,這一點構成了遷移學習的理論保證.大量的實踐表明,在源領域(source domain)上學習大量數據后的網絡再遷移到目標領域(target domain)上,網絡的性能會得到極大的提升.但是當源領域和目標領域之間差距甚大時,遷移學習是否有所幫助,目前還未有研究.
(3)數據增強[3]的方法.通過合成或者轉換的方式,從有限的數據中生成新的數據,數據增強技術一直以來都是一種重要的克服數據不足的手段.傳統的圖像領域的數據增強技術是建立在一系列已知的仿射變換——例如旋轉、縮放、位移等,以及一些簡單的圖像處理手段——例如光照色彩變換、對比度變換、添加噪聲等基礎上的.這些變化的前提是不改變圖像的標簽,并且只能局限在圖像領域.這種基于幾何變換和圖像操作的數據增強方法可以在一定程度上緩解神經網絡過擬合的問題,提高泛化能力.但是相比與原始數據而言,增加的數據點并沒有從根本上解決數據不足的難題;同時,這種數據增強方式需要人為設定轉換函數和對應的參數,一般都是憑借經驗知識,最優數據增強通常難以實現,所以模型的泛化性能只能得到有限的提升.
最近興起的一些生成模型,由于其出色的性能引起了人們的廣泛關注.例如變分自編碼網絡(Variational Auto-Encoding network,VAE)[4]和生成對抗網絡(Generative Adversarial Network,GAN)[5],其生成樣本的方法也可以用于數據增強.這種基于網絡合成的方法相比于傳統的數據增強技術雖然過程更加復雜,但是生成的樣本更加多樣,同時還可以應用于圖像編輯,圖像去噪等各種場景.本文主要介紹的是基于生成對抗網絡的數據增強技術,并將這種方法應用于小規模數據集的分類任務.
生成模型可以分成顯式密度模型和隱式密度模型兩種.生成對抗網絡是一種隱式密度模型,即網絡沒有顯式的給出數據分布的密度函數,GAN的網絡結構如圖1所示,是由生成網絡(Generator,G)和判別網絡(Discriminator,D)兩部分組成.假設在低維空間Z存在一個簡單容易采樣的分布p(z),例如標準正態分布N(0,I),生成網絡構成一個映射函數G:Z→X,判別網絡需要判別輸入是來自真實數據還是生成網絡生成的數據.生成網絡輸入噪聲z,輸出生成的圖像數據;判別網絡輸入的數據或者來自真實數據集,或者來自生成網絡合成的數據,輸出數據為真的概率.

圖1 GAN結構示意圖
G和D相互競爭:G試圖欺騙D從而以假亂真,而D則不斷提高甄別能力防止G合成的數據魚目混珠,理論上最終生成的數據分布Pg和真實的數據分布Pdata可以相等.可以用式(1)概括整個GAN網絡的優化函數:

GAN本質上屬于無監督學習的范疇,其判別網絡僅僅輸出數據真假的概率.條件生成對抗網絡(Conditional-GAN)[6]在GAN的基礎上,加入類別的信息Y,從而可以生成指定類別的數據.Conditional-GAN的優化函數可以寫成式(2):

Conditional-GAN的判別器D仍然只有一個輸出來判斷真假,而半監督學習生成對抗網絡(Semi-GAN)[7]在Conditional-GAN 的基礎上,判別器輸出增加到K+1個(K代表數據的類別個數),K個輸出表示真實數據的分類概率,第K+1個表示數據為假的概.Conditional-GAN和Semi-GAN的結構如圖2所示.

圖2 Conditional-GAN與Semi-GAN結構對比
本文從數據增強的目的出發,通過改進生成對抗網絡的結構和訓練算法,設計了一種基于生成對抗網絡的數據增強技術,并提出了一種新的網絡結構,即數據增強生成對抗網絡(Data Augmentation GAN,DAGAN).與其他的GAN結構相比,我們提出的網絡結構更加適用于數據增強任務,即生成的樣本和原始數據真假難分的同時,還可以做到類間可分,從而有利于分類器在在合成的數據點上學習到分類界限.在訓練算法上,本文將DAGAN的訓練過程和分類器的訓練過程相結合,并提出一種新的損失函數,稱之為“2K”損失函數,從而可以做到在線數據增強,即數據處理和分類器訓練可以在內存中同步處理,不需要另外的數據存儲空間.
一般的GAN網絡其判別器僅僅只有一個輸出——判斷輸入的真假,如果直接用來生成數據用來做數據增強是不可行的,因為不能做到按類別生成樣本.Conditional-GAN和Semi-Supervised GAN雖然可以利用數據的標簽信息,并且按照給定的類別生成相應的數據,但是相關的研究工作表明這樣的GAN結構其生成的樣本多樣性不足,對數據增強的貢獻十分有限.因此,需要針對我們數據增強的這一特定需求,即生成的數據有利于分類器學習更加緊湊的分類界限,提升分類性能來設計網絡結構.基于以上考慮,從生成網絡的角度來看,最優的判別網絡需要:
1)能夠正確地將真實數據和生成數據分類;
2)不能分辨數據是真實的還是合成的.
據此,在GAN的基礎上設計出適合于小規模數據增強任務的GAN網絡結構,即DAGAN.結構如圖3所示.

圖3 DAGAN網絡結構
這里,生成網絡采用Conditional-GAN的結構,隱向量z和類別信息y作為輸入,輸出對應類別的數據;判別網絡的輸入有兩個來源——真實數據或者生成的數據,輸出則變為2K個,前K個表示輸入為真實數據K類的概率,后K個表示輸入為生成數據K類的概率.
可以看出,就判別網絡而言,從GAN到Semi-Supervised GAN,再到本文提出的DAGAN,輸出的維度不斷增加,同時應用的領域也更加廣泛.就生成網絡而言,Conditional-GAN,Semi-Supervised GAN以及本文的DAGAN都利用了數據的標簽信息,可以根據指定的類別生成相應的數據.DAGAN在利用Conditional-GAN生成器結構的同時,又增強了判別網絡的判別能力,使之適用于小規模數據集的增強.表1總結了以上幾種GAN網絡的特點對比.

表1 幾種GAN網絡的對比
DAGAN的訓練分成兩個階段,第一階段為數據生成階段.生成網絡和判別網絡優化相反的目標函數,在不斷的對抗中達到平衡.與GAN不同的是,由于判別網絡有2K個輸出,因此相應的損失函數也將發生改變,稱之為“2K”損失函數.對于判別網絡,其損失函數如下:

對于生成網絡,除了對應的判別真假的損失函數之外,還包括正則化項,用來保證生成的數據和真實的數據在特征層面盡可能保持相近,損失函數如(4)式所示:

其中,Lfm為正則化項,具體形式如下:

這里f(x)函數判別網絡中間某一層的輸出,即要求在相同類別的前提下,生成數據和真實數據特征應當相近,這進一步保證了生成數據和真實數據在同一類別下具有相同的語義.
第二階段為分類訓練階段,假設第一階段訓練完成之后,生成網絡已經學習到真實數據的分布.因此在這一階段,生成網絡將不再進行訓練,僅僅作為一個數據的提供者,生成的數據和真實數據一起訓練分類網絡.值得注意的是,這里不需要單獨搭建新的分類網絡,判別網絡直接作為分類器進行訓練.由于判別網絡有2K個輸出,這里規定第i個與第k+i個輸出的概率之和表示輸入為第i(i=1,2,…,k)類數據的概率.第二階段的判別網絡的損失函數由兩部分構成,分別是真實數據和生成數據:

其中,

兩個階段均采用批量隨機梯度下降的算法進行參數更新,具體流程見算法1.

算法1.DAGAN批量隨機梯度下降訓練算法輸入:第一階段的迭代次數KG,第二階段的迭代次數KC,訓練集D,測試集T,批次數量B 1)數據生成階段訓練:分別采樣真實數據(x,y)~Pdata(x,y),以及隱向量數據z~P(z),隨機類別數據y~Pg.在KG次迭代中,采用隨機梯度下降的方法,交替更新生成網絡和判別網絡,損失函數分別為LG和LC.2)數據分類階段訓練:分別采樣真實數據(x,y)~Pdata(x,y),以及隱向量數據z~P(z),隨機類別數據y~Pg,在KC次迭代中,采用隨機梯度下降的算法,只更新判別網絡,損失函數為L’C.3)在測試集上測試判別網絡的準確率.
為了驗證DAGAN的生成能力以及生成樣本能否提升分類器的準確率,我們分別在3個數據集上做了驗證實驗,分別為CIFAR-10、SVHN以及KDEF數據集.實驗中的網絡結構都是基于DCGAN[8]這個網絡搭建,詳細的網絡結構參數如表2所示.

表2 CIFAR10實驗網絡參數與網絡結構(SVHN與KDEF數據集實驗與之類似)
這里需要說明:G,D,T-Conv,Conv,NIN,NL分別表示生成網絡,判別網絡,反卷積,卷積,Network in Network,非線性激活函數.
CIFAR-10數據集總共包含60 000張RGB圖片,其中50 000張為訓練圖片,10 000張為測試集圖片.圖片為32×32的分辨率,總共可以分成10類.為了探究各種數據增強方式對于不同程度的小規模數據集的影響,我們人為地從該數據集中抽取不同數量的子數據集,每類從50到1000不等.實驗主要對比以下幾種不同的數據增強方式:(1)不采用任何的數據增強方式(C);(2)傳統的基于仿射變換和圖像操作的數據增強方式(C_aug);(3)GAN在每一類上分別訓練,然后每一類單獨生成數據(Vanilla GAN);(4)Semi-Supervised GAN 生成數據(Semi GAN);(5)本文所提出的方法(DAGAN);(6)本文所提出的方法加上傳統的數據增強技術(DAGAN_aug).實驗對比了不同方法下訓練出來的分類器在測試集上的分類準確率(Acc),結果見表3.

表3 不同數據增強方式在CIFAR10數據集上測試集的準確率(%)
從實驗結果可以看出,DAGAN_aug是所有方法中對分類器提升最顯著的,表明DAGAN可以在傳統數據增強的基礎上進一步提升模型的性能,突破傳統數據增強的瓶頸.另外可以看出DAGAN在數據量較少的時候(每類圖片數量小于500張)要優于Vanilla GAN和Semi GAN,說明本文針對數據增強目的設計的DAGAN網絡結構和訓練算法更加有利于分類器的性能提升.
SVHN[9]是真實世界的街道門牌號碼識別數據集,每張圖片代表0-9中的一個數字,分辨率為32×32.由于每種圖片中可能包含不止一種數字,而標簽為中心的數字.傳統的數據增強方式例如翻轉、移位等在這樣的數據中將不能應用,因為這些轉換方式可能會改變圖像的標簽.同樣地,表4給出了不同種數據增強方式在SVHN數據集上的性能對比.實驗僅僅考慮了3種數據增強方式的對比,即(1)不采用任何的數據增強方式(C);(2)Semi-Supervised GAN生成數據(Semi GAN);(3)本文提出的方法(DAGAN).

表4 不同數據增強方式在SVHN數據集上測試集的準確率(%)
實驗結果和CIFAR10數據集是一致的,在數據量較少的情況下,DAGAN能夠最大程度的提升分類器的分類性能,且優于Semi GAN的方法.有一點需要注意,當數據量較多時(每類圖片數為500張),Semi GAN和DAGAN兩種方法幾乎都不起作用,這主要是因為對于相對比較簡單的SVHN數據集,當訓練數據達到一定規模后,限制網絡性能的因素不再是數據,而是分類網絡的結構還有分類算法.
KDEF[10]數據集是一種人臉表情數據集,包含35個男性和35個女性,年齡在20至30歲之間.沒有胡須,耳環或眼鏡,且沒有明顯的化妝.7種不同的表情,每個表情有5個角度.總共4900張彩色圖,尺寸為562×762像素.實驗中我們僅采用正面角度,因此只有490張圖片,根據表情進行分類.
本次實驗生成網絡的結構沒有變化,與表2類似,判別網絡采用VGG-16,由于數據量過少,因此我們采用的VGG-16是在ImageNet數據集上預訓練過的.實驗對比了以下幾種數據增強方式的性能:(1)不采用任何數據增強方式,僅僅是預訓練的分類器(C);(2)GAN在每一類上分別訓練,然后每一類單獨生成數據(Vanilla GAN);(3)Semi-Supervised GAN生成數據(Semi GAN);(4)本文所提的方法(DAGAN).實驗結果如表5所示,從結果來看,DAGAN依然是性能最好的結構,同時說明DAGAN可以和預訓練的策略相結合,進一步提升分類器的性能,突破數據增強技術的瓶頸.

表5 不同數據增強方式在KDEF數據集上測試集的準確率(%)
以上3個數據集的實驗說明了DAGAN結構的可行性和有效性,為了進一步表明DAGAN生成的圖片和原始圖片具有相同的語義,而且呈現出內容上的多樣性,這一部分將展示3個數據集上DAGAN生成的數據樣本,并和原始數據相比較,如圖4所示.
從生成圖片來看,CIFAR-10數據集每一行都是有著漸變的效果,這是通過對隱變量z差值實現的;而每一列都是一個不同的類別,這是通過控制類別信息y實現的.SVHN數據集每一行都是屬于相同的類別,而每一列圖片的z保持相同,所以每一列的圖片具有相同的風格.以上都說明DAGAN生成的圖片是可編輯的,同時也可以看出生成的圖像呈現比較豐富的多樣性,從而印證了DAGAN可以用于數據增強任務.

圖4 CIFAR-10數據集、SVHN數據集和KDEF數據集原始圖片和生成圖片對比
由于深度神經網絡在小規模數據集上難以訓練,容易出現過擬合的問題,本文提出一種基于生成對抗網絡的數據增強技術,通過在大量的實驗,以及和其他模型的對比,驗證了所提方法的可行性和有效性.DAGAN既可以有效提升分類器的分類性能,同時生成的圖像數據和真實數據相比具有語義的相似性和內容的多樣性.