摘 要:領域自適應將源域上學習到的知識遷移到目標域上,使得在帶標簽數據少的情況下也可以有效地訓練模型。采用偽標簽的領域自適應模型未考慮錯誤偽標簽的影響,并且在決策邊界處樣本的分類準確率較低,針對上述問題提出了基于加權分類損失和核范數的領域自適應模型。該模型使用帶有偽標簽的可信樣本特征與帶有真實標簽的源域樣本特征構建輔助域,在輔助域上設計加權分類損失函數,降低錯誤偽標簽在訓練過程中產生的影響;加入批量核范數最大化損失,提高決策邊界處樣本的分類準確率。在Office31、Office-Home、Image-CLEFDA基準數據集上與之前模型的對比實驗表明,該模型有更高的精確度。
關鍵詞:領域自適應; 加權分類損失; 核范數; 偽標簽
中圖分類號:TP183 文獻標志碼:A
文章編號:1001-3695(2023)06-021-1734-05
doi:10.19734/j.issn.1001-3695.2022.10.0514
Domain adaptation based on weighted classification loss and nuclear-norm
Du Shelin, Huang Binghe, Li Rongpeng, Song Xueli, Xiao Yuzhu
(School of Science, Chang’an University, Xi’an 710064, China)
Abstract:Domain adaptation transfers the knowledge learned from the source domain to the target domain, so that the model can be effectively trained in the case of less labeled data. The domain adaptation models using pseudo-labels do not consider the influence of 1 pseudo-labels, and the classification accuracy of samples at the decision boundary is low. For the above problems, this paper proposed domain adaptation model based on weighted classification loss and nuclear-norm. The model used confident sample features and their pseudo-labels, and constructed an auxiliary domain with the source domain sample features with real labels. It designed a weighted classification loss function on the auxiliary domain reduced the influence of 1 labels in the training process. Batch nuclear-norm maximization loss improved the accuracy of sample pseudo-labels at the decision boundary. The comparison experiments with previous models on the benchmark datasets of Office31, Office-Home and Image-CLEFDA illustrate that this method has higher accuracy.
Key words:domain adaptation; weighted classification loss; nuclear-norm; pseudo-label
0 引言
在大數據時代,隨著數據規模和計算資源的快速增長,數據變得容易獲取,但是大多數據無標簽,無法直接用于模型訓練,進行人工標注耗時且費力。然而,深度學習和機器學習的發展需要大量帶標簽的數據,帶標簽數據的短缺導致機器學習和深度學習在某些方面不能充分發展。遷移學習可以有效解決上述問題,遷移學習將在源域的帶標簽數據上學習到的知識遷移到目標域中[1]。按照特征的屬性對遷移學習進行分類,可以分為異構遷移學習和同構遷移學習,領域自適應屬于同構遷移學習,側重于解決特征空間相同、類別空間也相同,但是特征分布不相同的問題[2]。領域自適應的目標為減小源域和目標域的分布差異,從而使用源域上帶標簽數據和目標域上無標簽數據學習目標域上的預測函數來確定目標域上的標簽[3]。深度域自適應方法使用深度學習解決領域自適應問題,與淺層域自適應方法相比,深度域自適應方法通過將神經網絡和領域自適應相結合,使得學習到的特征更具有代表性和可遷移性。深度域自適應方法又可分為基于距離度量和基于對抗的深度域自適應。基于距離度量的深度域自適應將神經網絡提取到的特征映射到一個公共空間中,最小化源域和目標域在此空間中的分布差異[4]。常見的距離度量指標有最大均值差異(maximum mean discrepancy,MMD)[5]、Kullback-Leibler散度[6]、Wasserstein距離[7]等。基于對抗的深度域自適應方法把共享的特征提取器作為生成器,使域分類器無法判別提取的特征來自源域還是目標域,以此來對齊源域和目標域的特征分布,再訓練源域分類器,對特征進行分類。與基于對抗的深度域自適應方法相比,基于距離度量的深度域自適應方法無須對抗性訓練,并且此類方法收斂速度快。
文獻[8]首次將神經網絡應用到領域自適應中,提出了域自適應神經網絡(domain adaptive neural network,DaNN),在特征提取層之后加入自適應層,使用最大均值差異來測量兩個域之間的特征分布差異,并將其加入網絡的損失中進行訓練。由于DaNN太淺,表達能力有限,效果較差,無法很有效地解決領域自適應問題,文獻[9]使用在ImageNet數據集上訓練好的AlexNet[10],固定AlexNet的前7層,在第8層前加入自適應層,取得的效果好于DaNN;Long等人[11]提出了深度自適應網絡(deep adaptation network,DAN),在分類器前加入三個自適應層,采用了表達能力更好的多核最大均值差異(multiple kernel maximum mean discrepancy,MK-MMD)對DDC進行擴充;Long等人[12]提出聯合自適應網絡(joint adaptation network,JAN)對聯合分布進行適配,采用聯合最大均值差異(joint maximum mean discrepancy,JMMD)度量準則,使用在源域上訓練的分類器為目標域上的樣本生成偽標簽,將對數據進行自適應推廣為對類別的自適應。之后許多方法開始使用偽標簽技術提升遷移效果,Xie等人[13]對目標域中的無標簽數據生成偽標簽,通過偽標簽實現跨域語義對齊;Zhang等人[14]通過網絡輸出的置信度選擇目標域的樣本,擴大訓練集并且重新訓練模型;Zhu等人[15]提出深度子領域適應網絡(deep subdomain adaptation network,DSAN),使用帶標簽的源域數據訓練分類器,得到目標域樣本的偽標簽,通過源域的真實標簽和目標域的偽標簽將源域和目標域劃分為若干子領域,使用局部最大均值差異(local maximum mean discrepancy,LMMD)度量準則在不同子域上進行分布對齊;吳蘭等人[16]在DSAN的基礎上設置自適應權重調節機制,減輕異常源類導致的負遷移,提出深度加權子域自適應網絡(deep-weighted "subdomain adaptive network,DWSAN);Deng等人[17]對目標數據的潛在特征進行探索,在目標域上使用聚類學習鑒別特征,同時生成偽標簽,并且把學習鑒別特征和對齊類鑒別特征統一到一個框架中,但是引入聚類學習增加了計算量;Fu等人[18]將源域和目標域映射到子空間中,在子空間中進行分布對齊,然后使用最近鄰原型(nearest class prototype,NCP)和結構化預測(structured prediction,SP)得到目標域的偽標簽,選擇部分具有高置信度偽標簽的樣本加入到下一次的迭代學習中;文獻[19]使用源域數據訓練網絡,預測目標域的偽標簽,將有偽標簽的目標域數據與源域數據混合,通過偽標簽學習和域自適應的正則化項對網絡的參數施加約束,減少可遷移特征的分布差異,用于機械診斷問題;Wang等人[20]使用監督局部保持投影學習域不變子空間,使用聚類分析得到目標域樣本的偽標簽,所得偽標簽的準確率較高;Zhang等人[21]提出的PRPL(pre-trained features and recurrent pseudo-labeling)模型使用EfficientNet[22]對源域和目標域提取特征,選擇具有高置信度的偽標簽與帶標簽的源域數據共同訓練模型,對目標域數據反復生成高置信度的偽標簽;文獻[23]提出循環偽標簽分類模型,加入偽標簽選擇損失函數,選擇高置信的目標域偽標簽,循環迭代生成高置信度的偽標簽,提高了分類的準確率。
以上模型均采用交叉熵損失函數,交叉熵函數屬于熵最小化的方法,熵最小化的方法會傾向于將決策邊界附近的樣本預測為樣本個數比較多的類別,導致錯誤分類;另一方面,以上模型未考慮錯誤的偽標簽在循環迭代訓練過程中產生的影響。對于上述問題,在PRPL模型的基礎上提出基于加權分類損失和核范數的領域自適應模型(domain adaptation model based on weighted classification loss and nuclear-norm, WCLN)。該模型選擇目標域中的可信樣本,使用帶有高置信度偽標簽的可信樣本特征與帶有真實標簽的源域樣本特征構建輔助域,用輔助域代替源域進行訓練,循環上述步驟直至模型收斂,設計輔助域上的加權分類損失函數,降低錯誤偽標簽在訓練中的影響;加入批量核范數最大化損失函數,提高決策邊界處樣本的分類準確率;在各個基準數據集上實現了比現有方法更高的精度。
1 基于加權分類損失和核范數的領域自適應模型
1.1 問題定義
1.2 WCLN模型結構
1.3 最大均值差異
1.4 加權分類損失
1.5 批量核范數最大化損失
1.6 損失函數
2 實驗
2.1 數據集介紹
在Office31、Office-Home、Image-CLEFDA三個領域自適應基準數據集上進行大量實驗。
Office31是領域自適應中使用最多的基準數據集,如圖2(a)所示,由4 110張圖片組成,共有31個類別,這些圖片來自亞馬遜網站上下載的Amazon(A)、網絡攝像頭拍攝的WebCam(W)和Dslr(D)。在此數據集上構建6個遷移任務A→W、A→D、W→A、W→D、D→A、D→W。
Image-CLEFDA如圖2(b)所示,由來自Caltech-256(C)、ImageNet ILSVRC 2012(I)和Pascal VOC 2012(P)的1 800張圖片組成,每個域包含12個類別,每個類別有50張圖片。在此數據集上構建6個遷移任務C→I、C→P、I→C、I→P、P→C、P→I。
Office-Home更加復雜,如圖2(c)所示,由15 500張圖片組成,共有65個類別,這些圖片來自Art(Ar)、Clipart(Cl)、Product(Pr)和Real World(Rw)四個領域。在此數據集上構建12個遷移任務Ar→Cl、Ar→Pr、Ar→Rw、Cl→Ar、Cl→Pr、Cl→Rw、Pr→Ar、Pr→Cl、Pr→Rw、Rw→Ar、Rw→Cl、Rw→Pr。
2.2 實驗設置
對于以上數據集,將WCLN與其他模型進行對比,采用的對比模型為joint distribution adaptation(JDA)[25]、deep adaptation network(DAN)[11]、deep correlation alignment(DCORAL)[26]、pre-trained features and recurrent pseudo-labeling(PRPL)[21]、deep subdomain adaptation network(DSAN)[15],其中JDA屬于傳統模型,DAN、DCORAL、PRPL、DSAN是本文使用PyTorch框架重新編寫的模型。對于WCLN,設置循環的次數m=3,置信度閾值pt=[0.5,0.8,0.9],用WCLN(t=0)表示不設置偽標簽置信度閾值時的實驗,用WCLN(t=1)、WCLN(t=2)、WCLN(t=3)分別表示設置偽標簽置信度閾值為0.5、0.8、0.9時的實驗,加權分類損失中的權重α=[0.29,0.41,0.47],模型總的損失函數中的β=0.1,對于所有的任務,先對源域樣本和目標域樣本訓練9次后再選擇可信樣本,同時每個循環訓練9次,所以設置epoch為36。優化器設置動量為0.9的小批量梯度下降,學習率lr為0.001。對于所有模型的訓練,batch_size設置為64,不同模型設置不同的epoch。實驗使用的電腦顯卡為NVIDIA GeForce RTX 2060,CPU為Intel CoreTM i7-10870H CPU@2.20 GHz。
2.3 實驗結果及分析
在Office31、Image-CLEFDA、Office-Home數據集上基于EfficientNetB7的所有遷移任務的分類結果如表1~3所示,在該遷移任務上最高分類準確率由粗體表示。與傳統方法JDA進行比較,可以看出深度域自適應算法的分類準確率有大幅度提升。WCLN在三個數據集上的所有遷移任務都實現了最高的分類準確率。本文重新實現的DAN、DCORAL、DSAN方法比原文有更高的分類準確率。對于Office31數據集,在D→W遷移任務上達到了100%的分類準確率,循環迭代訓練對準確率較高的任務的提升并不明顯,對于D→A和W→A兩個任務的提升較明顯,平均提高2.2%。第二和第三次循環整體上比第一次循環提升的少。對于Image-CLEFDA數據集,在C→P遷移任務上的提升效果最明顯,提高了3.8%。同樣,第二和三次循環整體上比第一次循環提升的少。對于Office-Home數據集,WCLN模型提升效果最明顯,在12個遷移任務上平均提高了2.4%。Office31和Image-CLEFDA數據集較小,并且大部分任務都已經有較高的分類準確率,所以提升效果不明顯。而在Office-Home數據集上的結果表明,設計加權分類損失函數和加入批量核范數最大化損失可以顯著提升模型在大型數據集上的遷移效果。
分別使用AlexNet、ResNet50、VGG16、EfficientNetB7對Office31數據集中的Dslr提取特征,并繪制t-SNE圖,結果如圖3所示,不同顏色表示不同的類別。可以看出,使用EfficientNetB7提取出的相同類別樣本的特征更加緊湊,不同類別的樣本特征之間的距離更大,效果更好。因此,本文選擇使用EfficientNetB7作為提取主干特征提取網絡。
在A→W遷移任務上,如圖4(a)所示,還未進行選擇可信樣本特征時,分別使用加權分類損失函數和交叉熵損失函數得到的分類準確率相等;選擇了可信樣本特征之后,隨著訓練次數的增加,使用加權分類損失得到的準確率明顯高于使用交叉熵損失函數的準確率,并且逐漸收斂。如圖4(b)所示,在該遷移任務上,使用批量核范數最大化損失以后的準確率明顯高于未使用批量核范數最大化損失,隨著訓練次數的增加,準確率逐漸收斂且收斂的速度更快。所以,加權分類損失和批量核范數最大化損失都提高了分類準確率,提升了遷移的效果。
除了JDA以外的所有模型在訓練時epoch均設置為36,在Office31數據集上6個遷移任務的訓練時間對比如圖5所示,DSAN(改)和DAN(改)是本文在PyTorch框架中重新實現的,DSAN(原)和DAN(原)是原文中開源的模型代碼。無論是DSAN還是DAN,在Office31數據集中每個遷移任務上改寫后的訓練時間遠小于原文中的訓練時間,DSAN在6個遷移任務上平均降低89%,DAN在6個遷移任務上平均降低82%。WCLN在每個遷移任務上的訓練時間雖然都不是最少的,但與DSAN(改)、DAN(改)、JDA的訓練時間相差不大,同時WCLN在各遷移任務上達到了最高的分類準確率,表明WCLN同時含有較高的時間效率和分類準確率。
3 結束語
針對領域自適應中未充分考慮錯誤偽標簽在訓練過程中產生的影響和決策邊界處預測準確率較低的問題,提出WCLN模型。該模型首先選擇出目標域中具有高置信度偽標簽的可信樣本特征,與帶有真實標簽的源域樣本特征構建輔助域,再用輔助域代替源域進行訓練,循環上述步驟直至模型收斂。設計輔助域上的加權分類損失函數,降低錯誤偽標簽產生的影響。加入批量核范數最大化損失,提高決策邊界處樣本分類的準確率。在Office31、Image-CLEFDA、Office-Home基準數據集上,WCLN模型有較高的時間效率,比JDA、DAN、DCORAL、DSAN、PRPL有更高的分類準確率。另外注意到只將置信度作為偽標簽選擇的唯一依據不太充分,后續工作準備結合不確定性等相關知識改進偽標簽的選擇依據,以此獲得更加準確的偽標簽,提升遷移效果,再嘗試能否將WCLN運用到圖像分割或者目標檢測領域中以提高模型的實用價值。
參考文獻:
[1]Pan S J, Yang Qiang. A survey on transfer learning[J].IEEE Trans on Knowledge and Data Engineering,2009,22(10):1345-1359.
[2]李晶晶,孟利超,張可,等.領域自適應研究綜述[J].計算機工程,2021,47(6):1-13.(Li Jingjing, Meng Lichao, Zhang Ke, et al. A survey on domain adaptation[J].Computer Engineering,2021,47(6):1-13.)
[3]劉建偉,孫正康,羅雄麟.域自適應學習研究進展[J].自動化學報,2014,40(8):1576-1600.(Liu Jianwei, Sun Zhengkang, Luo Xionglin. Review and research development on domain adaptation learning[J].Acta Automatica Sinica,2014,40(8):1576-1600.)
[4]田青,朱雅喃,馬闖.基于深度學習的域適應方法綜述[J].數據采集與處理,2022,37(3):512-541.(Tian Qing, Zhu Yanan, Ma Chuang. Review on domain adaptation methods based on deep lear-ning[J].Journal of Data Acquisition and Processing,2022,37(3):512-541.)
[5]Pan S J, Tsang I W, Kwok J T, et al. Domain adaptation via transfer component analysis[J].IEEE Trans on Neural Networks,2010,22(2):199-210.
[6]Zhuang Fuzhen, Cheng Xiaohu, Luo Ping, et al. Supervised representation learning: transfer learning with deep autoencoders[C]//Proc of the 24th International Joint Conference on Artificial Intelligence.2015:4119-4125.
[7]Xu Pengcheng, Gurram P, Whipps G, et al. Wasserstein distance based domain adaptation for object detection[EB/OJ].(2019-09-18).https://arxiv.org/abs/1909/08675.
[8]Ghifary M, Kleijn W B, Zhang Mengjie. Domain adaptive neural networks for object recognition[C]//Proc of Pacific Rim International Conference on Artificial Intelligence.Berlin.Springer,2014:898-904.
[9]Tzeng E, Hoffman J, Zhang Ning, et al. Deep domain confusion: maximizing for domain invariance[EB/OL].(2014-12-10).https://arxiv.org/abs/1412.3474.
[10]Krizhevsky A, Sutskever I, Hinton G E. ImageNet classification with deep convolutional neural networks[J].Communications of the ACM,2017,60(6):84-90.
[11]Long Mingsheng, Cao Yue, Cao Zhangjie, et al. Transferable representation learning with deep adaptation networks[J].IEEE Trans on Pattern Analysis and Machine Intelligence,2019,41(12):3071-3085.
[12]Long Mingsheng, Zhu Han, Wang Jianming, et al. Deep transfer lear-ning with joint adaptation networks[C]//Proc of the 34th International Conference on Machine Learning.[S.l.]:JMLR.org,2017:2208-2217.
[13]Xie Shaoan, Zheng Zibin, Chen Liang, et al. Learning semantic representations for unsupervised domain adaptation[C]//Proc of the 35th International Conference on Machine Learning.2018:5423-5432.
[14]Zhang Weichen, Ouyang Wanli, Li Wen, et al. Collaborative and adversarial network for unsupervised domain adaptation[C]//Proc of IEEE/CVF Conference on Computer Vision and Pattern Recognition.Piscataway,NJ:IEEE Press,2018:3801-3809.
[15]Zhu Yongchun, Zhuang Fuzhen, Wang Jindong, et al. Deep subdomain adaptation network for image classification[J].IEEE Trans on Neural Networks and Learning Systems,2020,32(4):1713-1722.
[16]吳蘭,李崇陽.深度加權子域自適應網絡[J].鄭州大學學報:理學版,2022,54(1):54-61.(Wu Lan, Li Chongyang. Deep-weight subdomain adaptive network[J].Journal of Zhengzhou University:Natural Science Edition,2022,54(1):54-61.)
[17]Deng Wanxia, Liao Qing, Zhao Lingjun, et al. Joint clustering and discriminative feature alignment for unsupervised domain adaptation[J].IEEE Trans on Image Processing,2021,30:7842-7855.
[18]Fu Tingting, Li Ying. Unsupervised domain adaptation based on pseudo-label confidence[J].IEEE Access,2021,9:87049-87057.
[19]侯鑫燁,董增壽,劉鑫.基于偽標簽的弱監督遷移學習模型[J].機床與液壓,2021,49(24):185-189.(Hou Xinye, Dong Zengshou, Liu Xin. Weak supervised transfer learning model based on pseudo-label[J].Machine Tool amp; Hydraulics,2021,49(24):185-189.)
[20]Wang Qian, Breckon T. Unsupervised domain adaptation via structured prediction based selective pseudo-labeling[C]//Proc of AAAI Conference on Artificial Intelligence.2020:6243-6250.
[21]Zhang Youshan, Davison B D. Efficient pre-trained features and recurrent pseudo-labeling in unsupervised domain adaptation[C]//Proc of IEEE/CVF Conference on Computer Vision and Pattern Recognition.2021:2719-2728.
[22]Tan Mingxing, Le Q. EfficientNet: rethinking model scaling for con-volutional neural networks[C]//Proc of the 36th International Confe-rence on Machine Learning.2019:6105-6114.
[23]楊國慶,郭本華,錢淑渠,等.基于偽標簽的無監督領域自適應分類方法[J].計算機應用研究,2022,39(5):1357-1361.(Yang Guoqing, Guo Benhua, Qian Shuqu, et al. Pseudo label based unsupervised domain adaptation classification method[J].Application Research of Computers,2022,39(5):1357-1361.)
[24]Cui Shuhao, Wang Shuhui, Zhuo Junbao, et al. Towards discri-minability and diversity: batch nuclear-norm maximization under label insufficient situations[C]//Proc of IEEE/CVF Conference on Computer Vision and Pattern Recognition.Piscataway,NJ:IEEE Press,2020:3940-3949.
[25]Long Mingsheng, Wang Jianmin, Ding Guiguang, et al. Transfer feature learning with joint distribution adaptation[C]//Proc of IEEE International Conference on Computer Vision.Piscataway,NJ:IEEE Press,2013:2200-2207.
[26]Sun Bochen, Saenko K. Deep CORAL: correlation alignment for deep domain adaptation[C]//Proc of European Conference on Computer Vision.Cham:Springer,2016:443-450.