王紅霞,顧 鵬,李枝峻,寧樞麟
(沈陽理工大學(xué)信息科學(xué)與工程學(xué)院,遼寧 沈陽 110159)
復(fù)雜場(chǎng)景中準(zhǔn)確地對(duì)行人軌跡進(jìn)行預(yù)測(cè)對(duì)自動(dòng)駕駛系統(tǒng)至關(guān)重要[1]。行人作為自動(dòng)駕駛的主要參與者,其軌跡更加自由、靈活、復(fù)雜,準(zhǔn)確地對(duì)行人軌跡進(jìn)行預(yù)測(cè)可以規(guī)避潛在的風(fēng)險(xiǎn)。行人軌跡預(yù)測(cè)的難點(diǎn)在于行人的軌跡受主觀因素的影響,同時(shí)也受到周圍行人運(yùn)動(dòng)的影響,對(duì)其精準(zhǔn)預(yù)測(cè)有較高的難度。
傳統(tǒng)的行人軌跡預(yù)測(cè)通過手工提取行人交互特征,不僅有交互不足的缺點(diǎn),而且缺少適配的數(shù)據(jù)集。隨著深度學(xué)習(xí)的興起,長(zhǎng)短期記憶網(wǎng)絡(luò)(Long Short-Term Memory,LSTM)被用于序列預(yù)測(cè),通過復(fù)雜的網(wǎng)絡(luò)模型再搭配更適配的數(shù)據(jù)集得到了更好的預(yù)測(cè)效果。ALAHI等[2]提出采用LSTM網(wǎng)絡(luò)提取行人運(yùn)動(dòng)軌跡信息,加入池化模塊,在池化模塊范圍內(nèi)的行人共享信息,從而達(dá)到提取社交關(guān)系的目的。相對(duì)于單獨(dú)使用RNN網(wǎng)絡(luò)進(jìn)行預(yù)測(cè),考慮了行人交互,取得了更好的預(yù)測(cè)效果,缺點(diǎn)是池化操作不能區(qū)別對(duì)待行人之間的關(guān)系。生成對(duì)抗網(wǎng)絡(luò)[3](Generative Adversarial Networks,GAN)不需要復(fù)雜的模型,通過對(duì)抗學(xué)習(xí)達(dá)到更好的效果。GUPTA等[4]基于GAN網(wǎng)絡(luò)提出了(Social GAN,SGAN)模型,采用生成器-判別器模式訓(xùn)練模型,提出新的損失函數(shù)鼓勵(lì)網(wǎng)絡(luò)預(yù)測(cè)多條軌跡,相比先前模型只是預(yù)測(cè)一條“平均好”的軌跡,在預(yù)測(cè)精度上有了很大提升,缺點(diǎn)是采用全局池化會(huì)加入無效交互且增大運(yùn)算量,并且GAN模型不易訓(xùn)練。為了解決交互不足的缺點(diǎn),有學(xué)者提出采用注意力機(jī)制模塊提取信息[5-6],相對(duì)全局池化,在提取交互關(guān)系上有了一定改進(jìn),缺點(diǎn)仍是不能區(qū)別對(duì)待行人關(guān)系。
在對(duì)行人交互信息提取中,多數(shù)研究采用池化層進(jìn)行交互,缺點(diǎn)是損失了大量信息,融入了無效交互。圖神經(jīng)網(wǎng)絡(luò)(Graph Neural Network,GNN)的出現(xiàn),將行人關(guān)系映射到圖中進(jìn)行建模,更符合社交關(guān)系網(wǎng)絡(luò)。圖注意力網(wǎng)絡(luò)(Graph Attention Network,GAT)是GNN的變體。通過注意力機(jī)制(Attention Mechanism)對(duì)鄰居節(jié)點(diǎn)做聚合操作。HUANG等[7]基于GAT模型為每個(gè)鄰居節(jié)點(diǎn)分配注意力系數(shù),然后通過加權(quán)求和得到節(jié)點(diǎn)新的特征向量,缺點(diǎn)是當(dāng)前基于圖神經(jīng)網(wǎng)絡(luò)的模型在提取行人之間交互信息所考慮的信息是片面的,忽略了邊特征在圖中的作用。
針對(duì)上述問題,本文提出一種結(jié)合邊特征的時(shí)空?qǐng)D自注意力預(yù)測(cè)模型(Spatial-Temporal Graph Network with Edge Feature Generative Adversarial Networks,STGEF-GATv2)對(duì)行人軌跡進(jìn)行預(yù)測(cè)。采用編碼器-解碼器結(jié)構(gòu)作為主體結(jié)構(gòu),引入邊特征提取模塊,將行人歐式距離構(gòu)建鄰接矩陣,輸入到特征融合層,將輸出結(jié)果用作邊特征,增加了GAT可學(xué)習(xí)信息。空間交互采用更加有效的GATv2模塊[8]替換GAT模塊,采用更少的頭,提升精度的同時(shí)降低了模型復(fù)雜度。最后,STGEF-GATv2模型采用模塊時(shí)空信息融合模塊[7],可以更好地提取行人間的時(shí)空交互信息,從而提高預(yù)測(cè)模型的精確度。

