






摘要:小樣本學習是圖像分類任務中的一個重要挑戰,能夠有效解決因數據量較少而產生的模型準確率降低的問題。針對小樣本學習難以準確獲取類內共有特征的問題,提出一種基于類注意力的原型網絡改進方法。利用掩膜圖像進行數據預處理和圖像增強,以提高原始數據質量;引入注意力機制,選擇性地關注特征圖中的重要信息,以增強特征提取能力;設計類注意力模塊,提取具有注意力信息的類別原型。實驗結果表明,在miniImageNet數據集上,該方法的分類準確率在基線基礎上提高了2%,驗證了其有效性。
關鍵詞:原型網絡;小樣本學習;數據增強;類注意力;圖像分類
中圖分類號:TP183""""""""""""文獻標志碼:A """""""""文章編號:1674-2605(2025)01-0009-07
DOI:10.3969/j.issn.1674-2605.2025.01.009"""nbsp;""""""""""""""""開放獲取
Improvement Method of Prototype Network Based on Class Attention
CAO Zenghui CHEN Hao CAO"Yahui
(1.Guangdong University of Technology, Guangzhou 510000, China
2.Zhengzhou Vocational College of Industrial Safety,"Zhengzhou 450000, China)
Abstract:"Small sample learning is an important challenge in image classification tasks, which can effectively solve the problem of reduced model accuracy due to limited data volume. A prototype network improvement method based on class attention is proposed to address the problem of difficulty in accurately obtaining common features within classes in small sample learning. Using mask images for data preprocessing and image enhancement to improve the quality of raw data; Introducing attention mechanism to selectively focus on important information in feature maps to enhance feature extraction capability; Design a class attention module to extract class prototypes with attention information. The experimental results show that on the miniImageNet dataset, the classification accuracy of this method has improved by 2% compared to the baseline, verifying its effectiveness.
Keywords:"prototype network; small sample learning; data enhancement; class attention; image classification
0 引言
在計算機視覺領域,圖像分類是一個重要且具有挑戰性的研究方向。傳統的圖像分類方法,如K近鄰算法、決策樹、隨機森林等,在小樣本場景下泛化能力和準確率有限。而小樣本學習在模型訓練階段僅用少量的標簽樣本即可完成分類任務,解決了因樣本數量較少而導致的模型準確率下降的問題。然而,小樣本學習存在泛化能力不足、過擬合、類別不平衡等問
題。為此,學者們提出了一系列的解決方案。其中,原型網絡[1]作為一種有效的模型框架被廣泛研究和應用。
原型網絡通過學習類別原型的特征,求取各個類別原型的表示,通過樣本與類別原型之間的距離進行分類,初步解決了類別不平衡的問題,但仍然存在因樣本數量較少而導致的難以準確獲取類內共有特征的問題。文獻[2]通過對訓練樣本的特征進行收縮和擴
展,生成額外的樣本,提高了模型的泛化能力。文獻[3]通過在特征空間進行隨機變換和插值操作,生成多樣化的樣本,幫助模型更好地學習特征。文獻[4]結合半監督學習與數據增強,通過弱增強生成偽標簽,強增強優化模型的一致性。以上文獻利用不同的圖像增強方法來增加樣本數量,但簡單的圖像變換無法有效增加樣本的多樣性。
針對上述現狀,本文提出一種基于類注意力的原型網絡改進方法。采用掩膜圖像進行數據預處理,增強圖像的質量和信息,改善小樣本數據質量;引入注意力機制區分無關特征和相關特征;設計類注意力模塊,提取具有注意力信息的類別原型表示,從而提高原型網絡在小樣本學習中的分類性能和泛化能力。
1 相關工作
1.1 原型網絡
原型網絡是一種基于距離度量的分類器[5],其先通過學習每個類別的原型向量來表示不同類別之間的關系,再通過計算查詢集樣本的特征向量與支持集每個類別原型向量之間的歐氏距離進行分類[6]。傳統的類別原型向量通常由每個類別所有樣本的特征向量進行均值計算得到。
1.2 數據增強方法
數據增強通過對訓練數據進行變換和擴充,增加數據的多樣性和數量,從而改善模型的泛化能力和魯棒性。常用的數據增強方法包括平移、旋轉、縮放、翻轉等幾何變換[8],以及亮度、對比度、色彩等顏色變換[9]。數據增強不僅可通過對原始圖像進行隨機變換來生成更多的訓練數據,還可通過剪切、填充、仿射等操作,改變原始圖像的形狀和結構。
近年來,數據增強技術在深度學習領域取得了較大進展。文獻[10]提出一種RandAugment數據增強方法,通過一系列的隨機變換來擴充訓練數據集;在ImageNet數據集上,模型的準確率在基線基礎上提升了1.3%。文獻[11]提出一種Mixup數據增強方法,通過在訓練樣本之間進行線性插值來生成新的樣本,有效地增加了樣本的多樣性。
1.3 注意力機制
注意力機制是指在神經網絡中,通過對輸入數據的不同部分進行加權處理,使網絡更加關注有用的信息,廣泛應用于自然語言處理、計算機視覺、語音識別等領域[12]。文獻[13]提出一種用于深度神經網絡的注意力機制,可自適應地調整輸入數據的通道權重,從而提高模型性能。文獻[14]提出一種高效通道注意力(efficient channel attention, ECA)模塊,利用自適應卷積核計算每個通道的權重,避免了傳統通道注意力機制因采用全局平均池化操作而導致的信息損失。文獻[15]提出一種基于空間注意力和通道注意力機制的網絡模塊,利用一組卷積核來學習每個空間位置的權重,并結合通道注意力機制來提高特征圖的表達能力。文獻[16]提出一種Non-local注意力機制,利用所有位置的特征信息計算每個位置的權重,以實現不同空間位置特征的加權,模型準確率在基線基礎上提高了2.3%。
在小樣本場景下,文獻[17]引入自適應注意力機制,根據樣本的重要性動態調整模型的注意力,提高了模型對關鍵樣本的學習能力。文獻[18]設計了元權重生成器和空間注意力生成器結構,并將分類預測得分改為對稱形式,以提高模型的泛化能力。文獻[19]通過引入多級注意力機制、特征金字塔結構、細粒度的注意力加權和端到端的訓練策略,有效改進了小樣本學習任務中的特征提取和分類性能,使模型能夠更好地適應小樣本的學習任務。
2 本文方法
2.1 訓練策略
2.2 原型網絡改進模型
在圖2的網絡模型中,將支持集圖像和查詢集圖像輸入同一特征提取模塊,獲取圖像的特征向量。支持集特征向量通過類注意力模塊獲取關注類內共同信息的類原型向量,通過計算查詢集樣本的特征向量與每個類原型向量的歐氏距離進行分類。
2.2.1 數據增強模塊
數據增強技術在小樣本學習中被廣泛采用[21]。由于數據集樣本具有主體位置不定、大小不等、背景復雜等特點,本文采用掩膜圖像對支持集圖像進行隨機區域掩膜,提升原型網絡對局部信息的補全,以及不完全信息圖像的識別能力。掩膜效果圖如圖3所示。
掩膜圖像方法獨立于參數學習過程,因此可以嵌入到任何基于卷積神經網絡(convolutional neural networks, CNN)的識別模型中。
2.2.2 特征提取模塊
將數據增強后的支持集圖像和查詢集圖像一起輸入到特征提取模塊,將所有支持集中的D維向量數據映射到新的Z維特征空間。特征提取模塊的特征提取器采用Vgg16模型作為主干網絡,并引入了注意力機制,以重點關注提取圖像中的重要信息。
2.2.3 類注意力模塊
類注意力模塊將支持集圖像進行類注意力信息的提取,得到帶有權值的類別原型表示。本文提出的類注意力模塊主要包括Extract和Interaction"2個模塊,如圖4所示。
Extract模塊用于壓縮、提取圖像數據。經過編碼后的類內K個C×H×W維度的特征向量,通過全局平均池化壓縮為K個C通道、1×1維的特征圖,即將每個樣本、每個通道內H×W維的圖像轉化為一個數字表示,得到K×C個類別內所有樣本的權值。提取圖像數據的計算公式為
2.2.4 距離度量模塊
距離度量模塊基于度量的方式來計算查詢集樣本的特征向量與支持集每個類別原型向量之間的距離,再轉化為相似性度量,從而判斷樣本類別。
3 實驗與評估
3.1 數據集
本實驗數據集采用miniImageNet,其包含60"000幅圖像,分為100個類別。采用文獻[20]的數據集劃分方式將訓練集、驗證集、測試集分別劃分為64、16、20個類別,同時將輸入圖像處理為84×84像素。
3.2 實驗環境
在Ubuntu操作系統上,采用開源深度學習框架PyTorch搭建模型,利用GPU進行實驗計算,以提高模型的迭代速度。為保證實驗的嚴謹性,設置固定的隨機順序來保證每次對比實驗抽取的樣本一致。采用Vgg16模型作為主干網絡進行訓練,并確保每次實驗僅有驗證項發生改變。實驗環境如表1所示,實驗參數如表2所示。
3.3 評價指標
本實驗采用5-way"1-shot和5-way"5-shot的驗證模式,即在支持集中每次隨機選擇5個支持集類別,每個支持集類別分別有1個樣本和5個樣本進行實驗。利用查詢集中樣本的準確率來評估模型性能。準確率的計算公式為
3.4 實驗結果
3.4.1 "數據增強方法驗證實驗
選取翻轉、旋轉、隨機裁剪等不同的數據增強方法進行驗證實驗。其中,RandomCrop方法根據設置的參數隨機裁剪原始圖像;RandomHorizontalFlip、RandomVerticalFlip方法水平、垂直翻轉原始圖像;ColorJitter方法隨機修改原始圖像的亮度、對比度和飽和度;RandomRotation方法隨機角度旋轉原始圖像。實驗結果如表3所示。
由表3可以看出:RandomCrop方法的準確率在基線基礎上下降約8%,而RandomHorizontalFlip、RandomVerticalFlip、ColorJitter、RandomRotation、本文方法的準確率在基線基礎上分別提高了0.97%、0.51%、0.82%、0.05%、1.58%,本文方法的準確率提高最為顯著,表明本文數據增強方法有效。
3.4.2 "小樣本學習方法對比實驗
將匹配網絡(matching networks, MN)、關系網絡(relation networks,"RN)、記憶匹配網絡(memory matching networks,"MMN)、注意力吸引網絡(attention attractor networks,"AAN)、模型無關的元學習(model--agnostic meta-learning,"MAML)、Reptile、文獻[27]、文獻[28]、Prototypical network等9種經典的小樣本學習方法與本文方法進行對比實驗,結果如表4所示。
由表4可以看出:本文方法在5-way 1-shot任務上取得了53.42%的準確率,與其他方法相比處于較高水平;在5-way 5-shot任務上則取得了70.33%的準確率,優于表中所有對比方法,說明本文方法在少量樣本的場景下具有更出色的泛化能力。
4 結論
本文受原型網絡和注意力機制的啟發,利用數據增強方法增加樣本的多樣性,引入注意力機制提升網絡特的征提取能力,利用類注意力模塊改進原型網絡,解決小樣本學習因樣本多樣性不足導致的類內共有特征難以準確獲取的問題。實驗結果表明,數據增強方法能夠較好地增加數據樣本,提升模型對不同樣本的辨識性;類注意力機制能較好提取類內信息,更好地表示類別原型。
?The author(s) 2024. This is an open access article under the CC BY-NC-ND 4.0 License (https://creativecommons.org/licenses/ by-nc-nd/4.0/)
參考文獻
[1] 趙凱琳,靳小龍,王元卓.小樣本學習研究綜述[J].軟件學報,nbsp;2021,32(2):349-369.
[2] HARIHARAN B, GIRSHICK"R."Low-shot visual recognition by shrinking and hallucinating features[J]."IEEE Transactions on Pattern Analysis and Machine Intelligence, 2017,39(8): 1653-1667.
[3] DEVRIES T, TAYLOR"G W. Dataset augmentation in feature space"[J]. arXiv preprint arXiv:1702.05538, 2017.
[4] CUBUK E D, ZOPH B, MANE"D, et al. Autoaugment: Learning augmentation policies from data[J]. arXiv preprint arXiv:1805."09501, 2018.
[5] 王圣杰,王鐸,梁秋金,等.小樣本學習綜述[J].空間控制技術與應用,2023,49(5):1-10.
[6] 陳良臣,傅德印.面向小樣本數據的機器學習方法研究綜述[J].計算機工程,2022,48(11):1-13.
[7] SNELL J, SWERSKY K, ZEMEL R S."Prototypical networks for few-shot learning[J]. Advances in Neural Information pro-cessing Systems, 2017:30.
[8] SIMARD P Y, STEINKRAUS D, PLATT J C. Best practices for convolutional neural networks applied to visual document analysis[C]//7th International Conference on Document Anal-ysis and Recognition (ICDAR)."Edinburgh, UK: IEEE, 2003.
[9] KRIZHEVSKY A, SUTSKEVER I, HINTON"G."ImageNet classification with deep convolutional neural networks[J]."Communications of the ACM, 2017,60(6):84-90.
[10] CUBUK E D, ZOPH B, SHLENS"J, et al. Randaugment: Practical automated data augmentation with a reduced search space[C]//Proceedings of the IEEE/CVF Conference on Com-puter Vision and Pattern Recognition Workshops,"2020:702-703.
[11] ZHANG H, CISSE M, DAUPHIN Y N,"et al."Mixup: Beyond Empirical Risk Minimization[J]. arXiv preprint arXiv:1710. 09412, 2017.
[12] 彭云聰,秦小林,張力戈,等.面向圖像分類的小樣本學習算法綜述[J].計算機科學,2022,49(5):1-9.
[13] HU J, SHEN L, SUN G."Squeeze-and-Excitation networks[C]//"Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition,"2018:7132-7141.
[14] WANG Q, WU B, ZHU"P,"et al."ECA-Net: Efficient channel attention for deep convolutional neural networks[C]// Proceed-ings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition,"2020:11534-11542.
[15] WOO"S, PARK"J, LEE"J Y, et al. Cbam: Convolutional block attention module[C]//Proceedings of the European Conference on Computer Vision (ECCV),"2018:3-19.
[16]"WANG"X", GIRSHICK R, GUPTA A, et al. Non-local neural networks[C]//Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition,"2018:7794-7803.
[17] XING C, ROSTAMZADEH N, ORESHKIN B N,"et al."Adap-tive cross-modal few-shot learning[C]. Advances in Neural In-formation Processing Systems, 2019.
[18] JIANG Z, KANG B, ZHOU K, et al. Few-shot classification via adaptive attention[J]. arXiv preprint arXiv:2008.02465, 2020.
[19] 汪榮貴,韓夢雅,楊娟,等.多級注意力特征網絡的小樣本學習[J].電子與信息學報,2020,42(3):772-778.
[20] VINYALS O, BLUNDELL C, LILLICRAP T,"et al."Matching networks for one shot learning[J]."Advances in Neural Infor-mation Processing Systems, 2016:29.
[21] LI B, HOU Y, CHE"W."Data augmentation approaches in natural language processing: A survey[J]."AI Open, 2022,3:71-90.
[22] SUNG F, YANG Y, ZHANG L,"et al."Learning to compare: relation network for few-shot learning[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern"Recogni-tion,"2018:1199-1208.
[23] CAI Q, PAN Y W, YAO T,"et al. Memory matching net-works"for"one-shot image recognition[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern Recogni-tion,"2018:4080-4088.
[24] REN M, LIAO R, FETAYA"E,"et al."Incremental Few-Shot Learning with Attention Attractor Networks[C]. Advances in Neural Information Processing Systems, 2019.
[25] FINN C, ABBEEL P, LEVINE S. Model-agnostic meta-learning for fast adaptation of deep networks[C]//International Conference on Machine Learning. PMLR, 2017:1126-1135.
[26] NICHOL A, SCHULMAN J."Reptile: A"scalable metalearning algorithm[J]. arXiv preprint arXiv:1803.02999, 2018,2(3):4.
[27] RAVI S, LAROCHELLE H. Optimization as a model for few--shot learning[C]//International Conference on Learning Repre--sentations,"2017.
[28] YE H J, CHAO W L."How to train your"MAML to excel in few-shot classification[J]. arXiv preprint arXiv:2106.16245, 2021.
作者簡介:
曹增輝,男,1997年生,碩士研究生,主要研究方向:圖像處理和小樣本圖像分類。E-mail:"czh258biu@163.com
陳浩,男,2000年生,碩士研究生,主要研究方向:人工智能和原型網絡。E-mail:"chenhao_gd@163.com
曹雅慧,女,2003年生,專科,主要研究方向:人工智能。E-mail:"15103814269@163.com