







摘 要:離散序列生成廣泛應用于文本生成、序列推薦等領域。目前的研究工作主要集中在提高序列生成的準確性,卻忽略了生成的多樣性。針對該現象,提出了一種自適應序列生成方法ECoT,設置兩層元控制器,在數據層面,使用元控制器實現自適應可學習采樣,自動平衡真實數據與生成數據分布得到混合數據分布;在模型層面,添加多樣性約束項,并使用元控制器自適應學習最優(yōu)更新梯度,提升生成模型生成多樣性。此外,進一步提出融合協同訓練和對抗學習的方法,提升生成模型生成準確性。與目前的主流模型進行對比實驗,結果表明,在生成準確性和多樣性上,自適應協同訓練序列生成方法具有更均衡的準確性和多樣性,同時有效緩解了生成模型的模式崩潰問題。
關鍵詞:深度學習;機器學習;序列生成;協同訓練;對抗學習
中圖分類號:TP391 文獻標志碼:A
文章編號:1001-3695(2022)07-025-2081-06
doi:10.19734/j.issn.1001-3695.2021.12.0681
基金項目:國家社會科學基金重大資助項目(13amp;ZD091,18ZDA200);河北省重點研發(fā)計劃項目(20370301D);河北師范大學重大關鍵技術攻關項目(L2020K01)
作者簡介:張寶奇(1996-),男,河北承德人,碩士研究生,主要研究方向為機器學習、智能信息處理;趙書良(1967-),男(通信作者),河北滄州人,教授,博導,主要研究方向為機器學習、智能信息處理(zhaoshuliang@sina.com);張劍(1991-),男,河北石家莊人,碩士研究生,主要研究方向為機器學習、智能信息處理;呂曉鋒(1996-),男,河北石家莊人,碩士研究生,主要研究方向為機器學習、智能信息處理.
Sequence generation method based on adaptive learning
Zhang Baoqia,b,c,Zhao Shulianga,b,c?,Zhang Jiana,b,c,Lyu Xiaofenga,b,c
(a.College of Computer amp; Cyber Security,b.Hebei Provincial Engineering Research Center for Supply Chain Big Data Analytics amp; Data Security,c.Hebei Provincial Key Laboratory of Network amp; Information Security,Hebei Normal University,Shijiazhuang 050024,China)
Abstract:Discrete sequence generation is widely used in text generation,sequence recommendation and other fields.The current research work mainly focuses on improving the accuracy of sequence generation,but ignores the diversity of generation.To address this phenomenon,this paper proposed an adaptive sequence generation method (ECoT),and designed a two-layer meta controller.In the data layer,the function of meta controller was to realize adaptive learning sampling,automatically balance the distribution of real data and generated data,and obtain mixed data distribution.At the model level,this paper added diversity constraints.The function of the meta controller was to adaptively learn the optimal update gradient to improve the generation diversity of the generation model.In addition,in order to improve the accuracy of the generation model,this paper proposed a method combining cooperative training and adversarial learning.Compared with the current mainstream models,the results show that the adaptive cooperative training sequence generation method has more balanced accuracy and diversity in terms of generation accuracy and diversity,and can effectively alleviate the pattern collapse of the generation model.
Key words:deep learning;machine learning;sequence generation;cooperative training;adversarial learning
0 引言
序列生成模型廣泛應用于自然語言生成(natural language generation,NLG)任務[1~6]、推薦系統(recommendation system,RS)[7~10]等諸多領域。基于極大似然估計(maximum likelihood estimation,MLE)的神經網絡模型是序列生成的基本方法,繼神經網絡反向傳播算法出現之后,基于極大似然估計的前饋神經網絡和循環(huán)神經網絡的序列生成,雖然可以得到與訓練集樣本相似的序列數據,但基于極大似然估計方法生成的樣本在質量上良莠不齊。首先,在模型訓練階段,生成序列每一步所需的序列項均是來自于訓練集,測試階段輸入序列項均是由生成模型生成而來,使得訓練階段未曾發(fā)現的錯誤學習結果在測試階段快速積累而導致曝光偏差 (exposure bias)問題[11]。其次,序列生成過程中使用的數據均為離散數據,且每次只生成一個序列項,訓練過程選用交叉熵函數作為模型損失函數,BLEU等離散型評價指標對模型進行評價,結果表明,各種離散型指標的不可微性最終導致了損失函數優(yōu)化難的問題。隨后,生成對抗網絡(generative adversarial networks,GAN)在連續(xù)數據生成領域展現了強大的性能,并擁有可靠的理論基礎,逐漸成為流行的生成模型之一[2]。該模型基于零和博弈的思想,生成器生成偽數據去混淆判別器,使得判別器無法判斷輸入數據是否來自于訓練集;判別器為了能夠正確判別數據真?zhèn)芜M行不斷訓練,然后判別器通過訓練將學習到的知識通過梯度的形式傳遞給生成器,引導生成器進行訓練。但該模型在序列離散數據生成上仍具有目標函數不可微、優(yōu)化難等問題。
現階段,基于上述問題提出了兩類解決方法,即強化學習和策略梯度方法、改進的生成對抗網絡方法。在第一類強化學習和策略梯度方法中,最為經典的是序列生成對抗網絡模型(sequence generative adversarial networks,SeqGAN)[12]。該模型使用策略梯度算法優(yōu)化目標,在很多場景下都有出色的表現,但該模型在不同場景下生成結果具有不穩(wěn)定、不可靠的缺點。第二類方法中,主要使用Gumbel-softmax[13],得到一個近似連續(xù)分布的序列離散數據分布,使得模型在訓練過程中的目標函數是可微的,該方法極大地增強了訓練過程的穩(wěn)定性[13]。上述各種生成模型算法中,其更多地注重對模型生成準確度的改進,卻忽視了模型生成的多樣性。為此,提出一種協同訓練方法,即離散數據生成的協同訓練(cooperative training,CoT)[14]。CoT模型通過直接優(yōu)化Jensen-Shannon散度來進行針對離散數據生成模型的訓練,并且改進了優(yōu)化過程,引入了中間模型,在一定程度上提升了模型生成數據的多樣性,但是仍存在優(yōu)化空間。
大多數研究者的工作主要集中在連續(xù)數據上,而對于生成序列離散數據的研究較少,并且生成式對抗網絡在序列生成過程中忽略了生成的多樣性,為此本文提出了自適應序列生成方法ECoT,使用兩層元控制器。在數據層面平衡真實數據以及生成數據分布,第一層元控制器調節(jié)真實樣本分布和生成樣本分布的混合程度,得到混合數據分布。在模型層面優(yōu)化模型訓練的過程,第二層元控制器調節(jié)生成器更新梯度,進而尋找樣本生成質量和多樣性之間的最佳平衡,在保證其生成樣本準確性的同時最大限度地提高生成的多樣性。
1 相關工作
1.1 符號定義
本文中所用到的符號定義如表1所示。
1.2 序列生成對抗網絡
2014年,Goodfellow等人[15]將零和博弈的對抗學習思想與深度學習相結合,提出了生成對抗網絡。在生成對抗網絡中,生成器Gθ接收真實樣本數據來生成序列,其本質是對序列進行特征提取,然后根據提取到的特征學習樣本真實分布來混淆判別器Dψ。而判別器Dψ的目標則是能夠完全區(qū)分數據的真?zhèn)巍.斉袆e器Dψ達到無法準確區(qū)分真實的序列樣本和生成的序列樣本狀態(tài),便是生成對抗網絡的理想狀態(tài),這種狀態(tài)被稱之為納什平衡狀態(tài)[15]。生成對抗網絡的學習過程實際上是尋找極大極小值問題,生成對抗網絡的目標函數如式(1)所示。
對式(1)進行一步的推導可以觀察到,生成對抗網絡中生成器Gθ的訓練目標實際上是求解Jensen-Shannon散度(JS散度)的最小值。JS散度的定義如式(2)所示。
其中:M12(Pdata+G),系數12為生成器和真實數據分布的影響力權重。
為了解決序列數據的離散且不可微問題,序列對抗生成模型SeqGAN引入了強化學習的思想。該模型將強化學習中的隨機策略運用到生成對抗網絡中,通過策略梯度更新模型的參數來解決離散數據生成過程中的不可微問題。在SeqGAN中,將生成器模型的目標函數進行了改進,具體如式(3)所示。
其中:s表示生成器所生成的一個完整序列;Qt(st,xt)是狀態(tài)—動作值函數,表示從狀態(tài)st開始,采取行動xt的累計獎勵。SeqGAN將判別器對生成器生成序列的評分作為強化學習中的獎勵。為了解決訓練過程中存在的曝光偏差問題,在生成器訓練過程中,使用蒙特卡羅搜索和roll-out策略對未確定的T-t個序列項進行采樣,來評估中間狀態(tài)的動作值。狀態(tài)—動作值函數定義如式(4)所示。
SeqGAN模型的判別器,其目標函數如式(5)所示。
由于SeqGAN對于強化學習思想的依賴,導致了生成對抗網絡中的模式崩潰問題更為嚴重。換而言之,SeqGAN雖然獲得了與訓練集真實數據概率分布具有較高相似度的生成概率分布,但是SeqGAN的高準確度是以犧牲生成數據的多樣性作為代價。
1.3 協同訓練生成模型
SeqGAN模型中,生成器在對抗學習過程中,隨著訓練次數的增多,其生成樣本多樣性逐漸減小。針對SeqGAN模型存在生成數據缺少多樣性的問題,CoT模型引入了中間協同訓練模型M來引導協助生成器的訓練。該模型的最終目標函數如式(6)所示。該模型將生成對抗網絡中的最大最小化目標函數轉換成了完全的最大化問題。
2 自適應序列生成方法
在生成對抗學習中引入強化學習方法,盡管在生成準確度上有了明顯的效果提升,但是生成對抗網絡在對抗訓練階段時常會出現模式崩潰現象。在CoT模型提出后,模式崩潰問題得到了緩解。本章將通過建立元學習任務來進一步緩解模式崩潰問題,尋找生成器生成高質量序列和多樣性序列之間的最佳平衡,在保證準確率的同時,提升生成樣本多樣性。為此,本文分別在中間輔助生成模型訓練階段和生成器訓練階段中設置M-元控制器和G-元控制器兩個元控制器。其中,M-元控制器在數據層面控制輸入到中間輔助生成器M的生成樣本和真實樣本與的混合程度,G-元控制器在模型層面控制生成器訓練過程梯度的更新。ECoT模型的整體框架如圖1所示。
2.1 M-元控制器
中間輔助生成模型的訓練階段,中間輔助生成模型由生成模型和真實樣本分布的混合概率分布函數組成。其中,12(Pdata+Gθ)表示混合數據分布M*,將生成器和訓練集的采樣結果作為中間輔助生成器的輸入,這種方式在一定程度上緩解了曝光偏差問題[11,13]。CoT模型中,中間輔助生成器的目標函數如式(7)所示。
其中:使用來自生成器的樣本計算生成模型與中間輔助生成模型的KL散度值,使用來自訓練集的樣本計算真實數據分布Pdata與中間輔助生成模型M?的KL散度值,最后計算兩個KL散度平均值,式(7)實質上是混合數據分布和中間輔助生成器對應分布的KL散度。模型訓練過程中,中間輔助生成模型擬合的分布函數向真實數據分布靠攏,同時使用生成器來控制中間輔助生成模型每次向真實數據分布靠攏的程度。式(7)中將生成器和真實數據分布視為同等影響力的設置,對生成多樣性的提高具有局限性。針對該問題,將學習動態(tài)權重設置為第一元學習任務,也是M-元控制器的主要內容。
在本文中,從數據采樣的角度分析中間輔助生成器對生成器生成多樣性的影響。在多數生成對抗網絡模型中對真實數據、生成數據均采用等比例分層采樣的方法,采樣出等量的真實樣本和生成樣本,這種做法存在兩個問題:a)在訓練初期,生成器并不具備良好的生成效果,所以生成的數據對中間輔助器的訓練協同互助的作用有限;b)在訓練后期,生成器生成與真實數據十分相似的數據,該期間中間輔助生成器需要增加生成的多樣性,但是由于生成樣本與真實樣本的等比例分層采樣出的輸入數據無法為生成器提供更多生成多樣性學習的引導。對于上述兩個問題,本文利用M-元控制器學習長尾分布,控制中間輔助生成器輸入數據的混合程度,有側重地為中間輔助生成器提供訓練樣本,過程如圖2所示。M-元控制器的實現如式(8)所示,中間輔助生成器的目標函數如式(9)所示。
其中:λM值由式(10)所得。
M-元控制器在數據層面動態(tài)調控真實樣本分布和生成樣本分布對中間輔助生成器模型訓練的影響。在M-元控制器的調節(jié)下,真實數據分布以及生成數據分布動態(tài)引導中間輔助生成模型訓練,為中間輔助生成器多樣性的提高隱式地提供了方向。式(10)的MLP選擇sigmoid作為輸出單元激活函數。
2.2 G-元控制器
中間輔助生成器訓練后轉而進行生成器的訓練,現有的方法均是通過最小化JS散度作為生成器目標函數,具體如式(11)所示。
最小化JS散度實際上是通過中間輔助生成器間接引導生成器Gθ訓練,具體過程如圖3(a)~(d)所示。假設真實數據服從高斯分布的概率密度函數擬合,其中“--”為
ECoT模型生成器的概率密度函數
,“··”為真實數據概率密度函數,“__”為中間輔助生成器的概率密度函數。中間輔助生成器指導訓練的過程中隱式地提升生成器的生成多樣性,但是這種隱式提升生成多樣性的方法仍存在局限性,對多樣性的提高具有不確定性。為了改進這種情況,同時緩解對抗模型中常見的模式崩潰問題,本文添加了多樣性約束項,該約束項的目的是提供可顯式優(yōu)化損失部分,在目標函數中體現生成多樣性的訓練,該約束項指導生成器逼近均勻分布,進而提升生成的多樣性,在生成器的訓練過程中添加更多的隨機性。
至此,生成器的目標函數已經具備可以顯式優(yōu)化生成準確性和多樣性的能力。但是準確性目標項和多樣性目標項彼此是相互對立、相互競爭的關系,在訓練過程中難以達到預期的效果,甚至會導致目標函數難以收斂。為了解決該問題,設置G-元控制器顯式地控制生成器生成準確性以及多樣性的訓練與學習,生成器的最終目標函數如式(12)所示。
其中: U表示均勻分布U(0,SG);SG表示訓練集樣本數量;Cs表示訓練數據集中樣本種類數;λG由式(13)所得。
其中:D(·)表示判別器的判別結果;ξ表示判別器質量系數。
G-元控制器控制目標函數中準確性目標項和多樣性目標項對生成器訓練的影響,進而指導訓練的方向。接下來,本文從梯度更新角度對多樣性訓練進行分析,設FA表示目標函數中的準確性目標項,FD表示目標函數中多樣性目標項。在訓練初期,隨機初始化后,準確性目標項的值偏低,多樣性占據優(yōu)勢,G-元控制器在該階段提高準確性目標項的值,即增加FA的模長,減少FD的模長,如圖4(a)所示;而且在模型的最低點附近,準確性目標項展現優(yōu)勢,G-元控制器減少FA的模長,增加FD的模長,如圖4(b)所示。通過對梯度下降方向的控制,G-元控制器在模型層面尋找最佳更新梯度。尋找準確性和多樣性的最優(yōu)解,以緩解模式崩潰問題。
為了進一步保證生成模型的準確性,防止由于過分追求生成多樣性導致的生成準確性丟失,本文設置對抗判別器Dψ。對抗判別器主要用于輔助生成器矯正準確性訓練方向和為G-元控制器提供先驗知識輸入。判別器的目標函數如式(14)所示,判別式的輸出表示判別樣本為訓練集真實樣本的概率值。
在對抗學習過程,生成器選用交叉熵作為損失函數,專注于準確性的提升。在對抗學習結束后,判別器的輸出和判別器的交叉熵損失值組合作為G-元控制器輸入。判別器輸出對輸入樣本的判別概率,由于判別器處于動態(tài)訓練中,所以需要計算判別器的判別質量,稱為判別器質量系數,如式(15)所示,判別質量系數與判別器對生成樣本的判別輸出向量進行元素相乘,作為G-元控制器部分輸入。
其中:loss表示判別器的交叉熵損失值。
使用M-元控制器和G-元控制器對生成器Gθ和中間輔助生成器M?的學習過程進行兩層控制調節(jié),在數據和模型兩個層面進行學習,元控制器優(yōu)化生成器Gθ和中間輔助生成器M?向真實分布學習的過程,使模型能夠在保證生成樣本準確度的同時,提升模型生成樣本的多樣性,同時,增加對抗學習器以保證生成器準確率的穩(wěn)定訓練。算法1中給出了自適應序列生成方法的完整算法過程。
算法1 自適應序列生成方法
輸入: Gθ,Dψ,M?;從真實樣本分布Pdata采樣樣本。
輸出:生成器Gθ。
initialize Gθ,Dψ,M?,λM,λG with random weights θ,ψ,?,ωM,ωG
pretrain Gθ with samples from Pdata
while not done do
for Nm steps do
sample sg from Gθ and" Sp from Pdata
compute loss of" M? and λM
update ?,ωM
end for
samples S from Gθ
compute loss of Gθ and λG
update θ,ωG
end while
3 實驗
本章使用簡寫ECoT表示自適應序列生成方法。首先,本文分別在SeqGAN[12]和Meta-CoTGAN[1]中引入的合成離散序列生成數據集以及常用的文本生成數據集COCO圖像字幕數據集和EMNLP2017 WMT News數據集上進行對比實驗。在三個數據集的實驗中,均使用TensorFlow深度學習框架和Texygen框架進行訓練與模型的評估[16],在本文實驗環(huán)節(jié)中,將生成對抗網絡中的對抗學習過程和協同訓練的過程相結合,設置生成器Gθ、判別器Dψ、中間輔助生成模型M?,設置判別器的目標函數為JD(ψ)。
3.1 實驗評價指標
在實驗中,從生成樣本的準確性以及生成樣本的多樣性兩個方面對生成數據進行評估。現有的大部分文本序列生成工作使用BLEU分數度量生成器在真實數據集上學習后所生成的樣本質量,而在合成數據集上對模型進行評估時,通過計算
值來評估生成樣本準確性,使用NLLtest評估生成樣本多樣性[1]。設計圖靈測試任務,依據專家網絡提供的先驗知識計算負對數似然函數NLLoracle,其具體計算如式(16)所示,當NLLoracle越小時,表示生成模型的準確度越高。
NLLtest指的是從專家網絡額外抽取樣本,計算生成器的負對數似然,NLLtest是用于評估模型擬合真實測試數據能力的簡單指標[16],通過評估生成模型在真實數據概率密度上的覆蓋范圍來評估生成樣本的多樣性以及生成模型的抗模式崩潰能力。如果模型在真實數據空間中具有更廣泛的覆蓋范圍,則生成的樣本將具有更好的多樣性,對應的損失會更低。相反地,如果模型存在嚴重的模式崩潰問題,那么模型在真實數據空間中覆蓋范圍較小,模型將不能很好地代表真實數據,并且會得到較高的損失[1]。其具體計算如式(17)所示。
BLEU最初被應用于機器翻譯系統,是用來衡量機器翻譯的結果和人工翻譯的差異的指標。假設輸入標準的人工翻譯結果,生成器模型生成相應的翻譯結果,將句子長度設為n,在生成器模型生成的翻譯結果中存在m個單詞是在標準的人工翻譯結果中重復出現的,那么稱得到的m/n就是BLEU的1-gram值。通過這種計算方法來衡量生成結果的準確性。根據k-gram中k的取值不同,BLEU-k的結果也有所不同,更高階的BLEU衡量生成序列的準確性,同時衡量生成序列的流暢性。
3.2 實驗數據與結果分析
3.2.1 專家網絡合成數據集
本文使用已訓練好的長短期記憶網絡(long short-term memory,LSTM)模型作為專家網絡,其生成的數據作為訓練集,專家網絡不僅為模型的訓練提供訓練數據集,還提供訓練樣本的先驗知識。在實驗部分,優(yōu)先在專家網絡上合成的數據集進行模型評價,在SeqGAN模型實驗中首次使用專家網絡合成的數據集。實驗中設置專家網絡模擬現實世界中的序列數據,生成長度為20并且序列項總數為5 000的訓練數據,總共生成10 000個序列用做訓練。訓練前,對所有生成器的參數進行初始化,且參數使用正態(tài)分布初始化器初始化,且在所有的生成器預訓練階段均選擇極大似然估計作為預訓練過程[7],選擇LSTM作為生成器與判別器模型,Meta-CoTGAN溫度設置為1 000。在預訓練階段中,本文首先訓練生成器80個epoch,然后訓練判別器80個epoch,隨后進入到對抗訓練階段。每次對抗階段,本文更新一次生成器后,判別器進行15次小批量梯度更新[7]。此外,在LeakGAN模型的訓練過程中,每10次對抗訓練后,生成器和判別器將進行5次極大似然估計訓練[7]。本次實驗除了將ECoT與CoT及Meta-CotGAN模型等協同訓練模型進行效果對比實驗之外,還進行了同極大似然估計模型MLE以及融合了強化學習思想的序列生成對抗網絡SeqGAN及其變體(如MaliGAN[11]、RankGAN[17]、LeakGAN[18]模型等)實驗結果的對比展示。
在表2和圖5中,將ECoT與MLE、GAN及其強化學習變體、CoT模型及其變體在NLLoracle和NLLtest兩個指標上的結果進行比較展示。在圖5中可以清楚地發(fā)現,采用MLE作為基礎生成器的CoT模型和極大似然估計MLE模型生成樣本的準確率是相對最低的,NLLoracle值達到了9以上,而SeqGAN模型及其引入強化學習變體的模型方法,如RankGAN、MaliGAN、LeakGAN的NLLoracle值分別為8.74、8.40、8.91、8.57,整體取值在8.4~8.9浮動。對比NLLtest值可以發(fā)現,LeakGAN的NLLtest值最小為4.54,也就是說在SeqGAN模型引入強化學習變體的模型方法中,LeakGAN的準確率及多樣性是最高的。采用LeakGAN模型的架構作為基礎生成器架構的CoT-Strong模型,其NLLoracle值減小到了8.24,NLLtest減小到4.36。在梯度中建立元學習任務的Meta-CoTGAN的NLLoracle值為8.18,NLLtest值減小到了4.30。本文提出的ECoT模型的NLLoracle、NLLtest值分別為7.57、4.12,其中NLLtest值較Meta-CoTGAN降低了4.1%,而NLLoracle相比于Meta-CoTGAN更是降低了7.5%。由上述實驗數據可以看到,本文提出的ECoT模型在生成序列的準確度上得到了提升,甚至達到了最優(yōu),并且對比NLLtest值,實驗結果表明本文提出的ECoT模型在生成樣本的多樣性上具有更好的效果,抗模式崩潰能力更強。通過兩個方面的實驗對比可以發(fā)現,本文方法在生成的準確度以及生成樣本的多樣性上均達到了最優(yōu)。
3.2.2 COCO圖像字幕數據集和EMNLP2017 WMT News數據集
COCO(common objects in context)起源于微軟于2014年出資標注的Microsoft COCO數據集[19]。COCO數據集中的圖像分為訓練集、驗證集和測試集。COCO數據集是一個大型的、豐富的物體檢測、分割和字幕數據集。COCO數據集以對圖片中場景的理解作為其主要目標,工作的具體內容是從復雜的日常場景中截取圖像中的目標,通過精確的語義分割進行位置標定。圖片包括91種對象目標,328 000張圖片和2 500 000個標簽。目前為止是語義分割的最大數據集,提供的類別有80 類,超過33 萬張圖片,其中20 萬張有標注,整個數據集中個體的數目超過150 萬個。該數據集包含多組圖像描述對。本文將整個圖像數據集上的圖像標題作為要生成的文本,其中大多數句子約為10個單詞。因此,本文對數據集進行了一些預處理。COCO圖像字幕訓練數據集由20 734個單詞和417 126個句子組成。本文刪除頻率低于10的單詞以及包含它們的句子。經過預處理,最終數據集中包含了20 000個句子。EMNLP2017 WMT News數據集作為長文本語料庫進行實驗評估。EMNLP2017 WMT News是機器翻譯領域最重要的公開數據集,其數據規(guī)模較大,含有多種語言的文本,通常在百萬句到千萬句不等。這部分實驗中,首先從原始EMNLP2017 WMT News數據集中選擇部分新聞數據。選擇出來新聞數據集由6 459個單詞和397 726個句子組成。本文通過消除頻率低于4 050的單詞以及包含這些低頻單詞的句子對數據進行預處理。在本實驗中,刪除了長度小于20的句子,最終采樣出20 000個句子作為實驗數據集。
為了對比本文模型同生成對抗網絡與強化學習的組合在準確度和多樣性上的提升,驗證協同訓練方法比生成對抗網絡與強化學習組合更加適合序列離散數據的生成,在兩個數據集上進行實驗,使用BLEU對準確度進行評估,使用NLLtest評估生成樣本的多樣性。表3展示模型在COCO圖像字幕數據集以及EMNLP2017 WMT News數據集上的BLEU-2~BLEU-5評分,用于測量樣本質量的分數,
NLLtest值用于評估生成樣本的多樣性。為了更清楚地展示,使用柱狀圖進行實驗結果對比,如圖6所示,(a)表示在COCO圖像字幕數據集上的結果對比,(b)表示在EMNLP2017 WMT News數據集上的結果對比。在兩個真實數據集上的實驗結果表明,不論是在準確性上還是在多樣性上,協同訓練方法在離散數據的生成問題上比生成對抗網絡與強化學習組合更具優(yōu)勢。ECoT模型同協同訓練方法CoT模型相比,雖然在生成多樣性上與CoT接近,但是在生成準確性上,ECoT較CoT具有更優(yōu)的表現。實驗結果表明,ECoT具備尋找準確性和多樣性最優(yōu)組合的能力。
在表4中,將ECoT模型同MLE、CoT在COCO圖像字幕數據集訓練采樣結果進行對比。在表4中能夠直觀發(fā)現,本文提出的ECoT模型能夠生成表達更為人性化的語句,并且更具多樣性;而CoT模型雖然能夠生成表達多樣的語句序列,但是其生成的句子在正確性上略顯不足;MLE模型作為基礎的序列生成模型,在生成短句上具備更好的性能,但是在更為復雜的長句生成中,生成結果會出現難以意料的錯誤生成結果,并且MLE模型生成的語句較顯單一,并不具備較好的表達性。相比較而言,本文提出的ECoT模型在兼顧生成準確性的同時,能夠生成句式更加豐富,表達更為多樣性的語句序列。
4 結束語
本文提出了一種新的方法ECoT模型,以協調序列生成的準確性和多樣性,為提升生成樣本的多樣性提供了一種新思路。在協同訓練思想基礎之上,ECoT分別在中間輔助生成器的訓練階段引入了M-元控制器,在生成器的訓練階段引入了G-元控制器。通過對模型施加兩層控制,在訓練中尋找接近真實數據分布的同時,又能夠保持其生成樣本多樣性的平衡態(tài)模型,并且本文方法在面對模式崩潰問題時展現了有效性。實驗結果表明,本文方法在準確性和多樣性上優(yōu)于對比方法,同時在與LeakGAN等強化學習架構相結合時,本文方法能夠展現出更加優(yōu)秀的性能。在未來的研究中,將使用元學習方法與CoT深度結合,對比不同生成模型的生成效果,進一步提高生成序列的多樣性。
參考文獻:
[1]Yin Haiyan,Li Dingcheng,Li Xu,et al.Meta-CoTGAN:a meta coo-perative training paradigm for improving adversarial text generation[C]//Proc of AAAI Conference on Artificial Intelligence.Palo Alto,CA:AAAI Press,2020:9466-9473.
[2]Bahdanau D,Cho K H,Bengio Y.Neural machine translation by jointly learning to align and translate[C]//Proc of the 3rd International Conference on Learning Representations.2015:1-15.
[3]張涼,楊燕,陳成才,等.基于多視角對抗學習的開放域對話生成模型[J].計算機應用研究,2021,38(2):372-376.(Zhang Liang,Yang Yan,Chen Chengcai,et al.Open domain dialogue generation model based on multi-view adversarial learning[J].Application Research of Computers,2021,38(2):372-376.)
[4]LiuShuman,Chen Hongshen,Ren Zhaochun,et al.Knowledge diffusion for neural dialogue generation[C]//Proc of the 56th Annual Meeting of the Association for Computational Linguistics.Stroudsburg,PA:Association for Computational Linguistics,2018:1489-1498.
[5]Vaswani A,Shazeer N,Parmar N,et al.Attention is all you need[C]//Proc of the 31st International Conference on Neural Information Processing Systems.Red Hook,NY:Curran Associates Inc.,2017:5998-6008.
[6]Lin Junyang,Sun Xu,Ma Shuangming,et al.Global encoding for abstractive summarization[C]//Proc of the 56th Annual Meeting of the Association for Computational Linguistics.Stroudsburg,PA:Association for Computational Linguistics,2018:163-169.
[7]Wu Chaoyuan,Ahmed A,Beutel A,et al.Recurrent recommender networks[C]//Proc of the 10th ACM International Conference on Web Search and Data Mining.New York:ACM Press,2017:495-503.
[8]Tang Jiaxi,Belletti F,Jain S,et al.Towards neural mixture recommender for long range dependent user sequences[C]//Proc of the 28th International Conference on World Wide Web.New York:ACM Press,2019:1782-1793.
[9]Ying Haochao,Zhuang Fuzhen,Zhang Fu Zhang,et al.Sequential re-commender system based on hierarchical attention networks[C]//Proc of the 27th International Joint Conference on Artificial Intelligence.Palo Alto,CA:AAAI Press,2018:3926-3932.
[10]伍鑫,黃勃,方志軍,等.序列生成對抗網絡在推薦系統中的應用[J].計算機工程與應用,2020,56(23):175-179.(Wu Xin,Huang Bo,Fang Zhijun,et al.Application of sequence generative adversarial network in recommendation system[J].Computer Engineering and Applications,2020,56(23):175-179.)
[11]Che Tong,Li Yanren,Zhang Ruixiang,et al.Maximum-likelihood augmented discrete generative adversarial networks[EB/OL].(2017-02-26).https://arxiv .org/abs/1 702.07983.
[12]Yu Lantao,Zhang Weinan,Wang Jun,et al.SeqGAN:sequence gene-rative adversarial nets with policy gradient[C]//Proc of AAAI Confe-rence on Artificial Intelligence.Palo Alto,CA:AAAI Press,2017:2852-2858.
[13]Kusner M J,Hernández-Lobato J M.GANs for sequences of discrete elements with the Gumbel-softmax distribution[EB/OL].(2016-11-16).https://arxiv .org/abs/1611.04051v1.
[14]Lu Sidi,Yu Lantao,Feng Siyuan,et al.CoT:cooperative training for generative modeling of discrete data[C]//Proc of International Conference on Machine Learning.2019:4164-4172.
[15]Goodfellow I J,Pouget-Abadie J,Mirza M,et al.Generative adversarial nets[C]//Proc of the 27th International Conference on Neural Information Processing Systems.Cambridge,MA:MIT Press,2014:2672-2680.
[16]Zhu Yaoming,Lu Sidi,Zheng Lei,et al.Texygen:a benchmarking platform for text generation models[C]//Proc of the 41st International ACM SIGIR Conference on Research amp; Development in Information Retrieval.New York:ACM Press,2018:1097-1100.
[17]Lin K,Li Diang,He Xiaodong,et al.Adversarial ranking for language generation[C]//Proc of the 31st International Conference on Neural Information Processing Systems.Red Hook,NY:Curran Associates Inc.,2018:3158-3168.
[18]Guo Jiaxian,Lu Sidi,Cai Han,et al.Long text generation via adversa-rial training with leaked information[C]//Proc of AAAI Conference on Artificial Intelligence.Palo Alto,CA:AAAI Press,2018:5141-5148.
[19]Lin T Y,Maire M,Belongie S,et al.Microsoft COCO:common objects in context[C]//Proc of European Conference on Computer Vision.Cham:Springer,2014:740-755.