STGEF-GAT模型的架構(gòu)如圖1所示,主要采用編碼器-解碼器模型。本文在編碼器采用LSTM模型學(xué)習(xí)行人運(yùn)動(dòng)軌跡隱藏狀態(tài),將行人的相對(duì)位置坐標(biāo)輸入特征提取模塊后與LSTM網(wǎng)絡(luò)的隱藏狀態(tài)進(jìn)行融合,采用GATv2來提取行人空間上的交互特征。另外采用一個(gè)LSTM學(xué)習(xí)行人空間交互的隱藏信息,解碼器采用LSTM網(wǎng)絡(luò),以時(shí)間特征、空間特征、交互特征[10]結(jié)合高斯噪聲作為輸入,未來的軌跡作為輸出,最終通過反向傳播進(jìn)行模型訓(xùn)練。

圖1 行人軌跡預(yù)測(cè)模型架構(gòu)圖
原始GAT結(jié)構(gòu)有一個(gè)缺陷,在學(xué)習(xí)注意力機(jī)制時(shí)不使用任何邊的可用信息,為解決這一缺點(diǎn)本算法通過添加歸一化層,如圖2所示,輸入行人軌跡相對(duì)位置坐標(biāo),構(gòu)建歐式距離鄰接矩陣,歸一化后進(jìn)行維度映射。最后與相對(duì)位置坐標(biāo)經(jīng)過LSTM后的隱藏特征進(jìn)行融合,增加了節(jié)點(diǎn)邊特征向量的信息,提供GAT層更豐富的學(xué)習(xí)信息。上述過程的計(jì)算表達(dá)式為:

圖2 邊特征融合模塊示意圖
(1)

進(jìn)行歸一化的原因是為了減少過擬合,加快模型收斂,此處采用GCN(Graph Convolutional Neural Network)中的歸一化方法,保證了圖的對(duì)稱性,并且有利于神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)。
最后進(jìn)行相加操作,相較于拼接操作和乘法操作的特征融合,采用加法融合使維度更低,減少運(yùn)算量,降低模型復(fù)雜度,同時(shí)預(yù)測(cè)精度并沒有降低。
在圖結(jié)構(gòu)中行人交互關(guān)系如圖3所示。GAT是GNN模型中重要的一種變體,多頭注意力機(jī)制能進(jìn)一步提升注意力層的表達(dá)能力,如圖4所示[11]。

圖3 行人交互關(guān)系示意圖

圖4 多頭注意力機(jī)制示意圖
首先計(jì)算鄰居節(jié)點(diǎn)及自身的注意力系數(shù),然而BRODY等[8]發(fā)現(xiàn)GAT的注意力有很強(qiáng)的局限性,在GAT中每個(gè)節(jié)點(diǎn)都只關(guān)心鄰居節(jié)點(diǎn),BRODY認(rèn)為不論自身節(jié)點(diǎn)特征怎么變,得到的注意力權(quán)重的計(jì)算結(jié)果都是相同的,將這種注意力稱為靜態(tài)注意力。BRODY通過改進(jìn)GAT提出了GATv2,由公式(2)演變到公式(3)。經(jīng)過實(shí)驗(yàn)可知,采用兩個(gè)頭的實(shí)驗(yàn)結(jié)果比原有GAT四個(gè)頭甚至更多頭的實(shí)驗(yàn)效果好,降低了模型復(fù)雜度,減少了運(yùn)算量,最后通過公式(4)對(duì)注意力系數(shù)進(jìn)行加權(quán)求和,得到新的節(jié)點(diǎn)特征向量。
(2)
其中,‖表示拼接操作,wgat表示共享權(quán)重矩陣,hi,t表示節(jié)點(diǎn)的特征,aT是一個(gè)注意力核函數(shù),目的是進(jìn)行維度映射。L為激活函數(shù),αij,t為t時(shí)刻權(quán)重系數(shù)。
(3)
與式(2)不同,式(3)先進(jìn)行特征拼接[hi,t‖hj,t],然后再與權(quán)重矩陣wgat相乘。經(jīng)過激活函數(shù)L,再與注意力核函數(shù)相乘。
對(duì)得到的注意力系數(shù)進(jìn)行加權(quán)求和,得到新特征向量,為加強(qiáng)GAT學(xué)習(xí)能力,采用多頭注意力機(jī)制,如公式(4)所示。
(4)


