王愛麗, 薛 冬, 吳海濱, 王敏慧
(哈爾濱理工大學 測控技術與通信工程學院,黑龍江 哈爾濱 150080)
圖像識別是計算機視覺領域中常見的任務之一,通過提取數字圖像中有效的特征信息,來賦予圖像各自的標簽類別,從而完成識別過程。目前,圖像識別已經廣泛應用于醫療診斷[1]、交通標志識別[2]等領域中。
近年來,深度學習的快速發展使得其在圖像識別領域取得了很多成果。相比于傳統的人工識別方法,使用深度學習方法可以提取到圖像更多的深層次特征,從而提高識別準確率。然而,深度學習方法需要大量的訓練數據才能取得好的識別效果,所以對于數據量較小的數據集,識別效果并不理想,而生成對抗網絡(GAN)的提出很大程度上緩解了這種問題[3]。
GAN的核心思想來源于博弈論的納什均衡[4],它含有一個生成器和一個判別器,生成器在訓練過程中將生成的假樣本送入判別器來擴充訓練集的數據量,這使得判別器學習到更多的圖像特征,從而提高分類準確率。因此,基于GAN的這種優點,許多學者將其應用到圖像分類、識別領域。Zhu等人[5]將GAN應用到高光譜數據的分類中,提出了利用空間特征和光譜特征的1D-GAN和3D-GAN,顯著提高了高光譜數據的分類精度。劉坤等人[6]使用半監督GAN實現了X光圖像的分類,充分利用有標簽的數據提高分類效果。楊旺功等人[7]利用深度卷積生成對抗網絡(DCGAN)實現了花朵圖像的分類,實驗結果表明所提方法的分類準確率高且穩定性好。Kuang等人[8]首次將GAN用于肺結節惡性腫瘤的無監督分類中,實驗表明此方法使用少量標記樣本就能取得較好的分類結果。
雖然GAN在圖像處理領域具有較好的表現,但是GAN存在訓練太過自由,生成的圖像不可控等問題,因此Mirza等人[9]提出了條件生成對抗網絡(CGAN),在生成器和判別器中加入圖像的類別標簽,使得生成的圖片可以人為控制。因此,本文將CGAN和條件批處理歸一化[10-11](CBN)結合,提出CBN-CGAN網絡。CBN利用類別標簽對每一類數據進行批處理歸一化,使網絡學習到更多的特征,提高模型魯棒性。
條件生成對抗網絡是在GAN的基礎上發展而來的,通過把附加信息y加入到生成器G和判別器D的輸入中,使得GAN的無監督學習變為有監督學習。訓練階段,隨機噪聲z與條件變量y同時輸入到G中,得到盡量服從真實數據分布Pdata的生成數據G(z|y);然后將真實數據x、條件變量y和G生成的數據G(z|y)同時輸入到D中,最后輸出一個標量來估計輸入數據來自真實數據的概率。其目標函數為:

.
(1)
D的目標是實現對數據來源的二分類判別,G的目標是最大化D判斷輸入數據錯誤的概率。因此,這兩個相互對抗并迭代優化的過程使得G和D的性能不斷提升。當最終D無法正確判別數據真假時,可以認為G已經學到了真實的數據分布,二者達到平衡狀態。
深度學習算法的一個缺點是網絡在訓練階段容易出現無法收斂的問題。因此,谷歌公司通過對輸入數據進行批量歸一化(BN)處理,提高了網絡的訓練速度,解決梯度消失的問題[12]。但是在條件生成模型中存在一個問題,不同類別的訓練數據放在一起做批處理歸一化不妥當。因為在一個批量的數據中含有不同類別的數據,而每個類別的數據在計算后得到的均值和方差的數值也是不同的,因此在還原數據階段,應該采用不同的標準化、平移和縮放處理的方式來還原每一類數據。
條件批量歸一化[10-11](CBN)處理數據時不使用整個批量的統計數據,而是在各類數據特征圖內部做歸一化,依賴類別標簽還原出原始的每一類數據。CBN對數據的批處理如下:
,
(2)
(3)
,
(4)
.
(5)
式(2)和(3)分別計算了數據的均值和方差;式(4)和(5)對數據進行標準化、平移和縮放的處理。
本文網絡模型將CGAN與CBN相結合,稱為CBN-CGAN網絡,網絡模型結構如圖1所示,主要包括3方面的改進。

