中國人民大學應用統計科學研究中心,統計學院(100872) 李 嶸 張文麗 李 揚 林存潔
【提 要】 目的 將深度學習方法應用在大規模腫瘤數據中,并預測腫瘤患者的個體生存情況,提升預測精度,為個體化治療方案提供參考。方法 以老年乳腺癌數據為例,將生存時間劃分成離散區間,通過神經網絡方法預測患者在各離散區間內的死亡概率,實現個體生存函數的預測。結果 對于19576例老年女性乳腺癌的個體生存函數預測情況,本文提出的方法預測效果好于其他的模型,表現在有更大的c-index指標和更大的log-rank統計量值。結論 基于深度學習的生存函數預測有較大的靈活性,不受Cox模型比例風險假設的限制,能夠處理大規模數據,并且對個體生存函數的預測更加準確。
全球癌癥負擔日益加重,腫瘤的發病率和死亡率日益增加,已成為威脅人類健康的主要危險因素。隨著電子病歷(EMR)和腫瘤基因組學的發展與普及,腫瘤病人的相關臨床數據量不斷增加,而大規模的腫瘤數據為精準醫學提供了良好的研究基礎[1]。精準醫學自提出以來一直被廣泛重視,2015年3月,我國科技部首次召開國家精準醫學戰略專家會議,計劃啟動中國的精準醫學計劃,隨后精準醫學被列入國家重點研發項目并正式進入啟動階段[2-3]。精準醫學根據患者的特異性進行個性化的預防或治療干預,通過預測腫瘤患者的個體生存情況來確定個體化治療方案。本文的研究對象為老年乳腺癌患者,乳腺癌是女性最常見的惡性腫瘤之一,隨著人口老齡化及女性平均壽命的延長,老年乳腺癌(以大于65歲為界限)發病率明顯增多。由于老年病人的體質和健康狀況差異較大,尚無規范的治療模式,因此對于老年乳腺癌的治療應該按照個體化原則確定治療方案[4]。
預測生存函數是生存分析中的重要任務,而大規模腫瘤數據為研究建立了基礎的同時也帶來了挑戰。龐大的數據量使得經典的Cox模型難以計算,另外,Cox模型假設風險函數的對數是解釋變量的線性組合且解釋變量的影響不隨時間變化,該比例風險假設在實際問題中難以被滿足。近年來,隨著機器學習的發展,利用深度學習方法處理生存數據的研究也取得了一些進展,突出的方法包括Cox-nnet[5]、DeepSurv[6]和Nnet-Survival[7]。其中Cox-nnet方法利用一層神經網絡進行降維后將輸出的結果作為解釋變量擬合Cox模型,DeepSurv方法則是基于Cox模型的部分似然函數利用深度學習模型擬合風險函數。但是Cox-nnet和DeepSurv這兩種方法仍在不同程度上保留了Cox模型的假設,因此具有一定的局限性。而Nnet-survival方法則是將生存時間離散化,然后估計各區間的條件風險函數。在本文中,我們借鑒Nnet-Survival的思想,但是更加關注每個離散區間上生存函數的估計,把生存分析問題轉化成深度學習問題,進而提高生存函數的預測精度。該方法完全摒棄了Cox模型的假設,能夠更加靈活地處理生存數據,給出更加準確的預測結果,同時保持了深度學習算法對大規模數據的有效性,因此能夠更好地適用于大規模腫瘤數據的生存分析。
假設我們的觀測數據是右刪失數據,即存在部分樣本,截止到觀測時間結束,感興趣的事件(例如死亡事件)仍沒有發生。不妨設觀測樣本為:(Ti,Zi,δi),i=1,…,n。其中Ti=min(Xi,Ci),Xi表示個體i的真實生存時間,Ci表示個體i的刪失時間,Ti即為可觀察到的兩者中的最小值,Zi表示p維協變量。δi是指示變量,δi=0表示數據刪失(即Ci pj=P(tj-1 其中S(t)=P(T>t)表示生存函數。如圖1所示,在第j個區間終點tj處的生存函數為: 通過估計離散區間端點處生存函數的值就可以很好地近似完整的生存函數曲線,因此對于某個個體而言,其生存函數可以對應到一組長度為M的向量p=(p1,p2,…,pM),這里p表示M個離散區間中死亡事件發生的概率。從而對于生存函數的估計就轉化成對于p的估計??紤]到解釋變量對p的影響可能是復雜的非線性關系,也可能隨著時間進展而變化,因此,采用深度學習對p進行估計。 圖1 離散區間結構 1.神經網絡的結構 神經網絡的結構包括輸入層、隱藏層和輸出層。本文采用全連接神經網絡,即層與層之間每個神經元都有連接。 (1)輸入層 輸入層是影響生存時間的解釋變量Z,輸入層神經元個數等于解釋變量的維數。 (2)隱藏層 隱藏層的層數和各層神經元的個數可以自行選擇。隱藏層采用sigmoid激活函數。隱藏層中每一個神經元的輸出值是所有連接到該神經元的輸入值的線性組合再經過sigmoid激活函數非線性處理后的結果。 (3)輸出層 一般地,離散區間的個數M可取15~40個且模型的表現對離散區間的選擇比較穩定,本文通過下式確定前疏后密的區間端點[7]: 其中,t*=0.27tmax,tmax為區間終點。 2.神經網絡的訓練 (1)構建目標函數 其中,第二項為正則項,wk記為神經網絡中的參數,λ為調節系數控制懲罰力度的大小,通過對參數添加L2懲罰以防止模型過擬合。 (2)Minibatch梯度下降算法 求解神經網絡以使得目標函數最小化,通過反向傳播算法對目標函數進行求導,然后采用Minibatch梯度下降算法對網絡中參數進行更新。Minibatch梯度下降法適用于大規模數據集,由于個體似然函數間互相獨立,因此可以將大規模數據集拆分成多個小樣本集,在每個小樣本集中更新參數[8-9]。首先將全部樣本劃分為訓練集和測試集,記訓練集中的樣本可以劃分為B個小樣本集,每個小樣本集中包含的樣本點個數為nb(b=1,…,B)。在每個小樣本集中通過以下的方式依次更新參數: 其中η(b)表示第b次迭代中的步長,也稱作學習率,w(b)-w(b-1)表示動量,記錄了上一次迭代時系數改變的方向,增加動量項可以在一定程度上避免陷入局部最優點及大幅度震蕩。B次更新記作一代訓練,一代是指遍歷了訓練集一次,本文中一代訓練內采用相同的步長,設置步長的初始值為0.005。再將上述一代訓練重復至收斂,本文為防止過擬合,設置停止準則為連續300代更新之后測試集上的目標函數沒有減少則停止訓練。為提高收斂效率,設置步長為每100代訓練以0.8倍減小。 (3)超參數選擇 上述Minibatch梯度下降算法中包含一系列超參數,包括目標函數中正則項的調節系數λ,神經網絡的隱藏層數,各隱藏層神經元個數及更新準則中的動量項參數α。本文通過比較各組超參數組合下測試集的目標函數值以確定使得測試集目標函數值最小的超參數組合。 本文通過深度學習預測老年乳腺癌患者的生存函數,數據來源于美國國立癌癥研究所SEER(Surveillance,Epidemiology,and End Results Program)數據庫,分析1994-2003年年齡大于等于65歲的19576例女性乳腺癌患者的病歷資料。通過預測其生存函數來了解患者的生存情況以便更好地做出治療決策。 1.數據描述 該數據中記錄病例的生存時間的中位數是119個月,觀測到的最長生存時間為263個月,刪失率為5.9%。連續變量中只有腫塊大小存在缺失,缺失比例為11.92%,采用中位數插補。為分類變量添加虛擬變量,其中關于腫瘤位置只設置一個虛擬變量以防止共線性。參考已有文獻中對乳腺癌危險因素的討論[10-11],最終從26個解釋變量中選擇出8個變量納入分析,各變量的描述如表1。 表1 解釋變量統計表 繪制KM曲線擬合整體的生存函數如圖2,總體生存函數在150個月之前下降速度略慢于150個月之后,表示后期風險略大于前期。 圖2 老年乳腺癌患者KM生存曲線 2.預測結果比較 本文基于深度學習預測老年乳腺癌患者的個體生存函數,劃分36個離散生存區間,通過估計各區間內的死亡概率得到各區間終點處的生存函數的估計,將該方法記為DL-Survival?,F有的生存函數估計方法包括Cox模型,以及利用深度學習處理生存數據的Cox-nnet、DeepSurv和Nnet-Survival。分別采用這五種方法預測老年乳腺癌的生存函數,并通過c-index和log-rank兩個指標評價各種方法的預測準確性,這兩個指標均是生存分析中常用的評價指標[12-13]。c-index計算所有可比的個體對中估計結果的相對關系和實際相對關系一致的比例,是衡量生存分析模型表現的常用指標,其大小在0到1之間,越接近1表示方法的預測精度越高。log-rank檢驗統計量的原理是先根據預測結果把人群按照中位數分為高風險人群和低風險人群,然后對這兩組人群的KM估計曲線進行log-rank檢驗。log-rank檢驗統計量值越大表示方法區分高風險和低風險人群的效果越好。各方法的比較結果如表2所示,本文提出的DL-Survival方法在個體生存函數的預測中表現最好。 表2 各方法對老年乳腺癌患者生存函數預測結果比較 對于個體生存函數的預測有助于掌握患者的生存情況,以便優化信息和決策。本文采用深度學習的方法,通過估計離散區間的死亡概率預測個體的生存函數。不同于KM方法對群體生存情況的估計,本文基于個體特征對每個患者的生存函數進行預測。同時本文提出的深度學習算法摒棄了Cox模型中比例風險假設,在實際應用中會更加靈活。在滿足等比例風險的條件下與基于Cox的方法能達到相同的效果;在不滿足等比例風險的條件下能夠優于基于Cox的方法。而相比于其他不受比例風險限制的機器學習方法,本文提出的方法更加直觀地預測生存函數,并且可以處理較大規模的數據,其適用性更加廣泛。然而在實際應用中運用哪種方法需要綜合考慮,例如,當樣本量較小時,深度學習方法由于訓練樣本量不足易產生過擬合,預測結果不一定優于Cox模型。 對大規模腫瘤數據仍然需要更多探索,大規模數據的特點通常包括樣本量大,變量維數多以及數據來源多樣化。對于更大樣本量的數據,基于個體似然函數相互獨立,可以考慮分治法(divide and conquer)以降低計算成本。另外,本文中對于SEER老年乳腺癌患者的分析涉及到的解釋變量個數不多,當數據中變量維數較多時,可以考慮在神經網絡中加入稀疏層,在預測生存函數的同時進行變量選擇[14],以尋找影響老年乳腺癌患者生存情況的風險因素。為充分利用不同實驗室或研究機構的數據來源,還可以考慮整合分析方法,探索數據集間的關聯性和差異性,有助于精準醫學對于不同亞群患者的治療和決策。


實例分析



討 論