收稿日期:2021-11-21;修回日期:2022-01-21
基金項目:國家自然科學基金資助項目(61562078,71563048);新疆天山青年計劃資助項目(2018Q073);新疆高校研自科項目(XJEDU2021Y037);新疆“天山雪松計劃”青年拔尖人才計劃資助項目
作者簡介:陳小昆(1963-),女,教授,碩導,碩士,主要研究方向為統計機器學習;左航旭(1998-),男(通信作者),碩士研究生,主要研究方向為機器學習(zuohangxu@163.com);廖彬(1986-),男,副教授,博導,博士,主要研究方向為深度學習、數據挖掘及大數據計算模型等;孫瑞娜(1982-),女,博士研究生,主要研究方向為數據挖掘、網絡安全等.
摘 要:為了解決冠心病診斷模型中性能無法滿足臨床應用要求、缺乏可解釋性的問題,提出一種融合XGBoost與SHAP的冠心病預測及其特征分析模型。在對數據集進行特征工程的基礎上,將處理好的數據集輸入XGBoost模型進行訓練,并且對模型進行優化,進一步提高了模型的性能表現;其次,與基于SVM、樸素貝葉斯等六種機器學習模型以及八種主流機器學習模型進行實驗對比,參數優化后的XGBoost模型在準確率、特異度、F1值和AUC值四個指標上分別達到0.994 2、0.997 0、0.994 1和0.999 8,均優于已有模型;最后引入SHAP框架增強模型可解釋性,綜合四種模型特征重要性排序結果,識別出影響冠心病的重要因素,為醫生作出正確的診斷提供決策參考。
關鍵詞:冠心病預測; XGBoost模型; SHAP模型; 特征分析
中圖分類號:TP391"" 文獻標志碼:A
文章編號:1001-3695(2022)06-033-1796-09
doi:10.19734/j.issn.1001-3695.2021.11.0639
Coronary artery disease prediction and feature analysis model based on XGBoost and SHAP
Chen Xiaokun1, Zuo Hangxu1, Liao Bin1, Sun Ruina1,2,3
(1.College of Statistics amp; Data Science, Xinjiang University of Finance amp; Economics, Urumchi 830012, China; 2.Institute of Information Engineering, Chinese Academy of Sciences, Beijing 100093, China; 3.School of Networks Security, University of Chinese Academy of Sciences, Beijing 100093, China)
Abstract:To address the lack of practical application and interpretability of coronary artery disease(CAD) diagnostic models, this paper proposed a novel model based on XGBoost and SHAP for the diagnosis of CAD. Firstly, it put the processed dataset into the XGBoost model for training, and optimized the model to boost performance. Then, compared to six machine learning models such as SVM and naive Bayes and eight mainstream machine learning models, the parameter-optimized XGBoost model obtains 0.994 2, 0.997 0, 0.994 1 and 0.999 8 in accuracy, specificity, F1 and AUC, which are higher than the existing models. Lastly, it used the SHAP framework to improve model interpretability and identified important factors affecting CAD. The proposed model has the potential to be a useful diagnostic tool in hospitals for the diagnosis of CAD.
Key words:coronary artery disease(CAD) prediction; XGBoost model; SHAP model; feature analysis
0 引言
隨著人工智能的蓬勃發展,醫療領域逐漸成為人工智能新的發力點。在研究人員致力于從醫學角度改進診療方案的同時,人工智能的融入極大地促進了傳統診療技術的進步。其中,機器學習作為當下最炙手可熱的人工智能技術,依托于豐富的醫療數據,各類預測算法被廣泛應用于疾病的診斷[1~3]與預測[4,5]。
冠心?。–AD)作為世界上最常見的心血管疾病,是導致當今人類死亡的主要病因之一。據世界衛生組織公布的統計數據顯示,2015年約有1 770萬人死于冠心病,占當年世界因病死亡人數的31%,且超過四分之三發生在中低收入國家[6]。目前,中國冠心病患病人數和死亡率仍處于上升階段,2019年中國心血管報告指出,中國冠心病患病人數約為1 100萬[7],且近三十年因冠心病死亡增加的人數位列全球第一。冠心病的診斷和治療較為復雜,尤其在發展中國家,由于診斷儀器的稀缺性和醫療人員的不足,影響了冠心病患者的早期檢測和進一步治療,從而導致了嚴重的后果。在發展中國家,用于診斷冠心病的技術一般基于對患者病史、檢驗報告以及醫療人員對相關癥狀的分析,診斷結果的準確率嚴重依賴于醫療人員的專業水平[8],并且昂貴的醫療診斷費用也導致了只有少數人能夠接受專業的檢測和診斷。冠心病檢測和診療中存在的以上問題大大降低了冠心病患者的生存率,嚴重威脅病人的生命安全。
得益于人工智能領域的發展,為了降低冠心病診斷中的復雜性,基于機器學習的各類算法被廣泛應用于冠心病預測,這對于輔助醫生降低診斷難度、提高診斷效率及預測準確率具有重要的現實意義。在現有的研究工作中,已有不少基于邏輯回歸(Logistic)[9]、支持向量機(SVM) [10,11]、樸素貝葉斯[12]、XGBoost[13]、神經網絡[14~20]等機器學習模型在相關心臟病數據集上進行冠心病預測,但是上述模型在準確率(accuracy)、靈敏度(sensitivy)等性能指標上仍有提升空間,并未達到投入一線臨床應用的性能要求。基于以上背景,為了進一步提高冠心病預測模型的準確率、靈敏度、特異度、F1值和AUC值等性能指標,本文融合XGBoost和SHAP模型建立了一種性能更為出色的冠心病預測及其特征分析模型,并且通過特征選擇對比、超參數優化、泛化分析、算法耗時分析等技術找到模型的最佳表現,為模型的一線臨床應用打下基礎。本文工作主要集中在以下三個方面:
a)基于XGBoost算法建立了冠心病預測模型,并通過與已有研究提出的模型以及主流機器學習模型在準確率、靈敏度、特" 異度等指標上進行對比實驗,驗證了本文模型的性能優越性。
b)通過特征選擇對比、超參數調優、泛化分析和算法耗時分析等技術,找到模型在訓練數據集上的最佳表現,為模型的臨床應用打下基礎。
c)在保證模型預測性能的基礎上,為增強機器學習模型的可解釋性,引入SHAP模型對影響冠心病的各種因素進行分析,為醫療人員正確診斷冠心病提供決策參考。
1 國內外相關研究
近年來,國內外學者不斷嘗試將機器學習和神經網絡算法應用于冠心病輔助診斷領域,這對于降低診斷費用及難度、提高診斷效率及預測準確率具有重要的現實意義,其已有的相關研究及其性能表現匯總如表1所示。
Liu等人[9]使用臨床數據,基于邏輯回歸算法對冠心病進行預測。文獻[10]提出一種基于支持向量機(SVM)的模型在Z-Alizadeh Sani數據集上進行冠心病預測。文獻[12]在克利夫蘭數據集選擇15個特征,使用樸素貝葉斯進行冠心病的預測,并且討論了影響冠心病的重要因素。文獻[13]使用臨床數據,基于XGBoost算法對老年冠心病患病風險進行評估。文獻[14]使用私有數據,基于數坤Coronary Doc工作站進行建模分析。在美國NHANES冠心病調查數據集上,存在著樣本不平衡現象,基于該不平衡樣本,文獻[15]使用隨機二次采樣技術,選取了38個特征,利用卷積神經網絡(CNN)對冠心病作出預測。在ECG信號數據上,基于神經網絡的算法表現較好,文獻[16]利用IPSO-BP神經網絡對臨床心電圖信號進行建模預測。文獻[20]使用ECG信號數據,基于CWT-CNN神經網絡算法,診斷準確率為98.7%。
一方面,上述已有研究大多使用ECG數據和圖像數據,數據較難獲得,提高了冠心病診斷的診斷門檻,并且提出的各種模型在準確率、靈敏度、特異度(specificity)等指標上仍有提升空間,還未達到投入臨床應用的性能要求;另一方面,已有工作大多采用XGBoost、隨機森林和神經網絡等黑箱機器學習模型,在具備較高性能的同時,模型缺乏可解釋性,不利于輔助醫生作出最后的醫療決策。針對以上存在的問題,本文基于較易獲得的一般性醫療統計數據,在使用XGBoost模型進一步提高預測性能的基礎上,引入SHAP增強模型可解釋性。針對以上兩個問題,本文與已有研究工作的最大不同之處在于:a)基于XGBoost算法建立了冠心病預測模型,在準確率、特異度、F1值和AUC值四項指標上分別達到0.994 2、0.997 0、0.994 1、0.999 8,與已有模型和主流模型相比性能更優;b)通過引入SHAP增強了模型的可解釋性,能夠對影響冠心病的各種因素進行分析,為輔助醫療人員作出正確決策提供參考信息。
2 模型構建方法
本文所用到的符號及其解釋如表2所示。
融合XGBoost與SHAP的冠心病預測及其特征分析模型構建流程如圖1所示,其過程主要包括數據預處理、預測模型的建立、SHAP解釋模型的建立、超參數調優、泛化能力分析和算法耗時分析等。首先,通過異常值處理、特征選擇和平衡數據集等步驟對原始數據進行預處理,隨后利用處理后的數據建立XGBoost冠心病預測模型和SHAP解釋模型,并且對模型進行優化,進一步提升模型的性能表現,最終得到基于XGBoost和SHAP的冠心病預測及其特征分析模型。
2.1 模型原理及其構建流程
本節選取XGBoost算法對美國NHANES臨床數據集中是否患有冠心病進行分類建模。設Xs為輸入空間(包括年齡、體重、白細胞計數等特征),特征維度為s,Y為輸出空間,給定訓練數據集為D={(x1,y1),(x2,y2),…,(xn,yn)},其中xj=(xj(1),xj(2),…,xj(s))為輸入實例, j=1,2,…,n,n為樣本個數?;赬GBoost算法訓練出最優P-XGBoost模型的訓練過程算法如下:
算法1 基于XGBoost的冠心病診斷模型訓練算法
輸入:parameter 1,訓練數據集D={(x1,y1),(x2,y2),…,(xn,yn)};parameter 2,learning_rate、n_estimators、max_depth、subsample 、min_child_weight 、colsample_bytree、gamma、reg_alpha、reg_lambda。
輸出:P-XGBoost。
1 models←{}
2 D←{(x1,y1),(x2,y2),…,(xn,yn)}//輸入訓練數據集
3 D′←preprocess(D)//數據預處理
4Euclid Math TwoRAp←SMOTE(D′)//引入SMOTE算法平衡數據集
5 input parameter2 //輸入參數2
6 for parameter2 in parameter2List do
7"" tempModel←XGBoost(parameter2).fit(瘙綆)
8"" if tempModel.accuracygt;models do
9""" models←tempModel
10"" best_parameter2←parameter2
11"" end if
12 end for
13 P-XGBoost←XGBoost(best_parameter2)
14 return P-XGBoost
在算法的第2行輸入非平衡訓練集D后,第3行對數據集進行異常值處理和特征選擇;第4行針對數據集中出現的樣本不平衡現象引入少數類過采樣方法(synthetic minority oversampling technique,SMOTE)對數據集進行樣本平衡操作;算法的5~12行完成對九個參數配置項的搜索工作,其余參數均為默認參數;算法的第13行選擇最優參數訓練出P-XGBoost模型,在第14行返回得到的最優模型。
算法第7行是XGBoost算法的訓練過程,其核心是由文獻[21]提出的基于Boosting樹模型的學習框架。傳統的Boosting樹模型只使用一階導數,在訓練第n棵樹時,由于使用了前n-1棵樹的殘差,所以很難實現分布式訓練,XGBoost對損失函數進行二階泰勒展開,且加入正則化項,可以自動使用CPU進行并行計算,并且避免了過擬合。
XGBoost作為一種提升樹模型,即用一棵樹去預測一個值,得到該值與實際值的偏差,再添加一棵樹去學習該偏差[22]。假設總共有t棵樹,F表示樹模型,則預測值yEuclid ExtrazB@l可表示為
yEuclid ExtrazB@l=∑Kk=1ft(xi) ft∈F(1)
目標函數為
L=∑il(yEuclid ExtrazB@l,yi)+∑kΩ(ft)(2)
其中:l為損失函數,表示預測值與真實值之間的誤差;Ω為正則化函數,防止模型過擬合。XGboost中的正則化函數表示為
Ω(f)=γT+12λ‖w‖(3)
其中:T表示每棵樹的葉子節點數;w表示每棵樹葉子的權重;為了抑制樹的生長和防止模型過擬合,加入了λ和γ,λ為L2正則化系數,γ為分裂閾值。根據目標函數,解得最優評分函數,該函數輸出的值越小,表明樹模型越好。
Obj(t)=-12∑Tt=1(G2iHi+λ)+γT(4)
其中:Gi=∑i∈Igi,Hi=∑i∈Ihi,且gi=yi(t-1)l(yi,yi(t-1)),hi=2yi(t-1)l(yi,yi(t-1))。
根據評分函數可以對一個樹模型進行評價,但候選樹是無窮的,不可能得到所有候選樹的評分。XGBoost算法采用了貪心算法來解決這個問題,從樹的根節點開始,計算分裂后與分裂前目標函數值是否減少,假設分裂前的節點為j,其對目標函數的貢獻為
Obj(j)=-12G2iHi+λ+γ(5)
該節點分裂后,兩個子節點的目標函數貢獻為
Obj(j)=-12G2jLHjL+λ+G2jRHjR+λ+2γ(6)
此時,目標函數變化為
Obj(j)split=Obj(j)-Obj(s)=-12(G2jLHjL+λ+G2jRHjR+λ-G2jHj+λ)-γ(7)
最終得到目標函數在每次分裂后信息增益:
Gain=12=G2LHL+λ+G2RHR+λ-(GL+GR)2H+λ-γ(8)
其中:GL、GR分別為分裂時左右葉子一階梯度統計和;HL、HR為左右葉子節點二階梯度統計和的信息增益。
XGBoost算法采用了一種近似直方圖算法,在每次分裂構建CART樹時,選取使得Gain值最大的節點進行分裂,使得分割點的選取更加高效和直接,并且在一定程度上降低了模型的過擬合。除此之外,XGBoost還使用了多種方法來對模型進行優化,由于其優越的性能,近年來在機器學習領域得到了廣泛應用。
2.2 SHAP解釋模型
基于集成算法的機器學習模型雖然有著較為優秀的性能,但是隨著模型復雜度的提高,降低了模型的可解釋性,這使得XGBoost模型幾乎是一個黑箱模型。為了解決該模型可解釋性較差的問題,引入SHAP框架來對模型結果進行解釋,以便為模型結果的可靠性提供支撐。
SHAP(Shapley additive explanations)是由文獻[23]提出的用于解釋黑箱模型的一種解釋框架。該解釋模型與人類直覺一致,并在解釋醫療和社會現象相關模型方面日益流行[24]。SHAP是基于Shapley value的計算,這是一種來自聯盟博弈論的重要方法,用于衡量特征如何影響因變量。該方法將每個特征都視為貢獻者,計算每個特征的貢獻值,將每個特征貢獻值相加得到模型的最終預測[25]。
對于集成樹模型來說,當做分類任務時,模型輸出的是一個概率值。SHAP可計算每一個特征的Shapley值,以此來衡量出每個特征對于最終預測的貢獻度。假設g代表解釋模型,M代表特征的數目,z代表該特征是否存在(取值0或1),為每個特征的Shapley值,則可給出公式為
g(z)=0+∑Mi=1izi(9)
每個特征的SHAP值表示以該特征為條件時預期模型預測的變化。對于每個功能,SHAP值說明了該特征對于總體預測結果的貢獻,以說明實例的平均模型預測與實際預測之間的差異。當igt;0,說明該特征對于預測值有提升作用,反之,說明該特征使得貢獻降低。XGBoost模型給出的feature importance只說明了哪個特征重要,但并未給出該特征是怎樣影響預測結果的。SHAP模型最大的優勢是能反映出每一個樣本中特征的影響力和該影響對于最終預測結果的正負性。
3 模型對比實驗
3.1 實驗環境及評價指標
本文實驗Python版本為3.8.3,scikit-learn版本為0.24.2,其詳細實驗配置環境如表3所示。
在醫療領域,常用的機器學習分類指標包括準確率(accuracy)、靈敏度(sensitivity)、特異度(specificity)、F1值及AUC值。本文使用上述指標進行模型的評估,其中準確率、靈敏度、特異度和F1值等指標來自于混淆矩陣,如表4所示。
準確率表示所有正確分類的樣本占總樣本的百分比。靈敏度即召回率,表示正確識別的正元組的百分比。特異度表示正確識別的負元組的百分比。F1-score為精確率和召回率的調和均值。ROC曲線(receiver operating characteristic curve)即受試者工作特征曲線,該曲線顯示了真正例率(TPR)和假正例率(FPR)之間的權衡。AUC值是ROC曲線下的面積,該值越接近1,說明模型精度越高。上述指標公式分別如式(10)~(13)所示。
accuracy=TP+TNTP+FP+FN+TN×100%(10)
sensitivity=recall=TPTP+FN×100%(11)
specificity=TNTN+FP×100%(12)
F1-score=2×precision×recallprecision+recall×100%(13)
其中:precision=TP/(TP+FP),即正確預測正類的樣本占預測為正類所占的百分比;TPR=TP/(TP+FN),即正確預測正類的概率;FPR=FP/(FP+TN),即錯誤預測正類的概率。
3.2 數據概況及其特征工程
本文與文獻[15]使用的數據集一致,數據集來自美國國家衛生統計局NHANES臨床數據。該數據集由1999—2000年至2015—2016年間美國37 079個人的基本統計數據、醫療數據和問卷數據編制而成,包括1 508例冠心病和35 571例非冠心病,數據無缺失值且該樣本為非平衡樣本。該數據集包括49個特征,其中是否患冠心病為一個二分類變量。特征的詳細列表如表5所示。
對原始數據進行的特征工程包括異常值處理、特征選擇及不平衡數據處理三個步驟。
3.2.1 異常值處理
通過數字異常值(numeric outlier)方法篩選數據中的異常值,IQR(inter-quartile range)可計算得到四分位間距,將四分位間距上下限以外的值視為異常值。如圖2所示,本數據集存在大量異常值,為了提高模型精度,去掉所有四分位間距上下限之外的值。除了去除異常值以外,不對數據進行標準化操作,本文試圖通過該種方法盡量保存數據的原始信息。
3.2.2 特征選擇
在機器學習領域,特征選擇主要有過濾器(filter)[26]、封裝法(wrapper)[27]和嵌入法(embedded)[28]三種模式。機器學習模型性能表現很大程度上依賴于特征,恰當的特征組合可以最大限度地發揮算法的預測準確率。
為了使用最優方法進行選取特征,提高模型預測能力。本文選取卡方檢驗、基于Logistic Regression的遞歸特征消除法、基于懲
罰項的Lasso算法、基于隨機森林的embedded特征選擇算法進行實驗對比。其中,卡方檢驗為過濾器特征選擇法中的代表性算法。一般來說,過濾器特征選擇法較簡單,選擇速度較快,但對于特征之間的組合效應難以挖掘,效果往往比不上封裝法和嵌入法;遞歸特征消除為封裝法的代表方法,Logistic Regression在數據量較大時也具備較快的收斂和歸一能力,因此本文選取Logistic Regression作為遞歸特征消除的基處理器;嵌入法主要包括使用帶懲罰項的基模型進行特征選擇和基于樹模型的特征選擇,Lasso算法和基于隨機森林的特征選擇可以較好地代表嵌入法的兩種類型。將以上四種特征選擇方法結合XGBoost模型在冠心病數據集上進行預測對比,相關對比結果見3.3節。
3.2.3 不平衡數據處理
本數據集中存在類別不平衡現象,一般來說,分類器對于非平衡數據中少數類的分類效果不太理想。在處理非平衡數據集時,一般有欠采樣和過采樣兩種方法。欠采樣技術拋棄了部分大類樣本,造成了數據的浪費;而簡單過采樣對于最后分類結果幫助很小,原因是在大多數情況下,簡單過采樣不會改變分類規則,并且會導致模型的過擬合[29]。合成少數類過采樣技術(synthetic minority oversampling technique)[30]解決類間不平衡的過采樣技術,它顯著改善了一般隨機過采樣方法造成的過擬合情況,且減少了過采樣算法在采樣過程中對不平衡數據的局限性和盲目性,因此近年來在處理類不平衡領域得到了廣泛應用。針對樣本中出現的不平衡現象,本文使用SMOTE對整個樣本的少數類進行過采樣。過采樣后,冠心病和非冠心病人數比為1∶1,其數據分布如圖3所示。
3.3 特征選擇實驗結果對比
針對3.2.2節描述的四種特征選擇算法,本文通過實驗讓其在冠心病數據集上進行訓練,并結合XGBoost在測試集進行預測,對比四種特征選擇算法的性能表現。其中,卡方檢驗屬于過濾器特征選擇法;基于Logistic Regression的遞歸特征消除屬于包裝模式特征選擇法;嵌入模式特征選擇有Lasso算法和基于隨機森林的embedded算法兩種。相關指標對比如表6所示。
對比表6實驗結果可以看出,各種特征選擇算法表現較為接近?;谇度胧降腖asso特征選擇算法在accuracy指標上相對于其余三種算法提升約0.2%~0.4%;在sensitivity指標上,四種算法的值相差不大,Lasso算法略優;在specificity指標上,Lasso算法相比于最低的隨機森林提升約1%,表現較好。綜合比較上述各指標,本文選取Lasso算法進行特征選擇,并且采用五折交叉驗證的方法選取Lasso算法中的最佳α,最終選取了38個特征(圖4)。
3.4 與已有工作的實驗結果對比
本節將本文使用的XGBoost模型與支持向量機[10]、樸素貝葉斯[12]、XGBoost [13]和卷積神經網絡(CNN)[15,17,18]進行性能對比,其對比結果如表7和圖5所示。
對比表7和圖5的數據可知,本文建立的冠心病預測模型在準確率、靈敏度、特異度三項指標上分別為0.991、0.991、0.995,準確率和特異度是上述文獻[10,12,13,15,17,18]中表現最好的。其中,文獻[10]提出的SVM模型的靈敏度達到了1,但是其余兩項指標均不及本文模型。本文與文獻[15]使用的數據集一致,文獻[15]針對數據集中出現的樣本不平衡現象采用隨機二次抽樣的方法平衡樣本,隨后建立卷積神經網絡模型,其模型的準確率、靈敏度和特異度分別為0.818、0.773和0.818;顯然,本文建立的模型在準確率、靈敏度、特異度三個指標上表現更好,相對于文獻[15]性能提升約21.2%~27.7%。這也說明,進行恰當的特征工程和選取合適的預測模型是提高預測性能的關鍵。
3.5 與其他主流機器學習模型對比分析
除了與已有工作模型進行對比,為了更好地驗證本文模型的優越性,將本文模型與八種主流機器學習算法進行對比,并使用準確率、靈敏度、特異度、F1值、AUC值五個指標進行模型評估。將處理后的數據集的80%劃分為訓練集,20%劃分為測試集,且各模型均使用默認參數。
由表8中對比的實驗結果可知,本文模型在準確率、靈敏度、特異度、F1值和AUC指標上的值分別為0.991 0、0.991 1、0.995 4、0.990 9和0.999 4,其中準確率、特異度、F1-score和AUC值四個指標為上述模型指標中最高。圖6給出了以準確率(accuracy)為評價指標的十折交叉驗證箱線圖,XGBoost模型的精度最高;圖7為ROC曲線對比,XGBoost模型的AUC值為0.999 4,也為八種主流模型中最高。可以看出:a)對比基于線性分類的邏輯回歸,非線性的模型表現更好,這表明冠心病臨床數據往往表現出復雜的非線性關系,所以基于非線性關系的模型往往可以獲得較好的分類效果;b)由表8可知,基于集成算法的模型在數據集上均表現較好,其原因在于,不管Bagging還是Boosting算法,其本質仍然是基于若干棵樹的集成,集成模型對于噪聲數據和離群點有著更好的魯棒性,并且對特征的數量不敏感;c)XGBoost模型在訓練之前,對輸入特征數據排序,存儲為block結構,在之后的預測過程中重復使用這個結構,很大程度上減少了計算量,可以在預測中實現并行計算,因此具有更快的預測速度。除此之外,XGBoost還支持自定義損失函數以及自動處理缺失值,且對損失函數進行二階泰勒展開,加入正則化項,大大提高了算法的預測精度和抗擬合性。綜合比較上述指標,在八種算法中,XGBoost算法預測結果更準確,運行速度快,其穩定性相比其他算法更加優秀。
3.6 模型超參數優化、泛化分析和算法耗時分析
XGBoost算法有general、booster和learning task參數三種參數類型。本文選擇對模型影響較大的learning_rate、n_estimators、max_depth、subsample 、min_child_weight 、colsample_bytree、gamma、reg_alpha和reg_lambda參數進行調優。XGBoost支持GridSearchCV,即網格參數搜索驗證,選取accuracy作為評價指標。根據2.1節中提出的最優P-XGBoost模型的算法流程,得到的調優結果如表9所示。
對比表10所示的默認參數組合和調參后的最佳參數組合,其中P-XGBoost為調優后的模型。實驗結果表明,調優后的模型在各項指標上均有不同程度的提升,準確率、靈敏度、特異度、F1-score和AUC值分別為0.994 2、0.991 3、0.997 0、0.994 1和0.999 8,只在靈敏度指標上略低于隨機森林的0.993 8。參數優化后的XGBoost模型性能提升約0.04%~0.49%,原因在于恰當的參數組合降低了模型運算復雜度和過擬合,提高了模型運算精度。
判斷一個模型的好壞,除了在測試集上表現得好,還應該在整個數據集上著較好的泛化表現,并且算法耗時應盡可能低。模型的泛化能力,即模型在新數據集上具有的良好適應能力。對此,本文采用學習曲線及其算法耗時分析對模型的泛化能力和運算速度進行分析。
由圖8的對比實驗結果可以看出,隨著樣本量的增加,三種模型均趨向收斂。對比隨機森林模型,XGBoost模型用較小的樣本量達到了更好的擬合結果,在樣本量為15 000時取得了較高的交叉驗證得分,原因在于其采用了帶權重的分位數略圖算法并且做到了對稀疏數據的優化。
在算法耗時方面,在相同樣本量為40 000時XGBoost耗時最低,且達到0.98以上的交叉驗證得分只用了隨機森林幾乎一半的時間,并且隨著樣本量的增大,隨機森林和支持向量機算法耗時呈現直線上升的趨勢,而此時XGBoost算法耗時呈現平穩趨勢,這得益于XGBoost算法獨特的block結構,可預先對特征值進行排序,并且支持并行計算,大大減少了訓練時間。模型收斂性以及算法耗時分析表明,本文建立的基于XGBoost的集成模型在泛化性以及算法耗時方面表現更優。
4 基于SHAP的模型解釋性分析
本章主要集中在基于SHAP框架對冠心病預測模型結果進行解釋性分析。圖9為SHAP特征摘要,該圖根據特征重要性對影響冠心病的因素進行分析。
如圖9所示,bilirubin(膽紅素)、basophils(嗜堿性粒細胞計數)、age(年齡)、gender(性別)等特征對模型影響較為顯著。SHAP給出的影響冠心病的最重要特征為bilirubin,該特征對冠心病的影響呈現出復雜的非線性關系,即剛開始隨著膽紅素指標的增加,患冠心病的風險會變大,但增加到一定程度后會降低冠心病的風險。年齡對冠心病也有著顯著影響,隨著年齡的增大,患冠心病的風險不斷增大;在其他情況相近的條件下,男性比女性更容易患心臟病。這也得到了現有文獻的支持,文獻[31]研究了患冠心病的終生風險,發現在男性和女性中,患冠心病累積風險隨著年齡的增長而上升,尤其是在60歲后急劇上升,直到大約90歲,之后累積風險似乎趨于平緩,并且男性患上冠心病的風險比女性更高。除了年齡和性別兩個人口統計指標以外,其余人口統計指標對冠心病也有影響。如購買醫療保險的人群似乎有更低的幾率患上冠心??;身高越高的人群患上冠心病的幾率也更低。
圖10給出了SHAP依賴圖,分別選取對模型影響顯著的bilirubin、age、basophils和cholesterol四個特征繪制SHAP特征依賴圖,其中依賴圖第三軸為分類變量。依賴圖顯示:隨著basophils和bilirubin值的增加,SHAP值呈現先上升再下降趨勢;隨著年齡的增大,不同性別患者的SHAP值也隨之增大,這表明年齡對患冠心病有正向影響;在不同年齡段,cholesterol與冠心病之間均存在反向關系,膽固醇增大,SHAP值降低。
除此以外,SHAP還可對個體患冠心病的影響因素進行分析,分別選取一名預測為冠心病和非冠心病患者進行個體影響因素分析。圖11為一名預測為冠心病患者的SHAP特征貢獻圖,紅色部分表示被預測為冠心病的影響因素(見電子版),其被預測為冠心病的原因是年齡較大、嗜堿性粒細胞較多、膽紅素指標超標等原因。
圖12為一名預測為非冠心病患者的SHAP特征貢獻圖,藍色部分表示被預測為非冠心病的影響因素(見電子版),其被預測為非冠心病的原因是年齡較小、嗜堿性粒細胞較少、膽紅素指標正常等原因。
圖13給出了四種模型的特征重要性排名,其排名順序不完全一樣,可以得出影響較為顯著的因素包括bilirubin、age、gender、basophils等。其中LightGBM和SHAP都將bilirubin排到了第一位,可見膽紅素指標是影響冠心病的重要因素;相關研究也表明在一定范圍內,膽紅素含量升高對于心臟具有保護作用[32]。值得關注的是,在XGBoost模型中指出,moderate-work(中等強度工作)是導致冠心病的重要影響因素;文獻[33]研究了工作壓力和冠心病的關系,指出工作壓力是導致冠心病的危險因素,這也為冠心病的預防提供了參考性辦法。此外,模型給出的其他因素需要在診斷時進行綜合分析。
5 結束語
冠心病是困擾人類的世紀重大疾病之一。中國目前有1 100萬冠心病患者,作為冠心病的高發病地區,為提高冠心病的診斷準確率、降低診斷費用,將機器學習算法應用到冠心病預測、輔助醫生作出精確診斷具有重大的現實意義。本文基于集成學習算法XGBoost構建了冠心病預測模型,并且使用SHAP增強模型可解釋性。首先,在對數據集進行異常值處理、特征選擇和不平衡數據處理的基礎上,基于XGBoost算法建立冠心病的預測模型,通過與已有研究提出的六種模型以及邏輯回歸(Logistic)、支持向量機(SVM)、決策樹、隨機森林、Gradient Boost、AdaBoost、LightGBM和MLP等主流機器學習算法的對比實驗,證明了本文模型在預測冠心病上的性能優越性。最后引入SHAP增強模型可解釋性,識別了引起冠心病的主要因素是膽紅素、嗜堿粒細胞計數、年齡、性別等。模型預測性能的提高,可解釋性的增強對于增加冠心病的診斷準確率、降低診斷費用具有重要的應用價值。
在未來的臨床實踐中,一方面,模型還需要與患者的歷史病例數據相融合,從而進一步提高模型的預測性能;另一方面,還需根據真實應用場景的需求,對特征工程、模型訓練、超參數優化、誤差及偏差分析等內容進一步優化。
參考文獻:
[1]梁禮明,黃朝林,石霏,等.融合形狀先驗的水平集眼底圖像血管分割[J].計算機學報,2018,41(7):1678-1692.(Liang Liming, Huang Chaolin, Shi Fei, et al. The image of the horizontal fundus with shape prior fusion[J].Chinese Journal of Computers,2018,41(7):1678-1692.)
[2]張曉宇,王彬,安衛超,等.基于融合損失函數的3D U-Net+腦膠質瘤分割網絡[J].計算機科學,2021,48(9):187-193.(Zhang Xiaoyu, Wang Bin, An Weichao, et al. Glioma segmentation network based on 3D U-Net+with fusion loss function[J].Computer Science,2021,48(9):187-193.)
[3]馬超,譚旭.基于樽海鞘算法優化的帕金森病早期診斷模型研究與并行優化[J].計算機應用研究,2021,38(9):2726-2731.(Ma Chao, Tan Xu. Research and parallel optimization of Parkinson’s disease early diagnosis model based on improved salp swarm algorithm[J].Application Research of Computers,2021,38(9):2726-2731.)
[4]劉平平,張文華,盧振泰,等.基于放射組學特征的胃腸道間質瘤的分類預測[J].計算機科學,2019,46(1):285-290.(Liu Pingping, Zhang Wenhua, Lu Zhentai, et al. Classification and prediction of gastrointestinal stromal tumors based on radiomics[J].Computer Science,2019,46(1):285-290.)
[5]胡滿滿,楊杰,楊焱,等.基于動態采樣和遷移學習的疾病預測模型[J].計算機學報,2019,42(10):2339-2354.(Hu Manman, Yang Jie, Yang Yan, et al. Disease prediction model based on dynamic sampling and transfer learning[J].Journal of Chinese Computer Systems,2019,42(10):2339-2354.)
[6]Baudet M, Daugareil C, Laulom P, et al. Therapeutic education in primary cardiovascular prevention[J].Ann Cardiol Angeiol (Paris),2019,68(1):49-52.
[7]中國心血管健康與疾病報告編寫組.中國心血管健康與疾病報告2019概要[J].中國循環雜志,2020,35(9):833-854.(China Cardiovascular Health and Disease Report Writing Group. Summary of China cardiovascular health and disease report 2019[J].Chinese Circulation Journal,2020,35(9):833-854.)
[8]Haq A U, Li Jianping, Memon M H. A hybrid intelligent system framework for the prediction of heart disease using machine learning algorithms[J].Mobile Information Systems,2018,2018:article ID 3860146.
[9]Liu Xinyun, Jiang Jicheng, Wei Lili, et al. Prediction of all-cause mortality in coronary artery disease patients with atrial fibrillation based on machine learning models[J].BMC Cardiovascular Disorders,2021,21:article No.499.
[10]Alizadehsani R, Hosseini M J, Khosravi A,et al. Non-invasive detection of coronary artery disease in high-risk patients based on the stenosis prediction of separate coronary arteries[J].Computer Methods and Programs in Biomedicine,2018,162(1):119-127.
[11]Babi F, Olejár J, Vantová Z, et al. Predictive and descriptive analy-sis for heart disease diagnosis[C]//Proc of Federated Conference on Computer Science and Information Systems. Piscataway, NJ: IEEE Press,2017:155-163.
[12]Palaniappan S, Awang R. Intelligent heart disease prediction system using data mining techniques[C]//Proc of IEEE/ACS International Conference on Computer Systems and Applications. Piscataway,NJ:IEEE Press,2008:108-115.
[13]王曉麗,施天行,彭德榮,等.兩種機器學習算法構建老年冠心病患病風險評估模型的效能比較研究[J].中華全科醫學,2021,19(4):523-527.(Wang Xiaoli, Shi Tianxing, Peng Derong, et al. Comparative study on the effectiveness of two machine learning algorithms in building risk assessment model of coronary heart disease in the elderly[J].Chinese Journal of General Practice,2021,19(4):523-527.)
[14]朱剛明,譚源滿,陶娟,等.基于深度學習的冠狀動脈CTA人工智能后處理對疑似冠心病患者診斷價值的初步研究[J].臨床放射學雜志,2021,40(11):2128-2133.(Zhu Gangming, Tan Yuanman, Tao Juan, et al. The value of artificial intelligence of coronary CTA based on deep learning in suspected coronary arteriosclerotic heart disease patients[J].Journal of Clinical Radiology,2021,40(11):2128-2133.)
[15]Dutta A, Batabyal T, Basu M,et al. An efficient convolutional neural network for coronary heart disease prediction[J].Expert Systems with Applications,2020,159(2):113408.
[16]孟輝,張加宏,李敏,等.基于IPSO-BP神經網絡與BCG信號的冠心病預測分類研究[J].傳感技術學報,2020,33(10):1379-1385.(Meng Hui, Zhang Jiahong, Li Min, et al. Prediction and classification of coronary heart disease based on IPSO-BP neural network and BCG signal[J].Journal of Sensing Technology,2020,33(10):1379-1385.)
[17]Jahmunah V, Ng E, San T R,et al. Automated detection of coronary artery disease, myocardial infarction and congestive heart failure using GaborCNN model with ECG signals[J].Computers in Biology and Medicine,2021,134:104457.
[18]Shu L O, Vicnesh J, Ru S T,et al. Comprehensive electrocardiographic diagnosis based on deep learning[J].Artificial Intelligence in Medicine,2020,103:101789.
[19]Feng Kai, Pi Xitian, Liu Hongying, et al. Myocardial infarction classification based on convolutional neural network and recurrent neural network[J].Applied Sciences,2019,9(9):1879.
[20]Wang Tao, Lu Changhua, Sun Yining, et al. Automatic ECG classification using continuous wavelet transform and convolutional neural network[J].Entropy,2021,23(1):119.
[21]Chen Tianqi, Guestrin C. XGBoost: a scalable tree boosting system[C]//Proc of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining.New York:ACM Press,2016:785-794.
[22]潘進,丁強,江愛朋,等.基于XGBoost的冷水機組不平衡數據故障診斷[J].機械強度,2021,43(1):27-33.(Pan Jin, Ding Qiang, Jiang Aipeng, et al. Fault diagnosis of water chiller unbalance data based on XGBoost[J].Journal of Mechanical Strength,2021,43(1):27-33.
[23]Lundberg S M, Lee S I. A unified approach to interpreting model predictions[C]//Proc of the 31st International Conference on Neural Information Processing Systems.2017:4768-4777.
[24]Wojtuch A, Jankowski R, Podlewska S. How can SHAP values help to shape metabolic stability of chemical compounds?[J].Journal of Cheminformatics,2021,13(1):1-20.
[25]Parsa A B, Movahedi A, Taghipour H,et al. Toward safer highways, application of XGBoost and SHAP for real-time accident detection and feature analysis[J].Accident Analysis amp; Prevention,2020,136:105405.
[26]Huan Liu, Lei Yu. Toward integrating feature selection algorithms for classification and clustering[J].IEEE Trans on Knowledge and Data Engineering,2005,17(4):491-502.
[27]Yu Lei, Liu Huan. Efficient feature selection via analysis of relevance and redundancy[J].The Journal of Machine Learning Research,2004,5:1205-1224.
[28]Oh I S, Lee J S, Moon B R. Hybrid genetic algorithms for feature selection[J].IEEE Trans on Pattern Analysis and Machine Intelligence,2004,26(11):1424-1437.
[29]Blagus R, Lusa L. Class prediction for high-dimensional class-imba-lanced data[J].BMC Bioinformatics,2010,11:article No.523.
[30]Chawla N V, Bowyer K W, Hall L O,et al. SMOTE: synthetic minority over-sampling technique[J].Journal of Artificial Intelligence Research,2002,16(1):321-357.
[31]Sanchez-Delgado E, Liechti H. Lifetime risk of developing coronary heart disease[J].Lancet,1999,353(9156):924-925.
[32]張紅民,沈建國,邵回龍,等.冠心病與膽紅素關系探討[J].中國急救醫學,2002(1):46-47.(Zhang Hongmin, Shen Jianguo, Shao Huilong, et al. Relationship between coronary heart disease and bilirubin[J].Chinese Journal of Critical Care Medicine,2002(1):46-47.)
[33]Kivimki M, Nyberg S T, Batty G D, et al. Job strain as a risk factor for coronary heart disease: a collaborative meta-analysis of individualparticipant data[J].Lancet,2012,380(9852):1491-1497.