圖1 CBN-CGAN網絡結構圖
(1)將生成器和判別器的網絡結構改為全卷積的網絡結構,借助卷積神經網絡(CNN)強大的提取特征的能力,更好地學習圖像的特征。在判別器輸出末端加入Softmax分類器,使模型適用于多分類任務。
(2)在生成器和判別器的網絡層中添加CBN,充分利用類別標簽來對每一類數據進行批量化處理,使得網絡充分學習到特征圖中每個類別的特征,提高生成圖像的質量和識別精度。同時,提升模型穩定性,加快收斂速度,緩解梯度消失的問題。
(3)提出新的目標損失函數,使模型最終的輸出既包括對圖像來源的真假判別,即將真實圖像判斷為真,將生成器生成的圖像判斷為假,又包括對圖像類別標簽的多分類結果。
本文網絡模型的生成器是一個全卷積的網絡結構,使用反卷積層代替池化層,反卷積核的大小為4×4,并且去掉全連接層來增加模型穩定性。首先,輸入端是兩個反卷積層,它們分別將100維的噪聲z和n維的類別標簽(假設數據集有n個類別)進行維度轉換,轉換成(256,4,4)的三維張量,然后將二者連接成(512,4,4)的三維張量作為下一個反卷積層的輸入。最后,3個反卷積層將這個3維張量轉換成(1,32,32)的圖像。本文在中間兩個反卷積層后添加CBN,CBN對數據進行歸一化處理得到均值和方差,利用類別標簽控制偏置因子和縮放因子的取值,從而在映射原始數據時充分還原出每一類數據,使網絡充分學習到各類數據的特征分布。
CBN還可以解決初始化差和模式崩塌的問題,同時確保梯度傳播到模型的深層。此外,為了不降低模型穩定性,在輸入端和輸出端的反卷積層后不添加CBN。其次,本文在中間兩個反卷積層后添加ReLU 激活函數層,提高模型的學習速度,在最后一個反卷積層后添加Tanh激活函數層。CBN-CGAN的生成器網絡結構如圖2所示。

圖2 CBN-CGAN生成器網絡結構圖
本文判別器網絡結構和生成器網絡結構類似,也是一個全卷積的網絡結構,含有5個卷積層,卷積核的大小為4×4,不含池化層和全連接層。在輸入端,訓練數據和類別標簽分別通過卷積層進行維度轉換,轉換成(64,16,16)的三維張量,再將二者連接成(128,16,16)的三維張量。接著依次經過兩個卷積層輸出(512,4,4)的三維張量到達輸出端。輸出端分為兩部分:一部分判斷數據的真實來源,另一部分輸出分類結果。
與生成器相同,本文在除了輸入端和輸出端的卷積層后添加CBN,提升訓練效果。本文在判別器中使用LeakyReLU激活函數,它是ReLU激活函數的改進版,在判別器中表現得更好。因此,本文在中間兩個卷積層后添加LeakyReLU激活函數層。在輸出端,本文在最后兩個卷積層后分別使用Sigmoid分類器和Softmax分類器完成不同的任務,一是Sigmoid分類器輸出數據的真假判別結果,二是Softmax分類器輸出真實數據和生成器生成的數據的分類結果。
在分類時,假設數據集有n個類別。首先,每個由生成器生成的假數據通過網絡向前傳遞,并通過獲取概率預測向量的最大值來分配一個標簽。因此,這些假數據可用于在網絡中使用這些標簽進行訓練。此外,假數據不屬于任何類型的真實數據。由于真實數據和假數據的不同,所以創建了一個新的類別標簽(n+1類)劃分它們,每個假數據都被賦予這個新的類別標簽。本文采用這種方法最后輸出n+1個類別的識別結果。CBN-CGAN的判別器網絡結構如圖3所示。