(5)

將節(jié)點(diǎn)特征輸入長(zhǎng)短期記憶網(wǎng)絡(luò)模型,學(xué)習(xí)行人運(yùn)動(dòng)的隱藏信息,數(shù)學(xué)表達(dá)式如下:
mi,t=T(mi,t-1,ei,t;Wm),
(6)
其中,Wm為L(zhǎng)STM權(quán)重系數(shù),mi,t為t-1時(shí)刻運(yùn)動(dòng)隱藏信息,ei,t為節(jié)點(diǎn)特征,T表示循環(huán)神經(jīng)網(wǎng)絡(luò)。
將運(yùn)動(dòng)隱藏特征和邊特征融合模塊進(jìn)行相加操作:
Mi,t=mi,t+H,
(7)
其中,mi,t表示運(yùn)動(dòng)隱藏信息,H表示邊特征提取結(jié)果,來自公式(1),此融合操作對(duì)應(yīng)圖1中的C模塊。
學(xué)習(xí)行人空間交互信息可表示為:
gi,t=T(gi,t-1,Gi,t;W),
(8)
其中,Gi,t為節(jié)點(diǎn)經(jīng)過多頭圖注意力機(jī)制的向量,gi,t-1為歷史空間交互向量,W為長(zhǎng)短期記憶網(wǎng)絡(luò)的權(quán)重。
將時(shí)間空間信息進(jìn)行融合:
hi,t=σ(Mi,t)‖σ(gi,t),
(9)
其中,σ(·)為不同的多層感知器,目的是在相同維度進(jìn)行拼接。
最后將結(jié)果拼接高斯噪聲,提升模型的魯棒性和泛化能力。
Ti,Tobs=hi,t‖z,
(10)
其中,z表示高斯噪聲,hi,t為由式(9)學(xué)習(xí)到的時(shí)空信息的融合結(jié)果,Ti,Tobs為融合后的特征向量。
軌跡預(yù)測(cè)是通過行人相對(duì)位置8個(gè)步長(zhǎng)信息,結(jié)合學(xué)習(xí)到的時(shí)空交互信息,預(yù)測(cè)下一個(gè)步長(zhǎng)位置信息。然后將預(yù)測(cè)位置信息結(jié)合前7個(gè)步長(zhǎng)進(jìn)行新的預(yù)測(cè),以此類推預(yù)測(cè)12個(gè)步長(zhǎng)信息。預(yù)測(cè)的相對(duì)位置信息由公式(12)得到。
di.Tobs+1=T(di,Tobs,ei,Tobs;W),
(11)
(xi,t+1,yi,t+1)=M(di,Tobs+1,Wσ),
(12)
其中,ei,Tobs表示相對(duì)位置特征向量,di,Tobs為上一步長(zhǎng)隱藏信息,di,Tobs+1為預(yù)測(cè)下一步長(zhǎng)信息,W為L(zhǎng)STM的權(quán)重,Wσ為MLP權(quán)重,經(jīng)過嵌入函數(shù)M,得到下一相對(duì)位置坐標(biāo)。
采用ETH[12]和UCY[13]兩個(gè)公共數(shù)據(jù)集進(jìn)行模型評(píng)估。數(shù)據(jù)集由實(shí)際生活中行人豐富的交互信息組成,其中ETH數(shù)據(jù)集包含兩個(gè)場(chǎng)景:UNIV和HOTEL。UCY數(shù)據(jù)集包含三個(gè)場(chǎng)景:ZARA01、ZARA02和UNIV。數(shù)據(jù)集將行人真實(shí)位置轉(zhuǎn)換為世界坐標(biāo)系下的位置信息。實(shí)驗(yàn)采用留一法,使用其中4個(gè)數(shù)據(jù)集進(jìn)行訓(xùn)練和驗(yàn)證,在剩余1個(gè)數(shù)據(jù)集進(jìn)行測(cè)試。
超參數(shù)設(shè)置:使用Adam優(yōu)化器進(jìn)行參數(shù)優(yōu)化,學(xué)習(xí)率設(shè)為0.001,批處理大小設(shè)為64,訓(xùn)練輪數(shù)為600。實(shí)驗(yàn)環(huán)境:操作系統(tǒng)為Ubuntu 20.04,處理器顯卡型號(hào)為2080Ti,PyTorch版本為1.2,Cuda版本為11.3,所有實(shí)驗(yàn)都是在相同的硬件環(huán)境下進(jìn)行。
實(shí)驗(yàn)時(shí),在模型中輸入8個(gè)時(shí)間步長(zhǎng)(3.2 s)的行人真實(shí)軌跡,輸出預(yù)測(cè)的未來12個(gè)時(shí)間步長(zhǎng)(4.8 s)的行人軌跡。與之前的研究相同,本文使用兩個(gè)指標(biāo)來評(píng)估預(yù)測(cè)誤差。
3.2.1 平均位置誤差
平均位置誤差(Average Displacement Error,ADE)為每一時(shí)間步的預(yù)測(cè)坐標(biāo)與真實(shí)坐標(biāo)之間的均方誤差,計(jì)算公式如下:
(13)

