侍海峰,何良華,盧劍
(1.同濟大學計算機系,上海201804;2.北京大學第三醫院,北京100083)
生成式對抗網絡(Generative Adversarial Nets,GAN)[1]是一類基于對抗訓練和深度神經網絡的無監督生成式模型,由一個生成器網絡和一個判別器網絡構成,可以生成服從訓練集分布的無限多的樣本。自2014年被提出后,迅速成為深度學習、人工智能領域的研究熱點之一,在圖像生成、圖像風格變換、圖像超分辨率、視頻生成等領域應用廣泛。
然而GAN存在訓練不穩定,容易對抗崩潰的問題。此外,生成分布還存在“模式丟失(mode dropping)”問題[2],只能生成訓練集分布的一個子集,多樣性不足。WGAN-GP[3]改進了原始GAN的目標函數,解決了訓練不穩定的問題。混合生成式對抗網絡(Mixture GAN,MGAN)[4]是一種集成模型,通過混合多個生成器的分布來改善模式丟失問題,增加生成樣本的多樣性。然而,MGAN的多個生成器的混合權重被設置為均等值,不適合類別不平衡且比例未知的數據集。
基于WGAN-GP和MGAN,本文提出模式分工型混合生成式對抗網絡(Mode-Splitting MGAN,MSMGAN),向MGAN的訓練算法中加入了生成器混合權重學習環節,提高了MGAN在類別不平衡數據集上的生成效果,能促使多個生成器分別學習訓練集中不同的模式,即“模式分工”。此外,替換MGAN的原始GAN目標函數為WGAN-GP的目標函數,使訓練更穩定。在由UTKFace[5]和Toronto Face Dataset[6]混合而成的多模態不平衡人臉圖像數據集上的實驗表明,MSMGAN生成分布具有更低的Frechet Inception Distance(FID)[7],且支持按類別生成圖片。
原始生成式對抗網絡[1]結構如圖1所示,由生成器和判別器兩個神經網絡構成。生成器將輸入的噪聲向量z映射為生成樣本xF(又稱假樣本);判別器接收生成的假樣本和來自訓練集的真實樣本x,輸出樣本為真實樣本的概率。
噪聲向量z的各分量通常是相互獨立的高斯噪聲。生成器的訓練目標是優化自身參數,盡可能使判別器誤把假樣本判別成真樣本;判別器的訓練目標則是優化自身參數,盡可能準確地區分真實樣本和假樣本。生成器和判別器訓練的目標函數(損失函數)分別為:


圖1 生成式對抗網絡結構圖
文獻[1]證明了在理想的條件下,對抗訓練將達到納什均衡,生成器生成的樣本xF的分布Pg將和訓練集真實分布Pr相同,判別器將無法區分其輸入樣本的來源,輸出恒定為0.5。此時的生成器即可用于生成能以假亂真的樣本。
原始GAN存在訓練不穩定,容易崩潰的問題。針對該問題,文獻[3]提出WGAN-GP模型,模型的判別器損失函數為:


WGAN-GP訓練穩定性和生成分布的多樣性均優于原始GAN,本文提出的MS-MGAN模型亦采用該損失函數和訓練方式。
為了解決GAN存在的“模式丟失”問題,增加生成樣本的多樣性,一些基于集成模型思路的集成類GAN模型被提出,Mixture GAN(MGAN)便是其中的典型。MGAN由K個生成器網絡、一個判別器網絡D和一個分類器C構成。分類器預測生成樣本來源于哪一個生成器,判別器預測生成樣本來源于真實分布還是生成分布。MGAN的生成器、判別器和分類器進行如下的最小-最大博弈:

