丁英姿 丁香乾 郭保琪



摘 要:針對糖尿病視網膜病變分級檢測中標定樣本少、多目標檢測的問題,提出了一種基于改進型GoogLeNet的弱監督目標檢測網絡。首先,對GoogLeNet網絡進行改進,去掉最后一個全連接層并保留檢測目標的位置信息,添加全局最大池化層,以sigmoid交叉熵作為訓練的目標函數以獲得帶有多種特征位置信息的特征圖;然后,基于弱監督方法僅使用類別標簽對網絡進行訓練;其次,設計一種連通區域算法來計算特征連通區域邊界坐標集合;最后在待測圖片中使用邊界框定位病灶。實驗結果表明,在小樣本條件下,改進模型準確率達到了94.5%,與SSD算法相比,準確率提高了10%。改進模型實現了小樣本條件下端到端的病變識別,同時該模型的高準確率保證了模型在眼底篩查中具有應用價值。
關鍵詞:糖尿病視網膜病變;弱監督;卷積神經網絡;目標檢測網絡;全局最大池化
中圖分類號:?TP391.41
文獻標志碼:A
Application of improved GoogLeNet based on weak supervision in DR detection
DING Yingzi1,2, DING Xiangqian1, GUO Baoqi2*
1.College of Information Science and Engineering, Ocean University of China, Qingdao Shandong 266100, China ;
2.Big Data Joint Laboratory, Qingdao New Star Computer Engineering Center, Qingdao Shandong 266071, China
Abstract:?To handle the issues of small sample size and multi-target detection in the hierarchical detection of diabetic retinopathy, a weakly supervised target detection network based on improved GoogLeNet was proposed. Firstly, the GoogLeNet network was improved, the last fully-connected layer of the network was removed and the position information of the detection target was retained. A global max pooling layer was added, and the sigmoid cross entropy was used as the objective function of training to obtain the feature map with multiple feature position information. Secondly, based on the weak supervision method, only the category label was used to train the network. Thirdly, a connected region algorithm was designed to calculate the boundary coordinate set of feature connected regions. Finally, the boundary box was used to locate the lesion in the image to be tested. Experimental results show that under the small sample condition, the accuracy of the improved model reaches 94%, which is improved by 10% compared with SSD (Single Shot mltibox Detector) algorithm. The improved model realizes end-to-end lesion recognition under small sample condition, and the high accuracy of the model ensures its application value in fundus screening.
Key words:?Diabetic Retinopathy (DR); weak supervision; Convolutional Neural Networks (CNN); target detection network; Global Max Pooling (GMP)
0 引言
糖尿病視網膜病變(Diabetic Retinopathy, DR)是糖尿病嚴重的眼部并發癥,已經逐步發展成為眼部疾病致盲的主要原因。根據國際糖尿病聯合會(International Diabetes Federation, IDF)的報道,2017年全球糖尿病患者已經達到4.25億,與2000年的1.51億相比,增加近2倍。根據衛健委的統計,目前我國糖尿病視網膜病變的患病率為24.7%~37.5%。據統計50%的糖尿病病程在10年左右的患者可能出現該病變,15年以上者達80%。糖尿病病情越重,病程越久,發病的幾率越高。沒有得到診斷的糖尿病患者主要分布在不發達地區,醫療資源的分布不均、重視程度不夠,導致DR診斷不及時,最終導致視力受損、失明等嚴重后果,研究DR的自動診斷系統具有重要的意義。
針對DR自動診斷,傳統做法是采用支持向量機(Support Vector Machine, SVM)進行分類[1-2],生成自動篩查系統,輔助人工檢測。深度學習方法采用卷積神經網絡(Convolutional Neural Network, CNN)進行分類,通常在眼底篩查圖片輸入網絡前需要對其進行預處理,如對圖片進行去除背景、噪聲等,然后使用CNN[3-4]或者基于AlexNet的改進型網絡DrNet[5]進行分類等。采用基于弱監督定位的方法,使用全局平均池化(Global Average Pooling, GAP)改進ResNet(Residual Neural Network)[6]對血管瘤進行檢測,并未取得較好的定位效果。
DR分為非增殖性和增殖性兩種類型,其中非增殖性未生成血管,及時治療能夠有效預防不可逆轉的增殖性病變,本文采用的非增殖性DR分級標準如表1所示。
本文采用GoogLeNet Inception V3作為基礎網絡,對網絡進行改進,在實驗過程中對比發現,全局最大池化(Global Max Pooling, GMP)效果優于GAP,因此本文采用GMP層替換原有網絡的稀疏全連接層,使用Sigmoid交叉熵作為目標函數,獲取帶有多種特征位置信息的特征圖,然后通過連通區域計算,對病灶進行標定。 采用弱監督方式,使用帶有類別標簽的樣本,進一步對網絡進行訓練。在樣本量較小的情況下,基于弱監督定位的改進模型的檢測結果超過Faster R-CNN(Faster Region-CNN)和SSD(Single Shot multibox Detector)網絡。
1 目標檢測算法介紹
1.1 常見的目標檢測算法
目前在深度學習領域流行的目標檢測算法主要分為兩類:
一類是基于候選區域提取的目標檢測算法,主要包括R-CNN(Region-CNN)、Fast R-CNN(Fast Region-CNN)、Faster R-CNN、R-FCN(Region-based Fully Convolutional Network)等。R-CNN的提出奠定了此類方法的基礎,此方法從待測圖像中提取約2000個區域候選框,由于候選區存在大量重疊,使用CNN進行特征提取時會進行大量重復計算。Fast R-CNN[7]首先使用CNN提取圖像特征,然后生成候選區域,避免重復提取特征,從而顯著減少了處理時間。Ren等[8]提出的Faster R-CNN不生成候選區域,使用RPN(Region Proposal Network)結合錨點框對位置進行回歸,提高了訓練速度。針對Fast R-CNN和Faster R-CNN使用多個全連接層成本較高的問題,提出的R-FCN[9]在Fast R-CNN基礎上采用全卷積網絡,使用Position-sensitive score maps解決位置敏感性問題,大大提高了檢測速度。
另一類是基于回歸的目標檢測算法,相對于前一類算法,該類算法的精度略低,但是不需要進行區域提取,計算速度較快。代表性算法有YOLO (You Only Look Once)和SSD。YOLO比Faster R-CNN速度快,但是由于規定圖像尺寸以及使用網格進行目標檢測,只能預測一個類別。SSD算法則是YOLO與Faster R-CNN的結合,提高了速度又保證了準確度。
這兩類算法的核心均為卷積神經網絡CNN,首先由CNN對特征進行提取,解決目標的分類問題,然后由定位網絡解決目標的定位問題,網絡結構復雜、參數量大,在訓練過程中需要使用大量的樣本。
1.2 基于弱監督定位的卷積神經網絡
卷積神經網絡在訓練過程中需要大量的標記圖像,而圖像數據標記需要大量時間與人力,因此成本較高。在醫療領域獲得數據更加困難,以眼底篩查為例,通常一張眼底照片的標記費用需要50元。
近期研究表明,弱監督卷積網絡神經網絡在定位問題上研究取得了不少的進展。Oquab等[10]使用ImageNet數據集對AlexNet網絡進行預訓練,然后將其遷移到目標任務中,在目標任務中使用500個重疊的窗口進行滑動定位,并將結果與其他網絡進行對比。在后續研究工作中,Oquab等[11]對AlexNet網絡進行改進,使用卷積層替換全連接層,通過GMP輸出分類目標的邊緣位置信息。
Zhou等[12-13]進一步證實了卷積神經網絡可以提取特征的位置信息,在研究過程中,使用GAP替換網絡中的全連接層,然后使用“類別激活映射圖” (Class Activation Map, CAM)進行弱監督定位取得了較好的效果。
在周博磊的對比實驗中, GoogLeNet-GAP網絡要優于VGGNet、AlexNet、NIN(Network in Network),如表2所示。
GoogLeNet是2014年ImageNet的ILSVRC14(Large Scale Visual Recognition Challenge 2014)競賽冠軍。GoogLeNet增加了網絡的深度和寬度,使用原有的深度學習架構,會導致計算大大增加,也容易導致過擬合。為解決這一問題,提出了使用稀疏的全連接層替換原架構的全連接層[14],引入1×1卷積進行降維,借鑒NIN[15]中“mlpconv”模塊的設計思想,設計了Inception模塊結構;由于非均勻的稀疏矩陣在現有計算架構下計算效率低下,采用Inception結構能夠將稀疏矩陣聚類為相對密集的子矩陣,能夠有效地降低參數量,節省計算資源,從而提高計算的效率。
Szegedy等[16]對Inception模塊進行了改進,衍生出了多個版本的Inception模塊,其中 Inception V3將二維卷積進行非對稱拆分,拆分成為兩個較小的卷積,即將n×n卷積拆分為1×n卷積和n×1卷積,有效地降低了參數量,可以處理更多、更豐富的空間特征,增加特征的多樣性。本文采用Inception V3網絡結構。
3 基于弱監督的GoogLeNet-GMP
DR分級診斷病理特征的分類與定位采用端到端的設計思想,即以整張圖片作為輸入、輸出。首選使用改進型GoogLeNet提取帶有位置信息的特征分類圖,然后使用連通區域算法,對目標位置進行計算,實現特征的定位。
3.1 網絡結構設計
面向糖尿病視網膜病變分級檢測的病理特征提取與定位網絡GoogLeNet-GMP網絡結構的設計參考了Oquab等[11]的GMP網絡的設計,使用Zhou等[13]的CAM設計思想來加強特征圖中的位置信息。
GoogLeNet-GMP網絡結構如圖2所示,采用GoogLeNet Inception V3作為基礎網絡,在最后一個Inception模塊之后添加GMP層,然后使用Sigmoid全連接層替換原網絡的稀疏全連接層。令 f k(x,y)代表最后一個卷積層的第k個特征圖,經過GMP后,mk如式(1)所示,類別i的得分Si如式(2)。以Sigmoid交叉熵函數為目標函數訓練,逼近多類標簽的概率分布,根據Zhou等[13]的定理生成類別激活熱圖 M i表達式如式(3)。
mk=max x,y { f k(x,y)}
(1)
S i=Sigmoid ( ∑ k wik m k-bi )
(2)
M i=∑ k wik? f k(x,y)
(3)
在實驗過程中,對全局池化GAP和GMP的定位效果進行對比,GMP效果略優于GAP,本文在構建網絡的過程中采用GMP。
GoogLeNet-GMP網絡結構參數如表3所示,其中Inception模塊按結構差異分為5類,在此不進行具體描述。特征圖通過GMP后,輸出大小為1×1×2048。
GoogLeNet網絡最后使用softmax函數進行分類,在DR檢測過程中,需要對多個目標進行檢測,因此使用Sigmoid交叉熵函數進行替換,使用所有樣本的平均Sigmoid交叉熵函數值作為目標函數,對網絡進行訓練,實現多分類預測。
3.2 定位算法
通過使用GMP層生成待測圖片的激活熱圖。由于激活熱圖的連通區域對應了檢測目標的位置,計算連通區域的包圍矩形即可實現對檢測目標的定位,具體算法如下:
假設輸入為GMP層某類的激活特征圖 M i,學習到的特征圖偏移量 b i,特征圖閾值為θ,二分閾值為δ,原圖的大小為Size,輸出為特征圖 M i的所有激活區域的邊界框坐標集合Li。連通區算法描述如下所示:
程序前
M i= M i- b i
m=max{Sigmoid( M i)}
if ?m>θ
M ′i= M i>δ?1:0
f=Size/ M i.Size
N i=labelconnectivity( M ′i)
C i=regioncrops( N i)
fo r? C ij in? C i
Pij=max{ M i(x,y)|(x,y)∈ C ij}
if ?Pij>θ
xmin,ymin,xmax,ymax= C ij.bbox
L i= L i∪{(xmin,yminx,xmax,ymax) f }
程序后
其中 M ′i為二分圖,縮放因子 f , N i為二連通區域標記, C i為連通區域剪裁坐標集合。
4 訓練方法
4.1 實驗工具與預訓練
深度學習框架采用Google的Tensorflow 1.12,Python版本為3.5。
改進模型的基礎框架為GoogLeNet Inception V3版本,使用ImageNet數據集進行預訓練,設置模型的初始參數;從訓練集中隨機抽取一份數據作為驗證集對網絡的超參數進行探索,最終將一些超參數設置為固定值,如初始學習率、學習率衰減因子、衰減周期等;然后使用本項目數據集對GoogLeNet-GMP網絡進行訓練。
4.2 弱監督樣本標定方法
本文方法不需要對每張圖片標定具體病灶的groundbox,只需標記圖片所包含的類別標簽。如原圖中同時存在出血點和硬滲出,那么label的相應位標記為1,其余位置標記為0。本文對出血點、血管、硬滲出、軟滲出、視盤進行了識別,因此label=(1,0,1,0,0)。網絡通過Sigmoid交叉熵逼近label的分布,并產生具有定位信息的激活熱圖。
4.3 數據增廣
訓練數據采用DIARETDB1數據集和Kaggle公開數據集,由于數據集中類別不均衡(Kaggle數據集“正常類”樣本為73%),需要對數據集中的數據進行篩選,提取約700張符合國內DR分級標準Ⅰ~Ⅲ級的眼底照片。數據集輸入模型之前,本文不對圖片進行任何處理,如圖片背景去除、尺寸調整、顏色或亮度調整等,降低了生成模型在實際應用中的復雜度,提升了模型效率。
與其他圖像識別項目的數據集相比,本文的數據集較小,因此在網絡訓練的過程中,為了避免過擬合,采用增廣技術進行數據增廣,主要包括:
1)翻轉:隨機地進行水平、垂直翻轉;
2)剪切:在圖像中隨機選取帶有標簽信息的區域進行剪裁,然后將圖片切片大小擴展到299×299。
5 實驗結果與分析
本文進行實驗的硬件環境:CPU Intel Xeon E5-2630 2個,顯卡GTX1080TI 4個,內存128GB。軟件環境:計算機操作系統為Ubuntu 16.04,訓練平臺為TensorFlow1.12,Python 3.5。
本文數據集使用開源數據庫DIARETDB1和Kaggle數據集約700張圖像,以及眼科專家提供的539張圖像。眼底篩查圖像的采集使用佳能眼底采集設備,圖片的分辨率較高,大小在3000×3000左右。在實驗過程中,使用生成的系統對實驗室同事進行眼底篩查,相機采用的佳能80D。與其他DR分類檢測研究不同[3-5],本文訓練的模型直接使用設備采集的圖片進行分析,不需要對圖片進行預處理,同時可以對多種特征進行提取定位,并在原圖行進行標定,因此生成的眼底篩查系統能夠快速產業化。
實驗過程中,使用Faster R-CNN和SSD算法進行了對比實驗,結果如表4所示。由于使用的樣本數較少,使用Faster R-CNN進行訓練時模型無法收斂;使用SSD算法進行處理的識別準確率為85.4%;GoogLeNet-GMP的準確率更高,達到了94.5%。
使用GoogLeNet-GMP模型對數據進行分析,圖像通過Sigmoid全連接層后,獲取的特征圖如圖3所示。GoogLeNet Inception V3默認圖片的大小為299×299,在分析過程中,由系統對輸入的原始圖像進行裁剪。由于使用Sigmoid函數進行處理,出血點、硬滲出、軟滲出的特征激活閾值分別為0.15、0.1和0.1。在實際分析的過程中,同時對血管、視盤進行了識別,有益于排除干擾因素,提高出血點、硬滲出、軟滲出識別的準確率。圖3中未添加血管、視盤的特征圖,從圖中可以看出,存在大量的出血點以及硬滲出、軟滲出,符合DR分級Ⅲ級的標準,建議進行專家診斷與治療。
最終生成的DR檢測結果如圖4所示,圖中一共發現出血點區域12個,硬滲出區域19個,軟滲出區域4個,符合糖尿病Ⅲ期特征。
6 結語
本文提出了一種基于弱監督的目標檢測網絡GoogLeNet-GMP,結合連通區域算法,實現了糖尿病視網膜病變的分級以及病灶的定位。該算法在最后一層卷積層之后添加GMP層,并替換最后一層稀疏全連接層,使用Sigmoid交叉熵函數替換Softmax函數,能實現多種特征的檢測與定位;最后使用連通區域算法,對特征邊界進行計算,在原圖生成標記框。本文方法直接使用數據集及采集設備數據進行分析,能夠同時對多種特征進行檢測,準確率高于幾種經典算法。
同時,本文方法也存在一定的局限性,樣本數據集分布不均,訓練樣本數據相對較少,在投入產業化之前需要在實際檢測環境中進行檢驗。目前已經生成Web應用,由眼科專家進行試用。后續研究工作主要分為兩個方向:一是加強與專家的合作,增加數據集,在實踐中檢驗完善算法模型;一是結合強化學習算法,增加算法的自優化能力。
參考文獻
[1]?PRIYA R,ARUNA P. Review of automated diagnosis of diabetic retinopathy using the support vector machine [J]. International Journal of Applied Engineering Research, 2011, 1(4):844-862.
[2]?PRIYA R, ARUNA P. SVM and neural network based diagnosis of diabetic retinopathy [J]. International Journal of Computer Applications,2012,41(1):6-12.
[3]?丁蓬莉.基于深度學習的糖尿病性視網膜分析算法研究[D].北京:北京交通大學,2017:22-23. (DING P L. Research of diabetic retinal image analysis algorithms based on deep learning [D]. Beijing: Beijing Jiaotong University, 2017:22-23.)
[4]?蔡石林.基于CNN的糖尿病視網膜病變識別算法研究與實現[D].長沙:湖南大學,2018:22-25. (CAI S L. Research and implementation on diabetic retinopathy recognition algorithm based on CNN [D]. Changsha: Hunan University, 2018:22-25.)
[5]?馬文俊.基于機器學習的糖尿病視網膜病變分級研究[D].哈爾濱:哈爾濱工程大學,2017:28-31. (MA W J. Study on classification of diabetic retinopathy based on machine learning [D]. Harbin: Harbin Engineering University, 2017:28-31.)
[6] ?張德彪.基于深度學習的糖尿病視網膜病變分類和病變檢測方法的研究[D].哈爾濱:哈爾濱工業大學,2017:25-29. (ZHANG D B. Research on diabetic retinopathy classification and lesion detection based on deep learning [D]. Harbin: Harbin Institute of Technology, 2017:25-29.)
[7]?GIRSHICK R. Fast R-CNN [C]// Proceedings of the 2015 IEEE International Conference on Computer Vision. Piscataway, NJ: IEEE, 2015: 1440-1448.
[8]?REN S, HE K, GIRSHICK R, et al. Faster R-CNN: towards real-time object detection with region proposal networks [C]// Proceedings of the 2015 International Conference on Neural Information Processing Systems. Cambridge, MA: MIT Press, 2015:91-99.
[9]??DAI J, LI Y, HE K, et al. R-FCN: object detection via region-based fully convolutional networks [C]// Proceedings of the 30th International Conference on Neural Information Processing Systems. North Miami Beach, FL: Curran Associates Inc., 2016: 379-387.[J]. arXiv E-print, 2016: arXiv:1605.06409.[EB\OL]. [2019-01-22]. https://arxiv.org/pdf/1605.06409v2.pdf.
[10]?OQUAB M, BOTTOUB L, LAPTEV I, et al. Learning and transferring mid-level image representations using convolutional neural networks [C]// Proceedings of the 2014 IEEE Conference on Computer Vision and Pattern Recognition. Washington, DC: IEEE Computer Society, 2014: 1717-1724.
[11]?OQUAB M, BOTTOUB L, LAPTEV I, et al. Is object localization for free? — weakly-supervised learning with convolutional neural networks [C]// Proceedings of the 2015 IEEE Conference on Computer Vision and Pattern Recognition. Washington, DC: IEEE Computer Society, 2015: 685-694.
[12]?ZHOU B, KHOSLA A, LAPEDRIZA A, et al. Object detectors emerge in deep scene CNNs [J]. arXiv E-print, 2015: arXiv:1412.6856.?[EB/OL]. [2019-01-22]. https://arxiv.org/pdf/1412.6856.pdf.
[13]?ZHOU B, KHOSLA A, LAPEDRIZA, OLIVA A, et al. Learning deep features for discriminative localization [C]// Proceedings of the 2016 the IEEE Conference on Computer Vision and Pattern Recognition, Washington, DC: IEEE Computer Society, 2016: 2921-2929.
[14]?SZEGEDY C, LIU W, JIA Y, et al. Going deeper with convolutions [C]// Proceedings of the 2015 IEEE Conference on Computer Vision and Pattern Recognition. Washington, DC: IEEE Computer Society, 2015: 1-9.
[15]?LIN M, CHEN Q, YAN S. Network in network [J]. arXiv E-print, 2014: arXiv:1312.4400.?[EB/OL]. [2019-01-22]. https://arxiv.org/pdf/1312.4400.pdf.
[16]?SZEGEDY C, VANHOUCKE V, LOFFE S, et al. Rethinking the inception architecture for computer vision [C]// Proceedings of the 2016 IEEE Conference on Computer Vision and Pattern Recognition. Washington, DC: IEEE Computer Society, 2016: 2818-2826.