朱佳麗 宋燕



摘要:膠囊網絡(CapsuleNetwork,CapsNet)通過運用膠囊取代傳統神經元,能有效解決卷積神經網絡(ConventionalNeuralNetwork,CNN)中位置信息缺失的問題,近年來在圖像分類中受到了極大的關注。由于膠囊網絡的研究尚處于起步階段,因此目前大多數膠囊網絡研究成果在復雜數據集上表現的分類性能較差。為解決這個問題,本文提出了一種新的膠囊網絡,即基于遷移學習的注意力膠囊網絡。該網絡通過使用遷移學習的方法改進傳統的特征提取網絡,并融合注意力機制模塊,進而完成圖像分類任務。首先,使用9層采用ELU激活函數的特征提取網絡提取特征;然后,將特征提取網絡在ImageNet數據集上訓練所得參數,通過遷移學習用于CIFAR10數據集上;再者,在特征提取網絡后加入注意力模塊提取關鍵特征;最后,在MNIST、FashionMNIST、SVHN和CIFAR10等公開數據集上進行了對比實驗。實驗結果表明,本文提出的膠囊網絡在簡單和復雜數據集上都取得了理想的分類效果。
關鍵詞:膠囊網絡;遷移學習;注意力機制;圖像分類
【Abstract】Inrecentyears,CapsuleNetwork(CapsNet)hasreceivedgreatattentioninimageclassificationbecauseitreplacestraditionalneuronswithcapsulesandovercomesthedefectsoflosingpositioninformationinConvolutionalNeuralNetwork(CNN).SincetheresearchofCapsNetisstillinitsinfancy,mostresearchresultsofCapsNethavepoorclassificationperformanceoncomplexdatasets.Tosolvethisproblem,anewcapsulenetworkisproposedtocompletetheimageclassificationtask,namedAttentionCapsuleNetworkbasedonTransferLearning,byimprovingthefeatureextractionnetworkthroughtransferlearningandintegratingtheattentionmodule.Firstly,a9-layerfeatureextractionnetworkwiththeELUactivationfunctionisusedtoextractfeatures;secondly,theparametersobtainedfromthefeatureextractionnetworktrainingontheImageNetdatasetareusedontheCIFAR10datasetthroughTransferLearning;thirdly,theattentionmoduleisstackedafterthefeatureextractionnetworktoextractkeyfeatures.Finally,experimentsonpublicdatasetsincludingCIFAR10,SVHN,MNIST,andFashionMNISTshowthattheproposedAttentionCapsuleNetworkbasedonTransferLearningcanachieveidealclassificationaccuracyonbothsimpleandcomplexdatasets.
【Keywords】CapsuleNetwork;TransferLearning;attentionmechanism;imageclassification
作者簡介:朱佳麗(1996-),女,碩士研究生,主要研究方向:圖像處理;宋燕(1979-),女,博士,副教授,博士生導師,主要研究方向:大數據算法、圖像處理、預測控制。
0引言
自2017年GeoffreyHinton首次提出膠囊網絡(CapsuleNetwork,CapsNet)[1]以來,CapsNet受到了廣泛的關注并被逐漸應用于計算機視覺的各項任務中[2-3]。CapsNet使用膠囊代替了傳統卷積神經網絡(ConventionalNeuralNetwork,CNN)中的神經元,利用轉換矩陣表示物體之間的位置關系,極大地克服了CNN由池化操作[4-5]帶來的信息丟失的缺陷,因此CapsNet在訓練樣本較少的情況下也能有效地提取出圖像的基本特征。與傳統的CNN相比,CapsNet能在目標重疊的情況下識別多個目標,并且對仿射變換具有一定的魯棒性。
目前大部分膠囊網絡在復雜數據集上的表現較差,究其原因即在于膠囊網絡提取特征僅使用了一層卷積,難以有效提取出目標的有效信息,因此可以利用深度神經網絡提取特征。但在深度神經網絡上,僅通過CIFAR10這樣的小數據集很難學習到好的參數,提取到合理的特征。近年來,遷移學習[6]也迎來了業界研究熱潮,這是一種運用已有知識對不同領域問題進行求解的機器學習方法[7-8]。2016年的NIPS會議上,吳恩達指出遷移學習可以在樣本不充足的情況下達到較好的分類識別效果,會在未來的人工智能領域占據著舉足輕重的地位。但針對大量的特征,如何尋找出關鍵特征也至關重要。2018年,卷積注意力機制(ConvolutionalBlockAttentionModule,CBAM)[9]的提出,用于特征優化,在計算機視覺的各個領域均取得了顯著的效果[10-11]。
因此,本文首先使用一個采用ELU激活函數[12]的深層神經網絡提取特征,通過遷移學習將該特征提取網絡在ImageNet數據集上訓練得到的參數遷移到訓練CIFAR10數據集的網絡中,可以充分提取CIFAR10數據集上的特征,然后在特征提取網絡后引入CBAM注意力模塊,提出了一種全新的基于遷移學習的注意力膠囊網絡用于圖像分類。本文提出的方法具有以下優點:
(1)在特征提取部分使用遷移學習,利用從大型數據集ImageNet訓練得到的參數,提取CIFAR10數據集的特征,可以更充分地提取特征。
(2)在遷移學習后面加入注意力機制模塊,可以從已提取的特征中提取出關鍵特征。
(3)使用ELU激活函數,針對ReLU的負數部分進行改進,避免了函數左側輸入為負時,梯度為0的情況。
最后,本文在MNIST、FashionMNIST、SVHN和CIFAR10數據集上進行實驗,結果表明無論在簡單還是復雜數據集上,本文提出的基于遷移學習的注意力膠囊網絡都取得了良好的分類精度。
1膠囊網絡
2遷移學習
遷移學習(TransferLearning)[6]是一種機器學習方法,可以在目標任務的高質量訓練數據較少的情況下,將之前的一些任務中的知識轉移到目標任務中,使得目標任務能夠取得更好的學習效果。一般原始任務數據集有大量的標注數據,而目標領域數據集較小,遷移學習方法主要有基于特征的遷移學習和基于參數的遷移學習[13]。其學習過程如圖3所示。
2.1基于特征的遷移學習
基于特征選擇的遷移學習方法是識別出原始任務與目標任務中共有的特征表示,減少原始任務和目標任務中的差別,并利用這些特征進行知識遷移。首先利用原始任務和目標任務中的共有特征訓練一個分類器,接著用目標領域中的無標簽樣本的特征優化分類器,那些與樣本類別相關度高的特征會在訓練得到的模型中被賦予更高的權重。
2.2基于參數的遷移學習
基于參數的遷移學習是找到原始數據和目標數據的空間模型之間的共同參數或者先驗分布,從而可以通過原始數據的特征進一步處理目標數據,實現知識遷移,在學習原始任務中的每個模型的參數或者先驗分布會共享給目標數據。
3卷積注意力機制
卷積注意力模塊(CBAM)[9]是一種基于前饋卷積神經網絡的注意力模塊。對于給定的特征,CBAM模塊通過通道和空間兩個獨立的維度推測注意力特征,將注意力特征與輸入特征相乘得到輸出特征,實現特征的優化。CBAM可以集成到任意的卷積神經網絡中,其結構圖如圖4所示。
空間注意力模塊如圖6所示。將通道注意力模塊輸出的特征圖作為輸入特征圖,首先基于通道進行最大池化和平均池化操作,然后將這2個結果連接生成一個特征描述符。然后經過一個卷積操作進行降維。再經過sigmoid生成空間注意力特征。最后將該特征和該模塊的輸入特征做乘法得到最終生成的特征。可由如下公式計算得出:
4激活函數
激活函數在神經網絡中引入了非線性,是神經網絡中不可或缺的一部分。如果不使用激活函數,則神經網絡的每一次的輸入都是上一層的線性輸入,這樣的網絡無論有多少層,都只能表示是線性關系,但引入了激活函數后,神經網絡可以擬合各種非線性函數,大大拓展了神經網絡的應用[14]。常用的激活函數有sigmoid激活函數、ReLU激活函數、ELU激活函數等。
對此可做分析概述如下。
(1)sigmoid函數。作為最簡單的激活函數,其數學公式可寫為:
σ(x)=11+e-x,(7)
但在神經網絡進行反向傳播時,sigmoid函數容易導致梯度消失和梯度爆炸。
(2)ReLU激活函數。是目前使用最廣泛的激活函數。當輸入值小于零時,輸出值為零。當輸入值大于等于零時,輸出值等于輸入值。其數學公式可寫為:
f(x)=max(0,x),(8)
ReLU激活函數訓練速度很快,且當輸入為正數時,不會造成梯度爆炸。但當x<0時,梯度為0,這樣導致了負的梯度置零,該神經元不會再被任何數據激活。
(3)ELU激活函數。針對ReLU的負數部分進行的改進,右側線性部分能夠緩解梯度消失,而左側能夠保證在激活函數輸入為負時,梯度不為0。其公式如下所示:
f(x)=x,x>0,α(ex-1),x<0.(9)
5基于遷移學習的注意力膠囊網絡
5.1特征提取網絡
特征提取在圖像分類中起著核心作用,由于CapsNet具有很強的細節解釋能力,因此能夠很好地處理圖像中的重疊問題。但對于復雜數據集,CapsNet反而可能會提取出一些瑣碎的、不合理的特征,從而導致分類精度的下降。例如,CapsNet在手寫數字數據庫(MNIST)上取得了良好的結果,但在CIFAR10數據集[14]上表現較差。
為解決這個問題,本文采用了遷移學習的方法。CIFAR10數據集較小,只有少量的標注數據,在進行訓練時,很難通過這類小型數據集學習到合理的參數,參數的不合理直接導致了提取的特征不合理。因此可以利用大量的高質量標注數據如ImageNet數據集,通過預訓練模型得到合適的參數用于小數據集的訓練,方便提取合適的特征用于分類。
本文采用一個9層的深層網絡用于特征提取,該網絡包括9個卷積層和3個池化層。其結構如圖7所示。由圖7可知,所有卷積層均采用相同大小的3×3的卷積核,設置步長為1,填充1個像素,使用ELU激活函數,這樣使得每一個卷積層都能與前一層保持同樣的大小;池化層利用大小為2×2的矩陣進行最大池化。
CIFAR-10數據集的圖像數據原始大小為32×32,為了方便知識遷移,將ImageNet數據集的圖像數據大小處理成32×32后作為特征提取網絡的輸入。前三次卷積操作有64個大小為3×3卷積核,檢測二維圖像的初級特征;接著進行最大池化后,再使用256個卷積核,進行3次卷積;然后繼續進行最大池化操作和3次512個卷積核的卷積操作提取出合理的特征。
5.2基于遷移學習的注意力膠囊網絡
通過特征提取網絡在ImageNet數據集上的訓練,由此得到了該網絡的參數,并用于CIFAR10數據集中,在提取出合適的特征后,本文通過增加注意力機制模塊來提取關鍵特征,實現高精度分類,稱為基于遷移學習的注意力膠囊網絡,網絡結構如圖8所示。
輸入圖像通過圖8中的特征提取模塊,得到8×8×256的張量作為圖4中的注意力模塊的輸入,輸出仍然為8×8×256的張量,用于提取關鍵特征;接著的初級膠囊層的卷積核大小為3×3,步長為1,使用ELU激活函數,輸出是32個大小為6×6×8膠囊;最后是數字膠囊層,由10個16維的數字膠囊組成,使用動態路由進行分類,每個16維膠囊代表一個特定的圖像類別。
6實驗結果與分析
6.1數據集與評價標準
本文借助CIFAR10數據集[15],驗證了所提出的膠囊網絡的有效性。CIFAR10是一個真實世界物體的小數據集,圖像大小為32×32。與MNIST數據集相比,CIFAR10由真實世界中的目標組成,不僅存在較多的噪聲,而且目標的比例和特征不同,給識別帶來了較大的困難。在實驗中,將學習率設置為0.001,批量大小設置為64。
6.2仿真實驗結果分析
不同改進的膠囊網絡在CIFAR-10數據集上的分類準確率見表1。由Hinton提出的膠囊網絡準確率僅有68.95%,Prem等人[16]提出的膠囊網絡的準確率為68.49%,由Xi等人[17]設計的膠囊網絡分類準確率達到了71.51%。在實驗中,在膠囊網絡中分別引入BAM和CBAM注意力機制,精度分別達到74.52%和75.16%,當改進網絡中的激活函數,分類精度也略有提升。結合圖10中的特征提取網絡改進膠囊網絡,精度提高至77.84%,結合遷移學習,引入IamgeNet數據集上訓練的參數,精度可以達到79.93%。最終結合遷移學習和CBAM注意力機制,本文提出的基于遷移學習的注意力膠囊網絡在CIFAR-10數據集分類精度達到了81.34%。
為了說明本文提出的網絡的泛化能力,還在其他公共數據集(MNIST數據集、FashionMNIST數據集和SVHN數據集)上進行了實驗,結果見表2。
傳統膠囊網絡以及各改進膠囊網絡在簡單的MNIST數據集上均達到了很好的效果。而傳統膠囊網絡在稍復雜的FashionMNIST數據集的分類精度沒有特別理想,為88.19%;引入不同的注意力機制后,網絡精度達到了90.54%和91.76%;通過改進特征提取網絡,分類精度可以達到91.58%,結合遷移學習,精度提高至92.53%;在基于遷移學習的注意力膠囊網絡,分類精度最高達到94.07%。傳統膠囊網絡SVHN數據集上分類精度僅有82.81%,在分別增加BAM注意力機制和CBAM注意力機制后,分類精度大幅提升達到89.69%和91.56%,改進特征提取網絡遷移學習,分類精度可以達到92.91%,在基于遷移學習的注意力膠囊網絡中,分類精度最高達到94.59%。顯然,本文提出的基于遷移學習的注意力膠囊網絡在不同數據集上具有最好的效果。
7結束語
本文提出了一種新的膠囊網絡,即基于遷移學習的注意力膠囊網絡,該網絡充分提取圖像的有效基本特征,并篩選出關鍵特征。本文提取特征時采用遷移學習的方式,利用從大型數據集ImageNet訓練得到的參數,提取CIFAR10數據集的特征,此后引入CBAM注意力機制用于提取關鍵特征。并且所用網絡中的激活函數都采用ELU激活函數,有效地避免神經元壞死。通過對比實驗證明,無論是在簡單數據集MNIST還是復雜數據集FashionMNIST、SVHN和CIFAR10上,論文提出的基于遷移學習的注意力膠囊網絡在分類精度達到了理想的結果。下一步,將針對初級特征提取不充分問題,在提取特征的網絡上進行改進,構建分類精度更高的膠囊網絡。
參考文獻
[1]SABOURS,FROSSTN,HINTONGE.Dynamicroutingbetweencapsules[C]//AdvancesinNeuralInformationProcessingSystems.LongBeachm,California,USA:NeuralInformationProcessingSystemsFoundation,Inc.(NIPS),2017:3856-3866.
[2]王弘中,劉漳輝,郭昆.一種基于混合詞向量的膠囊網絡文本分類方法[J].小型微型計算機系統,2020,41(1):218-224.
[3]王金甲,紀紹男,崔琳,等.基于注意力膠囊網絡的家庭活動識別[J].自動化學報,2019,45(11);2199-2204.
[4]BOUREAUYL,PONCEJ,LECUNY.Atheoreticalanalysisoffeaturepoolinginvisualrecognition[C]//Proceedingsofthe27thInternationalConferenceonMachineLearning(ICML-10).Haifa,Israel:ACM,2010:111-118.
[5]SCHERERD,MLLERA,BEHNKES.Evaluationofpoolingoperationsinconvolutionalarchitecturesforobjectrecognition[M]//DIAMANTARASK,DUCHW,ILIADISLS.Artificialneuralnetworks-Icann2010.LecturenotesinComputerScience.Berlin/Heidelberg:Springer,2010,6354:92-101.
[6]PANSJ,YANGQ.Asurveyontransferlearning[J].IEEETransactionsonKnowledgeandDataEngineering,2010,22(10):1345-1359.
[7]陳炳超,洪佳明,印鑒.基于遷移學習的圖分類[J].小型微型計算機系統,2011,32(12):2379-2382.
[8]洪佳明,陳炳超,印鑒.一種結合半監督Boosting方法的遷移學習算法[J].小型微型計算機系統,2011,32(11):2169-2173.
[9]WOOS,PARKJ,LEEJY,etal.CBAM:Convolutionalblockattentionmodule[C]//EuropeanConferenceonComputerVision.Munich,Germany:dblp,2018:3-19.
[10]盧玲,楊武,王遠倫,等.結合注意力機制的長文本分類方法[J].計算機應用,2018,38(5):1272-1277.
[11]苑威威,彭敦陸,吳少洪,等.自注意力機制支持下的混合推薦算法[J].小型微型計算機系統,2019,40(7):1437-1441.
[12]XUBing,WANGNaiyan,CHENTianqi,etal.EmpiricalEvaluationofRectifiedActivationsinConvolutionalNetwork[J].arXivpreprintarXiv:1505.00853,2015.
[13]莊福振,羅平,何清,等.遷移學習研究進展[J].軟件學報,2015,26(1):26-39.
[14]張濤,楊劍,宋文愛,等.關于改進的激活函數TReLU的研究[J].小型微型計算機系統,2019,40(1):58-63.
[15]ZHANGJunbo,ZHENGYu,QIDekang,etal.Predictingcitywidecrowdflowsusingdeepspatio-temporalresidualnetworks[J].ArtificialIntelligence,2017,259:147-166.
[16]PREMN,ROHAND,STENFANK.Pushingthelimitsofcapsulenetworks[J].Technicalnote,2018.
[17]XIE,BINGS,JINY.Capsulenetworkperformanceoncomplexdata[J].arXivpreprintarXiv:1712.03480,2017.