可見生成器的目標有兩部分,既含有原始GAN目標函數中對抗判別器的項,又包含迎合分類器分類的項。后者含有超參數β,用于平衡目標函數中二者的比例。MGAN中各生成器G1,G2,…,Gk的混合權重被π1,π2,…,πK設定為1/K,即K個生成器的分布均勻混合。MGAN的目標函數能直接迫使多個生成器生成不同模式的樣本,以便于分類器區分,適合學習由若干個良好分離的分布等概率混合而成的分布。
基于WGAN-GP和MGAN模型,本文提出模式分工混合生成對抗網絡(MS-MGAN),以更好地學習和生成類別不均衡的數據分布。原始GAN和WGAN-GP模型都只具有一個生成器,讓單個生成器網絡學習復雜的多模態圖像數據分布是比較困難的,易導致生成的圖像質量欠佳。MGAN采用多個生成器分工學習復雜的數據分布,其實驗表明[4],算法提高了生成樣本的質量和多樣性,但是其超參數β對數據集比較敏感,需要精心調節,增加了算法的調參難度。此外,MGAN的多個生成器的混合權重π1,π2,…,πK被設定為均勻分布,但現實中的數據集往往各類別(模式)的占比不均勻,導致MGAN的各生成器會出現不合理的分工,影響生成質量。
本文提出的MS-MGAN舍棄了MGAN中的分類器,從而從模型中去除了敏感的超參數 β。采用WGAN-GP的訓練目標代替了MGAN中使用的原始GAN目標函數,提高了模型的訓練穩定性。此外,增加了多個生成器混合權重的學習環節,能根據訓練分布中不同模式樣本數量的占比分配各生成器對應得權重,使得各生成器合理分工學習訓練集中不同得模式,即使沒有額外分類器的促使作用。MS-MGAN的判別器損失函數同WGAN-GP的判別器損失函數,即式(3)。而生成器損失函數則被修改為判別器對多個生成器樣本評價的加權值:

MS-MGAN的訓練算法在WGAN-GP的基礎上增加了對π1,π2,…,πk的梯度下降法更新過程。在一次GAN訓練迭代中,除了原有的①固定生成器網絡,訓練判別器;②固定判別器網絡,訓練生成器;這兩個步驟以外,加入混合權重學習環節③固定判別器網絡和各生成器網絡,并從每一個生成器各采樣一個mini-batch的生成樣本,計算判別器對各生成器分布的期望評價并作為常數帶入(6)式,再將生成器的損失函數對π1,π2,…,πK,求梯度,更新混合權重。每一個迭代中進行上述3個步驟的計算,可以使多個生成器按訓練集中不同模式的比重合理分工。
為了測試MS-MGAN在類別不均衡數據集上的表現,將UTKFace和Toronto Face Dataset(TFD)數據集中的人臉圖像混合成一個訓練集。UTKFace數據集提供了23708張分辨率為200×200的剪裁并對齊了的彩色人臉圖像,TFD包含102236張分辨率96×96的灰度人臉圖像。由于UTKFace中的圖像數量較少,將每一張人臉圖像都水平翻轉,以將數據集圖像數量倍增至47416。TFD數據集中的圖像的灰度通道則被復制為3通道彩色圖像。所有人臉圖像均縮放至64×64分辨率。因此可知不平衡混合人臉數據集中UTKFace和TFD這兩種模式的比例為47416:102236=0.3168:0.6832。
本文MS-MGAN實驗程序基于WGAN-GP的官方開源代碼修改而成,生成器網絡和判別器網絡均選用廣泛使用的類DCGAN[8]的網絡結構。實驗使用NVIDIA GTX 1080Ti GPU和TensorFlow 1.12進行訓練,操作系統為Ubuntu 18.04。
實驗評價除了采用直接觀察生成圖像的定性方法外,還采用廣泛使用的定量指標Frechet Inception Distance(FID)。FID使用Inception[9]模型的中間編碼層的特征向量,對訓練集真實圖像和生成器合成的圖像的Inception編碼層特征分別回歸成多元高斯分布,然后計算這兩個多元高斯分布之間的Frechet距離,計算公式如下:

