






收稿日期:2021-10-05;修回日期:2021-12-04" 基金項(xiàng)目:國家自然科學(xué)青年基金資助項(xiàng)目(41706198)
作者簡介:尹來國(1997-),男,山東濰坊人,碩士研究生,主要研究方向?yàn)樯疃葘W(xué)習(xí)、復(fù)雜網(wǎng)絡(luò);孫仁誠(1977-),男(通信作者),山東青島人,教授,碩導(dǎo),博士,主要研究方向?yàn)閿?shù)據(jù)挖掘、復(fù)雜網(wǎng)絡(luò)、深度學(xué)習(xí)等(qdsunstar@163.com);邵峰晶(1955-),女,山東青州人,教授,博導(dǎo),博士,主要研究方向?yàn)閿?shù)據(jù)挖掘、復(fù)雜網(wǎng)絡(luò)等;隋毅(1984-),女,副教授,碩導(dǎo),博士,主要研究方向?yàn)閺?fù)雜網(wǎng)絡(luò)、深度學(xué)習(xí)等;邢彤彤(1997-),女,山東青島人,碩士研究生,主要研究方向?yàn)樯疃葘W(xué)習(xí)等.
摘 要:生成對抗網(wǎng)絡(luò)已經(jīng)成為深度學(xué)習(xí)領(lǐng)域最熱門的研究方向之一,其最大的優(yōu)勢在于能夠以無監(jiān)督的方式來擬合一個(gè)未知的分布。目前,生成對抗網(wǎng)絡(luò)在圖像生成領(lǐng)域大放異彩,其能夠產(chǎn)生一些高質(zhì)量的圖像,但也暴露了一些弊端。在生成圖像的過程中,經(jīng)常會(huì)出現(xiàn)模式坍塌問題,從而導(dǎo)致生成的樣本過于單一。為了解決這個(gè)問題,對生成對抗網(wǎng)絡(luò)的模型結(jié)構(gòu)和損失函數(shù)加以改進(jìn),使判別器能夠從多個(gè)角度來度量生成數(shù)據(jù)的分布和真實(shí)數(shù)據(jù)的分布之間的差異,從而改善生成樣本的多樣性。通過在多個(gè)數(shù)據(jù)集上進(jìn)行實(shí)驗(yàn),結(jié)果顯示,提出的模型在很大程度上緩解了模式坍塌問題。
關(guān)鍵詞:生成對抗網(wǎng)絡(luò);圖像生成;模式坍塌
中圖分類號(hào):TP183"" 文獻(xiàn)標(biāo)志碼:A
文章編號(hào):1001-3695(2022)06-015-1689-05
doi:10.19734/j.issn.1001-3695.2021.10.0605
Multi-mode generative adversarial network
Yin Laiguo,Sun Rencheng,Shao Fengjing,Sui Yi,Xing Tongtong
(Dept. of Computer Science amp; Technology,Qingdao University,Qingdao Shandong 266071,China)
Abstract:Generative adversarial networks have become one of the most popular research directions in the field of deep lear-ning.Its main advantage is that it can fit unknown distribution in an unsupervised way.At present,the generative adversarial network is valuable in the field of image generation.It can generate some high-quality images,but it also exposes some disadvantages.In the process of image generation,the problem of mode collapse often occurs,which leads to the generated sample being too single.To solve this problem,this paper improved the model structure and loss function of the generative adversarial network,so that the discriminator could measure the difference of distribution between generated data and real data from many aspects,thus increasing the diversity of generated samples.Experiments on multiple data sets show that the proposed model alleviates the mode collapse to a large extent.
Key words:generative adversarial network;image generation;mode collapse
近年來,生成模型[1]在深度學(xué)習(xí)領(lǐng)域變得越來越熱門,而這些生成模型的本質(zhì)都是在擬合未知的真實(shí)數(shù)據(jù)的高維分布[2]。在音樂風(fēng)格遷移領(lǐng)域,生成模型可以輕易地將一種流派的音樂轉(zhuǎn)換為另一種流派的音樂[3];在人臉識(shí)別領(lǐng)域,人臉數(shù)據(jù)收集的困難性導(dǎo)致模型的精度難以提升,而生成模型可以提供大量的生成樣本來訓(xùn)練人臉識(shí)別模型[4],目前銀行等相關(guān)部門的基于生成樣本的人臉識(shí)別模型已經(jīng)達(dá)到了很高的精度。然而,由于真實(shí)數(shù)據(jù)分布的高維性和復(fù)雜性等特點(diǎn),生成模型仍然難以學(xué)習(xí)到真實(shí)數(shù)據(jù)的分布,更加可能的是學(xué)習(xí)到真實(shí)數(shù)據(jù)的部分分布[5],這極大地限制了模型的適用性。
在生成對抗網(wǎng)絡(luò)(generative adversarial network,GAN)[2]被提出之前,主流的生成模型包括變分自動(dòng)編碼器(variational autoencoder,VAE)[6]、自回歸模型(autoregressive model,AR),例如PixelRNN[7]等。VAE是在自動(dòng)編碼器的基礎(chǔ)上,添加了一些約束條件,使編碼器的輸出服從高斯分布。VAE能夠在具有潛變量的概率圖模型上學(xué)習(xí)到有效的貝葉斯推理。然而,VAE生成的樣本往往不夠清晰,很多情況下無法滿足環(huán)境的使用需求。PixelRNN只需要訓(xùn)練一個(gè)網(wǎng)絡(luò)就可以達(dá)到生成的目的,該網(wǎng)絡(luò)通過使用之前的像素信息來預(yù)測下一個(gè)像素值,這類似于將圖像的像素插入到Char-RNN[8]中。Char-RNN是在字符的維度上,通過觀測字符來預(yù)測下一個(gè)字符出現(xiàn)的概率。PixelRNN有一個(gè)簡單且穩(wěn)定的訓(xùn)練過程,并且生成的數(shù)據(jù)也比較合理。然而,該模型的采樣效率相對較低,并且不容易為圖像提供簡單的低維編碼。隨著GAN的興起,生成模型的發(fā)展迎來了新的機(jī)遇。GAN是由Goodfellow等人于2014年提出的,被廣泛地應(yīng)用于圖像領(lǐng)域,包括圖像生成[9~11]、圖像風(fēng)格遷移[12~15]、圖像的超分辨率[16~19]、圖像修復(fù)[20]等方面。在圖像生成領(lǐng)域,Gao等人[21]提出的ProGAN通過向生成器中添加層的方式生成了高質(zhì)量的圖像;Ledig等人[22]提出了SRGAN,通過利用低質(zhì)量圖像來生成高質(zhì)量的圖像。由于生成器和判別器之間博弈過程的不穩(wěn)定性,GAN往往難以學(xué)習(xí)到真實(shí)數(shù)據(jù)的完整分布,即發(fā)生了模式坍塌[23~25]的問題。
針對這種問題,本文將信息熵與生成對抗網(wǎng)絡(luò)的優(yōu)勢相結(jié)合,并融入了高斯混合模型(Gaussian mixed model,GMM)[26]擬合分布的思想,提出了多模式生成對抗網(wǎng)絡(luò)(multi-mode GAN,MM-GAN)。GMM理論上可以利用多個(gè)高斯分布來擬合任意的分布。信息熵是度量分類純度的指標(biāo),將其應(yīng)用于生成器的損失函數(shù)中,可以調(diào)整生成器的參數(shù),使生成器生成多樣性的樣本。
1 相關(guān)工作
模式坍塌問題已經(jīng)嚴(yán)重影響了生成模型的性能,使得模型的輸出缺少多樣性。究其原因,在于網(wǎng)絡(luò)的學(xué)習(xí)能力受限,在實(shí)際訓(xùn)練中不能有效地?cái)M合真實(shí)數(shù)據(jù)的分布。目前比較合理的解決方案是調(diào)整模型的網(wǎng)絡(luò)結(jié)構(gòu),強(qiáng)化網(wǎng)絡(luò)擬合分布的能力。 Liu等人[27]提出的CoGAN由兩個(gè)共享部分權(quán)重參數(shù)[28]的生成對抗網(wǎng)絡(luò)組成。權(quán)重共享能夠使得CoGAN去學(xué)習(xí)一個(gè)聯(lián)合分布,這種策略增強(qiáng)了生成器的分布擬合能力。Ghosh等人[29]提出的MAD-GAN整合了多個(gè)生成器和一個(gè)判別器。為了正確地識(shí)別生成偽樣本的生成器,判別器必須推動(dòng)不同的生成器學(xué)習(xí)不同的可識(shí)別模式,這會(huì)使得多個(gè)生成器盡可能擬合真實(shí)數(shù)據(jù)分布的不同部分。張龍等人[5]提出的協(xié)作式生成對抗網(wǎng)絡(luò)同樣由多個(gè)生成器和一個(gè)判別器組成。多個(gè)生成器共享一個(gè)輸入數(shù)據(jù)。協(xié)作式訓(xùn)練可以拉近不同生成器之間的分布,從而提高訓(xùn)練效率。與協(xié)作式生成對抗網(wǎng)絡(luò)的想法類似,本文認(rèn)為真實(shí)數(shù)據(jù)的不同模式之間既存在差異性也存在相似性。但是,協(xié)作式對抗網(wǎng)絡(luò)有一些不可忽視的問題。協(xié)作式生成對抗網(wǎng)絡(luò)的兩個(gè)生成器之間相互協(xié)作、相互拉近,最終形成的兩個(gè)生成分布依然是局部性的。這種模型雖然在一定程度上增強(qiáng)了生成器的擬合能力,但是生成器擬合能力的提升完全是依賴于生成器數(shù)量的堆積,協(xié)作式訓(xùn)練并不能有效地改善模式坍塌問題。并且生成器的協(xié)作并不總是正向的,反向協(xié)作可能會(huì)使得生成器花費(fèi)更多的時(shí)間來達(dá)到收斂狀態(tài)。此外,共享同一個(gè)輸入的兩個(gè)生成器在達(dá)到收斂狀態(tài)時(shí),接受一個(gè)新的輸入向量,兩個(gè)生成器會(huì)生成不同的樣本。這種擬合能力的增強(qiáng)不能表達(dá)太多的實(shí)際意義。本文吸取了協(xié)作式生成對抗網(wǎng)絡(luò)的經(jīng)驗(yàn),提出了多模式生成對抗網(wǎng)絡(luò),該模型能夠在保證生成樣本質(zhì)量的同時(shí),改善模式坍塌問題,增加生成樣本的多樣性,并且比較符合人類的認(rèn)知。
2 多模式生成對抗網(wǎng)絡(luò)
2.1 生成對抗網(wǎng)絡(luò)(GAN)
GAN是一種無監(jiān)督的深度學(xué)習(xí)模型,其模型結(jié)構(gòu)如圖1所示。該模型主要由生成器G和判別器D組成,生成器G的作用是產(chǎn)生一個(gè)分布來擬合真實(shí)數(shù)據(jù)的分布,判別器D主要是為了度量真實(shí)數(shù)據(jù)與生成數(shù)據(jù)的差異。生成器G和判別器D經(jīng)過對抗訓(xùn)練,最終達(dá)到一個(gè)納什均衡狀態(tài),即生成器G生成的數(shù)據(jù)非常接近真實(shí)數(shù)據(jù),而判別器D無法區(qū)分輸入的數(shù)據(jù)是來自生成器G生成的數(shù)據(jù)還是真實(shí)數(shù)據(jù),此時(shí)就可以認(rèn)為生成器G學(xué)到了真實(shí)數(shù)據(jù)的分布。此時(shí)得到最優(yōu)判別器為
D*(x)=Pdata(x)Pdata(x)+PG(x)(1)
其中:Pdata(x)為真實(shí)數(shù)據(jù)的分布;PG(x)為生成數(shù)據(jù)的分布。
當(dāng)判別器固定時(shí),生成器的目標(biāo)函數(shù)與真實(shí)數(shù)據(jù)x無關(guān),如式(2)所示。
LG=Ez~P(z)[log(D(G(z)))](2)
其中:z為從簡單分布P(z)中采樣的隨機(jī)噪聲;G(z)是將隨機(jī)噪聲z輸入到生成器中所生成的數(shù)據(jù);D(G(z))為判別器D判斷生成數(shù)據(jù)G(z)為真的概率。生成器通過欺騙判別器來提高自身的生成能力,所以在訓(xùn)練生成器時(shí),要不斷地調(diào)整參數(shù),使得判別器能夠同等看待生成器生成的數(shù)據(jù)和真實(shí)數(shù)據(jù)。為了清晰地揭示目標(biāo)函數(shù)的更多細(xì)節(jié),將式(1)帶入式(2)中,可得到
LG=-DKL(PG(x)‖Pdata(x))+2DJS(Pdata(x)‖PG(x))-
2log2-Ex~Pdata(x)[log D*(x)](3)
其中:DKL()和DJS()分別為KL散度和JS散度,兩者都是度量分布之間差異的指標(biāo)。由于式(3)的后兩項(xiàng)與生成器無關(guān),所以可以將生成器的優(yōu)化問題表示為
maxG(LG)=minG(DKL(PG(x)‖Pdata(x))-2DJS(Pdata(x)‖PG(x)))(4)
式(4)中暴露了兩個(gè)嚴(yán)重的問題:a)在最小化生成數(shù)據(jù)的分布與真實(shí)數(shù)據(jù)的分布的KL散度的同時(shí),需要最大化兩者之間的JS散度,這在直觀上是相互矛盾的,在數(shù)值上則會(huì)出現(xiàn)梯度不穩(wěn)定的情況;b)JS散度是一個(gè)對稱的度量指標(biāo),而KL散度是一個(gè)非對稱的度量指標(biāo),因此,LG主要受逆向KL散度DKL(PG(x)‖Pdata(x))的影響。
DKL(PG(x)‖Pdata(x))=∫PG(x)logPG(x)Pdata(x) dx(5)
其中:當(dāng)PG(x)gt;0,而Pdata(x)→0,DKL(PG(x)‖Pdata(x))→∞,即生成器G生成一些不屬于真實(shí)數(shù)據(jù)分布的樣本時(shí),KL散度的值會(huì)變得非常大;當(dāng)Pdata(x)gt;0,而PG(x)→0,DKL(PG(x)‖Pdata(x))→0,即生成器G生成一些接近部分真實(shí)數(shù)據(jù)的樣本時(shí),KL散度的值會(huì)變得很小,而無須考慮是否覆蓋了完整的真實(shí)數(shù)據(jù)的分布。
當(dāng)最小化生成器的損失時(shí),KL散度對兩種錯(cuò)誤的懲罰力度是不同的。當(dāng)生成器生成一些不屬于真實(shí)數(shù)據(jù)分布的樣本時(shí),KL散度的值會(huì)變得很大,不符合最小化損失函數(shù)的目標(biāo)。所以,模型在訓(xùn)練時(shí),不會(huì)傾向于去生成多樣性的樣本來盡可能擬合真實(shí)數(shù)據(jù)的完整的分布,而是會(huì)傾向于產(chǎn)生一些相對單一,但是接近部分真實(shí)分布的新樣本,這也就是發(fā)生了模式坍塌,如圖2所示。模式坍塌會(huì)導(dǎo)致生成樣本的多樣性不足、樣本質(zhì)量差等問題。在圖2(a)中,生成器生成的所有數(shù)字都與0和6有著密切的關(guān)系。從這個(gè)角度講,模式坍塌的形成并不是因?yàn)槟P偷膶W(xué)習(xí)能力不足,更確切的是,模型的表達(dá)能力不能被發(fā)揮。
2.2 多模式生成對抗網(wǎng)絡(luò)(MM-GAN)
MM-GAN是由多個(gè)生成器、一個(gè)分類器以及一個(gè)判別器組成。為了更加充分地體現(xiàn)MM-GAN的優(yōu)越性,并且與CoGAN和MAD-GAN等保持一致性,本文使用兩個(gè)生成器來實(shí)現(xiàn)MM-GAN,其模型結(jié)構(gòu)如圖3所示。生成器所形成的分布不僅僅與生成器的參數(shù)有關(guān),還與隨機(jī)向量z的分布有關(guān)。兩個(gè)生成器的輸入來自不同的先驗(yàn)分布Pz1和Pz2。生成分布中的樣本可以用G(z)表示,這意味著不同先驗(yàn)分布的輸入可能會(huì)使得生成器擬合真實(shí)數(shù)據(jù)分布的不同的部分。依據(jù)Hansen提出的GMM,多個(gè)高斯函數(shù)可以擬合任意的分布,所以,本文將兩個(gè)輸入數(shù)據(jù)的分布確定為具有不同的均值和方差的高斯分布。兩個(gè)生成器generator1和generator2之間共享部分權(quán)重參數(shù)是為了保證兩個(gè)輸入向量的每一維的意義在模型看來是相同的,這樣可以使得相同的輸入向量在兩個(gè)生成器的生成結(jié)果不會(huì)有太大差異。
另外,共享權(quán)重參數(shù)也能減少訓(xùn)練時(shí)間。判別器discriminator的作用是判斷輸入數(shù)據(jù)是來自真實(shí)分布Pdata,還是來自生成分布,其可以使得生成器生成更加真實(shí)的樣本。分類器classifier是為了對輸入數(shù)據(jù)進(jìn)行分類,分類結(jié)果的熵值可以指導(dǎo)生成器的訓(xùn)練。理想情況下,生成器所產(chǎn)生的一批樣本通過分類器之后,分類結(jié)果會(huì)是一個(gè)比較大的熵值。大的熵值表示分類結(jié)果是均勻的,即生成器生成的樣本的種類是多樣性的。
由于訓(xùn)練集的確定性和兩個(gè)高斯函數(shù)取值的重疊性,會(huì)導(dǎo)致兩個(gè)生成器所形成的生成分布有重合的部分。這是合理的,因?yàn)橥粋€(gè)數(shù)據(jù)集的不同模式之間存在著顯著的差異,同時(shí)也存在著聯(lián)系[5]。判別器discriminator的目標(biāo)函數(shù)為
LD=Ex~Pdata(x)ln D(x)+Ez~Pz1(z)ln(1-D(G1(z)))+Ez~Pz2(z)ln(1-D(G2(z)))(6)
其中:D(G1(z))與D(G2(z))分別為生成器G1與G2的生成樣本的判別結(jié)果。在優(yōu)化目標(biāo)函數(shù)的過程中,判別器會(huì)賦予真實(shí)數(shù)據(jù)比較高的置信度,會(huì)賦予生成數(shù)據(jù)比較低的置信度。判別器的主要作用是度量真實(shí)數(shù)據(jù)和生成數(shù)據(jù)的分布差異。而分類器classifier的目標(biāo)函數(shù)為
min E(x,y)~PdataDKL(y,C(x))(7)
其中:DKL()表示KL散度;(x,y)表示真實(shí)數(shù)據(jù)及對應(yīng)的標(biāo)簽。值得注意的是,分類器的訓(xùn)練僅僅依賴于真實(shí)數(shù)據(jù)及其標(biāo)簽。分類器是為了盡可能準(zhǔn)確地識(shí)別出樣本的標(biāo)簽。當(dāng)分類器達(dá)到一定的準(zhǔn)確度時(shí),將一個(gè)batch的數(shù)據(jù)輸入到分類器中,就會(huì)得到一個(gè)batch數(shù)據(jù)的分類結(jié)果。分類結(jié)果的熵值可以表示輸入數(shù)據(jù)的多樣性。熵值越大,則表示輸入數(shù)據(jù)的類別越平均;熵值越小,則表示輸入數(shù)據(jù)的類別比較單一。根據(jù)熵值的大小來調(diào)整生成輸入數(shù)據(jù)的生成器的參數(shù)是有價(jià)值的。另外,分類器classifier也可以是一個(gè)預(yù)訓(xùn)練的模型,其效果并沒有什么損耗。生成器的目標(biāo)函數(shù)定義為
max Ez~Pz1ln D(G1(z))+Ez~Pz2ln D(G2(z))+H(C(G1(Pz1)))+H(C(G2(Pz2)))(8)
其中:H()為香農(nóng)信息熵。生成器G1、G2生成的樣本作為分類器C的輸入,其分類結(jié)果的熵值越大,就意味著兩個(gè)生成器能夠擬合真實(shí)數(shù)據(jù)的更多模式。由于兩個(gè)生成器的輸入數(shù)據(jù)的分布以及網(wǎng)絡(luò)層參數(shù)的不同,兩個(gè)生成器會(huì)傾向于擬合真實(shí)數(shù)據(jù)分布的不同部分。當(dāng)計(jì)算分類模型的熵值時(shí),本文模型以一個(gè)batch為單位,計(jì)算的是每個(gè)batch的熵值大小。生成器的目標(biāo)是生成更加接近真實(shí)的并且具有多樣性的樣本。
MM-GAN的訓(xùn)練過程為
a)將從高斯分布Pz1和Pz2中采樣的噪聲z輸入到generator1和generator2中,得到偽樣本G1(z)和G2(z)。
b)將G1(z)、G2(z)以及真實(shí)樣本x輸入到判別器discriminator中,并利用式(6)中定義的損失函數(shù)來調(diào)整判別器的參數(shù)。
c)使用真實(shí)樣本x以及對應(yīng)的真實(shí)標(biāo)簽y來訓(xùn)練分類器,損失函數(shù)被定義在式(7)中。
d)保持判別器和分類器的參數(shù)不變。使用式(8)來同時(shí)優(yōu)化生成器generator1和generator2的參數(shù),并且保證兩個(gè)生成器之間共享部分權(quán)重參數(shù)。
e)以上四個(gè)步驟重復(fù)執(zhí)行,直至達(dá)到設(shè)定的epoch。
f)模型訓(xùn)練完成。
MM-GAN能夠在保證樣本質(zhì)量的同時(shí),提高樣本的多樣性。如圖4所示,generator1和generator2所生成的分布能夠分別擬合真實(shí)分布的不同部分,將兩者相互結(jié)合,就有可能擬合真實(shí)數(shù)據(jù)的完整分布。
3 實(shí)驗(yàn)結(jié)果
為驗(yàn)證本文模型的有效性,分別在MNIST(灰度圖)、CIFAR-10(RGB圖)、CIFAR-100(RGB圖)等數(shù)據(jù)集上進(jìn)行了實(shí)驗(yàn)。所有的實(shí)驗(yàn)均使用TensorFlow2.1實(shí)現(xiàn),運(yùn)行在兩塊具有11 GB顯存的RTX2080Ti上。
3.1 MNIST數(shù)字分類數(shù)據(jù)集
MNIST數(shù)據(jù)集[30]包含了從0~9共10個(gè)類別的70 000張灰度圖片。由于生成對抗網(wǎng)絡(luò)的不同變體在擬合真實(shí)數(shù)據(jù)分布時(shí),沒有本質(zhì)的差異,都會(huì)傾向于擬合真實(shí)數(shù)據(jù)分布的一部分,所以,本文將兩個(gè)生成器均設(shè)置為四層的全連接網(wǎng)絡(luò)并共享前兩層的神經(jīng)網(wǎng)絡(luò)參數(shù)。兩個(gè)生成器的輸入均為100維的向量。但是兩個(gè)輸入分別來自正態(tài)分布N(0,1)和N(6,4)。在訓(xùn)練模型之前,將灰度圖片的像素值標(biāo)準(zhǔn)化到[-1,1]。在訓(xùn)練階段,數(shù)據(jù)的加載方式是批處理,每個(gè)batch的大小為512,迭代次數(shù)為2 000,學(xué)習(xí)率為1E-3。為了清晰地對比標(biāo)準(zhǔn)生成對抗網(wǎng)絡(luò)與MM-GAN在擬合真實(shí)數(shù)據(jù)分布時(shí)的差異,本文選擇了兩個(gè)模型在相同迭代次數(shù)時(shí)的生成樣本,如圖5所示。從圖中可以發(fā)現(xiàn),在第80次迭代時(shí),標(biāo)準(zhǔn)生成對抗網(wǎng)絡(luò)生成的樣本具有多樣性。但隨著訓(xùn)練的繼續(xù)進(jìn)行,生成樣本的多樣性逐漸減少,樣本變得越來越單一。這種現(xiàn)象論證了對式(5)的分析,標(biāo)準(zhǔn)生成對抗網(wǎng)絡(luò)更加傾向于擬合真實(shí)分布的一部分。另外,這種擬合具有隨機(jī)性,重復(fù)多次實(shí)驗(yàn),結(jié)果會(huì)有所不同。如圖6所示,生成器更加傾向于擬合數(shù)字1,而非圖5所示的數(shù)字9。對于兩個(gè)正態(tài)分布中的均值和方差的取值問題,經(jīng)過多次實(shí)驗(yàn)發(fā)現(xiàn),兩個(gè)分布之間較大的耦合程度會(huì)使得兩個(gè)生成器的生成分布的相似性較大,生成結(jié)果沒有明顯改進(jìn),而兩個(gè)正態(tài)分布之間較小的耦合程度則會(huì)使得兩個(gè)生成器的生成分布之間具有明顯的邊界感,很難擬合真實(shí)數(shù)據(jù)不同模式之間的聯(lián)系,因此對于不同的數(shù)據(jù)集,正態(tài)分布中的均值和方差需要依據(jù)“6σ原則”反復(fù)調(diào)整。
3.2 CIFAR-10數(shù)據(jù)集和CIFAR-100數(shù)據(jù)集
CIFAR-10數(shù)據(jù)集[31]包含了airplane、automobile、bird等10個(gè)類別的60 000張彩色圖片,每張圖片的尺寸為32×32。為了充分驗(yàn)證本文模型的有效性,在多種數(shù)據(jù)集上進(jìn)行驗(yàn)證。彩色圖片是RGB三通道的圖片,與灰度圖片有所不同。因此,模型的生成器與判別器使用卷積層來取代全連接層。本文將生成器設(shè)置為6層的深度卷積神經(jīng)網(wǎng)絡(luò)[32],每層卷積核的數(shù)量分別為512、256、256、128、128、3,卷積核的大小為4×4,步長為2。兩個(gè)生成器的輸入均為100維的向量。通過多次實(shí)驗(yàn)發(fā)現(xiàn),對于CIFAR-10數(shù)據(jù)集,生成器的兩個(gè)輸入分別來自正態(tài)分布N(-1,1)和N(8,9)時(shí),生成分布對真實(shí)分布的擬合效果最佳。在訓(xùn)練階段,每個(gè)batch的大小為64,迭代次數(shù)為2 000,學(xué)習(xí)率為1E-3。訓(xùn)練結(jié)果如圖7所示,本文比較了深度卷積生成對抗網(wǎng)絡(luò)與MM-GAN在RGB圖像上的分布擬合狀況。在第900次迭代時(shí),深度卷積生成對抗網(wǎng)絡(luò)生成了較為清晰的圖像,并且生成的圖像也具備了一定的多樣性。隨著訓(xùn)練的進(jìn)行,在第1 000次迭代時(shí),樣本的多樣性開始減少,在第1 300次迭代時(shí),生成的樣本主要包括bird和ship,但依然可以發(fā)現(xiàn)一些其他類別的圖像,到第1 800次迭代時(shí),生成的樣本基本都退化為bird和ship。而MM-GAN生成的樣本自始至終都保持了樣本的多樣性,隨著訓(xùn)練的進(jìn)行,生成的圖像變得更加清晰。
CIFAR-100數(shù)據(jù)集[31]包含了mammals beaver、aquarium fish、orchids等100個(gè)類別的60 000張彩色圖片。CIFAR-100與CIFAR-10都是RGB三通道的圖像,但是它的類別更多,分布更加復(fù)雜。為了增強(qiáng)生成器的擬合能力,本文將生成器設(shè)置為八層的深度卷積神經(jīng)網(wǎng)絡(luò),每層卷積核的數(shù)量分別為512、256、256、256、128、128、128、3,其他超參數(shù)的設(shè)置原則與CIFAR-10時(shí)保持一致。在訓(xùn)練階段,每個(gè)batch的大小為512,迭代次數(shù)為2 000,學(xué)習(xí)率為1E-3。由于CIFAR-100的分布更加復(fù)雜,所以在相同的超參數(shù)下,模型更容易出現(xiàn)模式坍塌問題,如圖8所示。在第600次迭代時(shí),DCGAN生成了多樣性的樣本。隨著訓(xùn)練的進(jìn)行,在第900次迭代時(shí),樣本的多樣性明顯減少,到第1 100次迭代時(shí),生成的樣本基本都退化為sea和cloud。而MM-GAN自始至終都保持著生成樣本的多樣性。
對于不同的數(shù)據(jù)集,在訓(xùn)練階段,本文設(shè)置了不同尺寸的batch。MNIST數(shù)據(jù)集的batch尺寸為512,CIFAR-10和CIFAR-100的batch尺寸分別為64和512。由于信息熵約束作用于batch上,所以batch尺寸的選擇對于生成器的擬合能力有一定的影響。當(dāng)batch的尺寸小于數(shù)據(jù)集的類別時(shí),熵約束會(huì)使得batch中的數(shù)據(jù)局限于有限幾個(gè)類別上,無法考慮所有類別的差異,這會(huì)對生成器的擬合能力造成影響,即生成器難以擬合更多的模式。本文將batch的大小確定為類別總數(shù)的五倍左右,以確保每個(gè)類別都有一定數(shù)量的生成樣本,從而提高生成樣本的多樣性。MNIST和CIFAR-10都有10個(gè)類別,所以batch的尺寸至少為50。MNIST數(shù)據(jù)集是單通道的手寫數(shù)字?jǐn)?shù)據(jù)集,相比于三通道的圖像數(shù)據(jù),具有更少的特征,所以可以增加MNIST的batch的尺寸。最終,本文確定MNIST和CIFAR-10的batch尺寸分別為512和64。CIFAR-100的batch尺寸也遵循相同的原則。本文將迭代次數(shù)限制為2 000,學(xué)習(xí)率為1E-3,是因?yàn)樵? 000次迭代之內(nèi),模式坍塌現(xiàn)象就已經(jīng)比較明顯,如圖8所示。學(xué)習(xí)率為1E-3在多次實(shí)驗(yàn)中保持一致,更大的學(xué)習(xí)率會(huì)導(dǎo)致模式坍塌現(xiàn)象更早發(fā)生。
另外,本文對比了深度卷積生成對抗網(wǎng)絡(luò)與MM-GAN的生成器損失的變化趨勢,如圖9所示。深度卷積生成對抗網(wǎng)絡(luò)的生成器損失在前250次迭代過程中變化幅度非常大,并迅速變得平緩,這說明生成器在生成了一些噪聲之后,能夠迅速生成滿足判別器要求的樣本,與此相對應(yīng)的是生成的樣本逐漸變得單一。在同一迭代過程中,生成器總能生成個(gè)別置信度高的樣本,而深度卷積生成對抗網(wǎng)絡(luò)會(huì)將生成的分布向個(gè)別置信度高的樣本的分布偏移,最終導(dǎo)致模式坍塌問題。MM-GAN的生成器損失一直保持著一個(gè)比較大的振蕩幅度,即MM-GAN不會(huì)輕易向部分分布偏移,而會(huì)使生成分布擬合真實(shí)數(shù)據(jù)的完整分布。
3.3 不同模型發(fā)生模式坍塌的概率
為了進(jìn)一步驗(yàn)證本文模型能夠在一定程度上緩解模式坍塌問題,分別進(jìn)行了120次實(shí)驗(yàn),并統(tǒng)計(jì)了多種GAN變體在不同數(shù)據(jù)集上模式坍塌問題發(fā)生的概率,如表1所示。從表1中能夠發(fā)現(xiàn),DCGAN基本不具有擬合分布的能力,在訓(xùn)練過程中經(jīng)常會(huì)發(fā)生模式坍塌問題,這是由于DCGAN只能覆蓋很少的模式;ProGAN和SRGAN等在模式坍塌問題上也表現(xiàn)較差。ProGAN和SRGAN的主要目的在于提升圖像的質(zhì)量,而不是改善圖像的多樣性。從數(shù)學(xué)角度講,ProGAN和SRGAN在分布擬合能力上與DCGAN等沒有很大差異。從表中可以發(fā)現(xiàn),DCGAN、ProGAN和SRGAN在模式坍塌發(fā)生的概率上較為接近,也佐證了這一觀點(diǎn);而CoGAN和MAD-GAN相較于DCGAN等有了部分提升,但是這種效果的提升主要依賴于生成器的堆積。多個(gè)生成器在擬合分布時(shí),不同的初始化方式使得多個(gè)生成器傾向于擬合真實(shí)數(shù)據(jù)分布的不同部分,從而提升了模型的擬合能力。本文模型MM-GAN雖然也增加了生成器的數(shù)量,但是MM-GAN的分布擬合能力更強(qiáng),因?yàn)槠湓黾恿诵畔㈧丶s束以及改變了模型結(jié)構(gòu),使得生成器的運(yùn)用更加合理。
另外,數(shù)據(jù)集的大小對模式坍塌也有一定的影響。ImageNet[33]為具有1 000個(gè)類別的RGB圖像,而MNIST和CIFAR-100分別具有10個(gè)類別和100個(gè)類別。從表1中可以發(fā)現(xiàn),ImageNet具有更高的概率發(fā)生模式坍塌問題,而MNIST的發(fā)生概率相對較低。一種可能的解釋是,生成器在擬合真實(shí)分布時(shí),更多的類別會(huì)導(dǎo)致模型能夠更快地生成置信度高的樣本,這種生成帶有一定的隨機(jī)性,但是這依然會(huì)導(dǎo)致模型發(fā)生偏移。
4 結(jié)束語
為了改善生成模型中的模式坍塌問題,提高生成樣本的多樣性,本文借鑒了協(xié)作式生成對抗網(wǎng)絡(luò)的一些思想,提出了多模式生成對抗網(wǎng)絡(luò)。通過定義多個(gè)生成器以及單個(gè)分類器,并且引入共享權(quán)值參數(shù)的方法以及熵值約束損失等,使得模型能夠完整地?cái)M合真實(shí)數(shù)據(jù)的分布。實(shí)驗(yàn)證明,MM-GAN在多種數(shù)據(jù)集上有效地緩解了模式坍塌問題。未來將進(jìn)一步優(yōu)化網(wǎng)絡(luò)結(jié)構(gòu)和損失函數(shù),研究模式坍塌問題的解決方法,使生成模型能夠生成多樣性的樣本。
參考文獻(xiàn):
[1]Ng A Y,Jordan M I.On discriminative vs.generative classifiers:a comparison of logistic regression and naive Bayes[C]//Advances in Neural Information Processing Systems.2002:841-848.
[2]Goodfellow I,Pouget-Abadie J,Mirza M,et al.Generative adversarial nets[C]//Advances in Neural Information Processing Systems.2014.
[3]Brunner G,Wang Yuyi,Wattenhofer R,et al.Symbolic music genre transfer with CycleGAN[C]//Proc of the 30th IEEE International Conference on Tools with Artificial Intelligence.Piscataway,NJ:IEEE Press,2018:786-793.
[4]Ryu Y S,Oh S Y.Simple hybrid classifier for face recognition with adaptively generated virtual data[J].Pattern Recognition Letters,2002,23(7):833-841.
[5]張龍,趙杰煜,葉緒倫,等.協(xié)作式生成對抗網(wǎng)絡(luò)[J].自動(dòng)化學(xué)報(bào),2018,44(5):804-810.(Zhang Long,Zhao Jieyu,Ye Xulun,et al.Co-operative generative adversarial nets[J].Acta Automatica Sinica,2018,44(5):804-810.)
[6]Kingma D P,Welling M.Auto-encoding variational Bayes[EB/OL].(2013).https://arxiv.org/abs/ 1312.6114.
[7]Van Oord A,Kalchbrenner N,Kavukcuoglu K.Pixel recurrent neural networks[C]//Proc of International Conference on Machine Lear-ning.2016:1747-1756.
[8]Karpathy A.The unreasonable effectiveness of recurrent neural networks,2015[EB/OL].(2016).http://karpathy.github.io/2015/05/21/rnn-effectiveness.
[9]Gulrajani I,Ahmed F,Arjovsky M,et al.Improved training of Wasserstein GANs[EB/OL].(2017).https://arxiv.org/abs/1704.00028.
[10]Karras T,Aila T,Laine S,et al.Progressive growing of GANs for improved quality,stability,and variation[EB/OL].(2017).https://arxiv.org/abs/1710.10196.
[11]Karras T,Laine S,Aila T.A style-based generator architecture for generative adversarial networks[C]//Proc of IEEE/CVF Conference on Computer Vision and Pattern Recognition.Piscataway,NJ:IEEE Press,2019:4401-4410.
[12]Zhu Junyan,Park T,Isola P,et al.Unpaired image-to-image translation using cycle-consistent adversarial networks[C]//Proc of IEEE International Conference on Computer Vision.Piscataway,NJ:IEEE Press,2017:2223-2232.
[13]Isola P,Zhu Junyan,Zhou Tinghui,et al.Image-to-image translation with conditional adversarial networks[C]//Proc of IEEE Conference on Computer Vision and Pattern Recognition.Piscataway,NJ:IEEE Press,2017:1125-1134.
[14]Wang Tingchun,Liu Mingyu,Zhu Junyan,et al.High-resolution image synthesis and semantic manipulation with conditional GANs[C]//Proc of IEEE Conference on Computer Vision and Pattern Recognition.Piscataway,NJ:IEEE Press,2018:8798-8807.
[15]Choi Y,Choi M,Kim M,et al.StarGAN:unified generative adversarial networks for multi-domain image-to-image translation[C]//Proc of IEEE Conference on Computer Vision and Pattern Recognition.Piscataway,NJ:IEEE Press,2018:8789-8797.
[16]Bulat A,Yang Jing,Tzimiropoulos G.To learn image super-resolution,use a GAN to learn how to do image degradation first[C]//Proc of European Conference on Computer Vision.2018:185-200.
[17]Lugmayr A,Danelljan M,Timofte R.Unsupervised learning for real-world super-resolution[C]//Proc of IEEE/CVF International Confe-rence on Computer Vision Workshop.Piscataway,NJ:IEEE Press,2019:3408-3416.
[18]Yuan Yuan,Liu Siyuan,Zhang Jiawei,et al.Unsupervised image super-resolution using cycle-in-cycle generative adversarial networks[C]//Proc of IEEE Conference on Computer Vision and Pattern Recognition Workshops.Piscataway,NJ:IEEE Press,2018:701-710.
[19]Zhao Tianyu,Ren Wenqi,Zhang Changqing,et al.Unsupervised degradation learning for single image super-resolution[EB/OL].(2018-12-13).https://arxiv.org/abs/1812.04240.
[20]Demir U,Unal G.Patch-based image inpainting with generative adversarial networks[EB/OL].(2018).https://arxiv.org/abs/1803.07422.
[21]Gao Hongchang,Pei Jian,Huang Heng.ProGAN:network embedding via proximity generative adversarial network[C]//Proc of the 25th ACM SIGKDD International Conference on Knowledge Discovery amp; Data Mining.New York:ACM Press,2019:1308-1316.
[22]Ledig C,Theis L,Huszár F,et al.Photo-realistic single image super-resolution using a generative adversarial network[C]//Proc of IEEE Conference on Computer Vision and Pattern Recognition.Piscataway,NJ:IEEE Press,2017:4681-4690.
[23]Salimans T,Goodfellow I,Zaremba W,et al.Improved techniques for training GANs[C]//Advances in Neural Information Processing Systems.2016:2234-2242.
[24]Ghosh A,Kulharia V,Namboodiri V P,et al.Multi-agent diverse generative adversarial networks[C]//Proc of IEEE Conference on Computer Vision and Pattern Recognition.Piscataway,NJ:IEEE Press,2018:8513-8521.
[25]Arjovsky M,Bottou L.Towards principled methods for training generative adversarial networks[EB/OL].(2017).https://arxiv.org/abs/1701.04862.
[26]Hansen L P.Large sample properties of generalized method of moments estimators[J].Econometrica:Journal of the Econometric Society,1982,50(4):1029-1054.
[27]Liu Mingyu,Tuzel O.Coupled generative adversarial networks[C]//Advances in Neural Information Processing Systems.2016:469-477.
[28]Liu Mingyu,Breuel T,Kautz J.Unsupervised image-to-image translation networks[C]//Advances in Neural Information Processing Systems.2017:700-708.
[29]Ghosh A,Kulharia V,Namboodiri V P,et al.Multi-agent diverse generative adversarial networks[C]//Proc of IEEE Conference on Computer Vision and Pattern Recognition.Piscataway,NJ:IEEE Press,2018:8513-8521.
[30]LeCun Y,Bottou L,Bengio Y,et al.Gradient-based learning applied to document recognition[J].Proceedings of the IEEE,1998,86(11):2278-2324.
[31]Krizhevsky A,Hinton G.Learning multiple layers of features from tiny images[R].2009.
[32]Radford A,Metz L,Chintala S.Unsupervised representation learning with deep convolutional generative adversarial networks[EB/OL].(2015).https://arxiv.org/abs/1511.06434.
[33]Russakovsky O,Deng Jia,Su Hao,et al.Imagenet large scale visual recognition challenge[J].International Journal of Computer Vision,2015,115(3):211-252.