3.2.2 最終位置誤差
最終位置誤差(Final Displacement Error,FDE)為在預(yù)測(cè)的最后一個(gè)時(shí)間步T,預(yù)測(cè)坐標(biāo)與真實(shí)坐標(biāo)的誤差。計(jì)算公式如下:
(14)

為驗(yàn)證所提出的邊特征融合模塊和GATv2的有效性,采取調(diào)整算法模塊的方法。并在公開數(shù)據(jù)集UCY和ETH上對(duì)ADE和FDE兩個(gè)指標(biāo)進(jìn)行對(duì)比,如表1和表2所示,其中加粗黑體為最好的預(yù)測(cè)結(jié)果。算法1為基礎(chǔ)算法,只采用GAT進(jìn)行特征提取,不添加任何改進(jìn);算法2在基礎(chǔ)算法上采用邊特征融合模塊;算法3在基礎(chǔ)算法上將GAT模塊替換成GATv2模塊;算法4在基礎(chǔ)算法上同時(shí)加入邊特征融合模塊和GATv2模塊。

表1 本文算法消融實(shí)驗(yàn)ADE指標(biāo)

表2 本文算法消融實(shí)驗(yàn)FDE指標(biāo)
由表1和表2可以看出,當(dāng)基礎(chǔ)算法增加邊特征后在五個(gè)數(shù)據(jù)集下比較兩個(gè)指標(biāo),若有一定的提升,或者不變,則表明該模型能有效地利用邊的信息。在基礎(chǔ)算法上替換GATv2模塊,除了ZARA1的ADE略有下降,FDE不變,其他指標(biāo)都有提升,且單獨(dú)使用GATv2模塊的效果要好于邊特征融合,提升效果明顯。最后融合算法2和算法3后,兩種評(píng)估指標(biāo)在ETH、HOTEL、UNIV上都有很好的提升,在ZARA1上的效果和邊特征融合效果一樣,在ZARA2上指標(biāo)略有下降。分析原因是ZARA1、ZARA2數(shù)據(jù)集中行人密度小,行人交互性不強(qiáng),增加太多交互關(guān)系,產(chǎn)生了一定的過擬合問題。實(shí)驗(yàn)結(jié)果證明,邊特征融合和GATv2模塊的應(yīng)用均可提升預(yù)測(cè)精度。
為了評(píng)估STGEF-GAT算法的性能,本文選取了七種算法(Linear、LSTM、S-LSTM、S-GAN、Sophie[14]、STGAT、STGEF-GATv2)進(jìn)行ADE和FDE的對(duì)比,如表3和表4所示。所有算法的行人觀測(cè)時(shí)間為3.2 s,行人預(yù)測(cè)時(shí)間為4.8 s。表中黑體為最好預(yù)測(cè)結(jié)果。