其中mr,mF分別是真實圖像和生成圖像輸入Inception模型得到的編碼層向量的均值,Cr和CF分別是協方差。FID值越低,生成分布就更接近真實分布。
實驗考察了單生成器的WGAN-GP模型、具有2個生成器的MGAN模型和具有2個生成器的MSMGAN模型。為了公平比較,單生成器的WGAN-GP模型的生成器規模等比例放大到MGAN和MS-MGAN生成器的2倍。三個模型的實驗的批大小均為64,訓練迭代200000次(每次迭代中生成器被訓練一次,判別器被訓練5次)。每個WGAN-GP模型和MGAN模型訓練一次花費GPU約26小時,MS-MGAN由于增加了混合權重學習環節,訓練一次花費27小時左右。為了避免神經網絡類方法固有的隨機誤差,每個模型均隨機初始化訓練了4次。訓練過程中,每10000次迭代,采樣50000個生成樣本,計算FID值。每個模型在一次訓練過程中達到的最小FID值的平均值、標準差如表1所示。表中生成器參數量是指模型包含的所有生成器網絡參數的總和。所有模型都只有一個相同結構的判別器,判別器參數量均為4.317M。

表1 不同模型生成質量(FID)對比表
從表1可以看出,WGAN-GP、MGAN和MSMGAN的FID值依次降低,表明后二者生成分布質量均優于WGAN-GP,即使單生成器模型的生成器尺寸已經翻倍。支持生成器混合權重學習的MS-MGAN取得了最低的FID值,表明MS-MGAN更能適應現實中更為普遍的類別不均衡數據集。
圖2為MS-MGAN模型在200000次GAN訓練迭代中,兩個生成器的混合權重的變化趨勢。橫軸為迭代次數,為展現訓練早期的曲線變化,采用對數坐標。可見訓練初期兩個生成器的混合權重波動較為劇烈,因為此時模型剛初始化,見到的訓練集樣本不多,生成分布和判別器判別都不太準確。但是迭代次數超過10000次后,混合權重已經基本穩定在真實值附近小幅波動了。因此,本文提出的MS-MGAN可以快速學習出合理的生成器混合權重,其實可以在后95%的訓練迭代中的固定混合權重不再學習,以加快訓練速度。

圖2 MS-MGAN生成器混合權重變化圖
表2為MS-MGAN在4次隨機初始化實驗中學習到的兩個生成器的混合權重。由于兩個生成器符號具有輪換對稱性,我們約定訓練結束后混合權重較小的生成器為G1,對應權重為π1。混合權重較大的則為G2,對應權重為π2。可見4次實驗中MS-GAN均基本準確地向兩個生成器分配了符合訓練集兩種模式占比(0.3168:0.6832)的混合權重。因此,本文提出的混合權重學習算法是穩定而精確的。

表2 MS-MGAN混合權重學習結果表
從MS-MGAN和MGAN的兩個生成器隨機采樣的一部分樣本如圖3所示,前2行(紅線上方)分別是MS-MGAN的兩個生成器生成的樣本,后兩行(紅線下方)分別是MGAN的兩個生成器生成的樣本。通過比較可以發現,MS-MGAN的兩個生成器G1,G2分別學習生成了混合人臉數據集中UTKFace(彩色)和TFD(灰度)這兩種模式的樣本,而MGAN的生成器G2負責生成TFD的人臉,G1既負責生成一部分UTKFace的彩色人臉,又負責生成一部分TFD人臉。這是因為MGAN的兩個生成器的混合權重被固定為0.5,而訓練集中兩種模式的比例分別為0.3168:0.6832,導致必須有一個生成器負責兩種模式樣本的生成,才能使生成分布服從真實分布。MS-GAN由于具有混合權重學習環節,兩個生成器的混合權重能快速收斂至訓練集中兩種模式的混合比例,使生成器更合理的分工,從而可以生成比MGAN具有更少瑕疵、失真的人臉樣本,并可以通過選擇不同的生成器以實現按類別采樣,增加了采樣的可操控性。

圖3 MS-MGAN和MGAN隨機生成樣本圖
本文提出了模式分工型混合生成式對抗網絡(MS-MGAN),向MGAN訓練算法中加入混合比例學習環節,在類別分布不平衡數據集上的生成質量優于MGAN,且各生成器能分工學習訓練集中不同的模式,從而支持按類別采樣生成。