圖3 CBN-CGAN判別器網絡結構圖
本文設計的網絡模型的目標損失函數包括兩部分:一部分是判斷數據來源真假的LS,另一部分是對真實數據和生成器生成的數據進行分類的LC。該目標損失函數如下所示:
,
(6)
,
(7)
,
(8)
式中,LS將輸入的真實數據判斷為真,將生成器生成的數據判斷為假。LC分為兩部分,一部分是真實數據對應的分類結果,它應該被檢測為前n種類別中其唯一對應的類別,即檢測為真實類別的概率為1;另一部分是使生成器生成的假數據可以被分類為第n+1類,即假數據被檢測成真實類別的概率為0,也即假數據對應的分類結果為第n+1類的概率為1。對于判別器,其最終目的是最大化LS+LC;生成器的目的是最小化LS-LC。
本文實驗在Windows操縱系統下進行,基于開源深度學習框架Pytorch,使用的編程語言為Python,實驗設備包括Intel(R) Core(TM) i5-6500 CPU @ 3.2 GHz處理器,16 GB運行內存(RAM),NVIDIA GeForce GTX 1060 GPU。實驗所用的數據集為MNIST數據集,它是機器學習領域最常見的手寫字體數據集,包含0~9的10類手寫數字圖像,每類數字包含60 000個訓練樣本和10 000個測試樣本,每幅圖像是28×28像素的灰度圖像。
實驗階段,為了讓生成器盡可能學習到所有樣本的數據分布,需要對樣本進行歸一化處理,對類別標簽數據進行one-hot編碼處理。網絡輸入的圖像大小為32×32,批處理大小設置為32,由于數據集為灰度的手寫數字圖像,特征相對較少,因此訓練迭代了30個Epoch。在訓練階段,網絡采用Adam優化器,學習率設置為0.000 2,權重衰減設置為0.000 5。
為了評估本文所設計的CBN-CGAN網絡模型的性能,將本文網絡模型與傳統的深度學習網絡做了對比。實驗評價標準一方面是生成圖像的質量,另一方面是識別的準確率。圖4為本文方法與CGAN和深度卷積生成對抗網絡[13](DCGAN)在Epoch次數為30時對生成圖像所做的對比。從對比圖中可以直接看出,本文提出的CBN-CGAN網絡比其他網絡生成的圖像質量更好,更清晰,例如數字“2”和數字“8”的輪廓更明顯,更容易辨別。

圖4 不同方法生成MNIST圖像對比結果
表1為本文識別方法與其他方法的識別精度對比。其中決策樹(Decision Tree)的最大深度設置為100;支持向量機(SVM)中的核函數設置為徑向基核函數(rbf),rbf的系數默認為“auto”,懲罰參數設置為100;隨機森林(Random Forest)中樹的數目設置為200;卷積神經網絡(CNN)的網絡結構與本文判別器網絡結構相似,卷積核大小也為4×4,網絡中加入ReLU激活函數,BN層,Softmax分類器,數據輸入之前做歸一化處理,學習率設置為0.001;在CGAN和DCGAN的判別器輸出端分別添加Softmax分類器,使這兩個網絡可以完成分類任務,學習率都設為0.000 2。

表1 MNIST數據集識別準確率和識別時間
從表1可以看出,本文所提出的CBN-CGAN網絡模型達到的識別準確率最高,為99.43%,分別比DCGAN、CGAN、CNN、隨機森林、SVM和決策樹高出0.57%,0.86%,1.07%,2.28%,2.83%和11.5%的準確率。雖然本文模型在識別時所消耗的時間有一定程度的增加,但是卻提高了準確率,因此證明了本文所提方法在手寫體數字識別時可以更好地提取特征,有效提高了識別準確率。
圖5為本文所提出的CBN-CGAN網絡的判別器損失函數曲線與原始CGAN網絡的判別器損失函數曲線的對比。從圖中可以看出本文所提出的網絡判別器的收斂速度更快,且隨著迭代次數的增加越來越穩定。

圖5 判別器損失函數曲線對比
本文結合條件生成對抗網絡與條件批處理歸一化提出CBN-CGAN、改進生成器和判別器的結構,提出新的損失函數。利用CBN融合類別標簽的優勢,使網絡提取更多的圖像特征,提高了模型的魯棒性。在MNIST數據集上的實驗結果表明,相比于其他方法,本文所提出的方法雖然消耗的時間較長,但是生成的圖像質量更好,同時手寫體數字識別的準確率達到99.43%,驗證了本文所提出的CBN-CGAN模型在手寫體數字識別領域中的有效性。