表3 本文算法與其他算法的ADE指標(biāo)比較
表3和表4的實(shí)驗(yàn)結(jié)果證明了邊特征融合和GATv2模塊的有效性,在五個(gè)不同的數(shù)據(jù)集上有較好的表現(xiàn),除ZARA1數(shù)據(jù)集的ADE和FDE指標(biāo)低于Sophine模型,其余四個(gè)數(shù)據(jù)集的ADE和FDE指標(biāo)都要高于所比較的算法,最終平均值全部高于所比較的算法預(yù)測(cè),達(dá)到了提高模型精度的目的。
為了更好地展示模型的預(yù)測(cè)效果,對(duì)模型進(jìn)行軌跡可視化展示和權(quán)重分配可視化展示,以下兩種有效的可視化展示模塊來自于文獻(xiàn)[7]。在ZARA1數(shù)據(jù)集四個(gè)不同場(chǎng)景下的可視化軌跡預(yù)測(cè)如圖5所示,其中實(shí)線代表觀測(cè)軌跡,虛線代表預(yù)測(cè)軌跡,與虛線相近的實(shí)線表示真實(shí)軌跡。由圖5可以很好地看出STGEF-GATv2模型的有效性。

圖5 不同場(chǎng)景的可視化圖
從圖5的場(chǎng)景3可以看出,同向行走的預(yù)測(cè)軌跡和真實(shí)軌跡基本相符。逆向行走對(duì)行人預(yù)測(cè)軌跡影響較大,導(dǎo)致預(yù)測(cè)位置發(fā)生偏移,且距離位置越近,影響越大,如場(chǎng)景1和場(chǎng)景2相反行走的行人,當(dāng)兩人發(fā)生交互時(shí)對(duì)預(yù)測(cè)同樣產(chǎn)生較大的影響,如場(chǎng)景4。分析原因是當(dāng)兩人交互時(shí),行人之間交互影響增強(qiáng),從而導(dǎo)致預(yù)測(cè)位置發(fā)生偏移。綜上所述,預(yù)測(cè)可視化很好地展示了模型預(yù)測(cè)的精度。
為了更好體現(xiàn)模型是否與周圍行人產(chǎn)生聯(lián)系,對(duì)目標(biāo)行人為周圍行人分配的注意力權(quán)重進(jìn)行可視化分析,如圖6所示。軌跡上的黑點(diǎn)表示不同時(shí)間的步長(zhǎng),箭頭代表行人前進(jìn)的方向,沒有圓圈代表目標(biāo)行人的軌跡,圓圈的大小和權(quán)重分配成正比,權(quán)重越大圓圈越大。

圖6 權(quán)重分配可視化圖
在圖6的場(chǎng)景1和場(chǎng)景2中,同向行走的行人相較于逆向行駛的人會(huì)分配得到更低的權(quán)重,表明與預(yù)測(cè)行人同向行駛的人有較小的影響力,反之有較大的影響力;在場(chǎng)景3和場(chǎng)景4中,同向后方行走人員相較于同向前方行走人員對(duì)目標(biāo)行人影響更小,分配得到更少的權(quán)重。不足的是對(duì)于靜止不動(dòng)的行人(場(chǎng)景2),始終分配最大的權(quán)重。綜上所述,該算法可以有效分配目標(biāo)行人和周圍行人的權(quán)重信息,充分表明了該模型的有效性。
行人邊特征提取模塊能有效解決圖注意網(wǎng)絡(luò)邊信息的缺失問題,此模塊為圖自注意力層提供更多可學(xué)習(xí)信息。采用GAT的變體GATv2,可以用更少的頭和層數(shù)達(dá)到更好的預(yù)測(cè)效果,降低了模型復(fù)雜度,同時(shí)也提升了模型的預(yù)測(cè)準(zhǔn)確度。將預(yù)測(cè)結(jié)果進(jìn)行可視化研究分析,可以更直觀地感受該模型軌跡預(yù)測(cè)效果。為驗(yàn)證模型的有效性,在ETH和UCY兩個(gè)數(shù)據(jù)集上進(jìn)行實(shí)驗(yàn),結(jié)果顯示,平均位置誤差和最終位置誤差兩個(gè)指標(biāo)都有所提升。