曹爽



摘? 要: 為了提升合成表格數(shù)據(jù)的質量,提出一種簡單的方法生成每個類的數(shù)據(jù),使用度量損失控制每一類結構化數(shù)據(jù)的生成,將此方法命名為SCGAN。文章用此方法在二分類問題上進行了嘗試。使用三種不同的度量損失在三個真實的數(shù)據(jù)集上訓練生成對抗網(wǎng)絡:逐次對每一類數(shù)據(jù)進行合成,利用合成數(shù)據(jù)訓練分類器模型,使用gmean來評估模型的性能。結果表明,單獨生成每一類數(shù)據(jù)能夠提升模型的分類性能。
關鍵詞: 合成數(shù)據(jù); 度量損失; 生成對抗網(wǎng)絡; 分類器
中圖分類號:TP391? ? ? ? ? 文獻標識碼:A? ? ?文章編號:1006-8228(2021)04-25-03
Abstract: In order to improve the quality of tabular data synthesis, a simple method to generate data of each category is proposed, and it is named SCGAN and uses metrics loss to control the generation of structured data of each category. In this paper, the binary classification problem is tried to be solved by this method. By using three different metrics losses, the generative adversarial network is trained on three real datasets that each category of data are synthesized one by one, the classifier model are trained with the synthesized data, and gmean is used to evaluate the performance of the model. The results show that generating each category of data separately can improve the classification performance of the model.
Key words: synthesized data; metrics loss; generative adversarial networks; classifier
0 引言
近年來,生成對抗網(wǎng)絡在生成高質量合成圖像方面取得了很大的成功[1]。多種數(shù)據(jù)類型,數(shù)據(jù)分布不確定,多模態(tài)分布,數(shù)據(jù)不均衡等特點對生成表格型數(shù)據(jù)帶來了挑戰(zhàn)[2]。MedGAN提出醫(yī)學生成對抗網(wǎng)絡,來生成逼真的合成病歷[3]。TableGAN使用生成對抗網(wǎng)絡來合成假表,這些假表在統(tǒng)計上類似于原始表[4]。CTGAN對連續(xù)數(shù)據(jù)進行建模,對離散數(shù)據(jù)增加條件損失來合成高質量數(shù)據(jù)[2]。
本文在CTGAN的基礎上提出一種無監(jiān)督的生成對抗網(wǎng)絡方法,將衡量指標FID[5],MMD[6],最小二乘作為度量模塊應用到生成對抗網(wǎng)絡模型中,利用單個類別的數(shù)據(jù)訓練模型生成大量的合成數(shù)據(jù),利用梯度懲罰[7]和譜歸一化方法[8]來增強模型訓練的穩(wěn)定性。在三個真實的數(shù)據(jù)集上選取相同數(shù)量的生成數(shù)據(jù)對三種度量方法做了比較,實驗結果顯示,本文提出的方法能夠提升生成數(shù)據(jù)的質量,提升模型分類的性能。
1 SCGAN
1.1 生成對抗網(wǎng)絡
生成對抗網(wǎng)絡是一種生成模型[1],包含生成器(G)和判別器(D)兩部分。生成器目的是生成逼真的合成數(shù)據(jù)以最大程度的騙過判別器來達到損失的最小化,判別器爭取將真實數(shù)據(jù)和合成數(shù)據(jù)分別開來[9]。以下為生成對抗網(wǎng)絡的一般形式:
其中[z]是隨機輸入的噪聲,一般為高斯分布中的隨機采樣點,[pz]是潛在向量[z]的先驗分布,[G?]是生成器函數(shù),[D?]是判別器函數(shù)。
1.2 度量損失
為了保證生成數(shù)據(jù)的質量,將三種度量損失:FID,MMD,最小二乘等加入到生成對抗網(wǎng)絡模型中,由于最小二乘比較簡單,在此我們著重介紹前兩種方法。
⑴ Frechet Inception Distance (FID)
FID[5]常用于評估生成器最終生成的圖像質量,計算真實數(shù)據(jù)和合成數(shù)據(jù)在特征層面的距離,距離越小,說明合成數(shù)據(jù)與真實數(shù)據(jù)越相似,以下是FID的計算公式:
其中[Pr],[Pg]分別表示真實數(shù)據(jù)和生成數(shù)據(jù),[C]表示數(shù)據(jù)的協(xié)方差矩陣,[u]表示數(shù)據(jù)的均值,我們將這種評估方式應用到生成表格數(shù)據(jù)的生成對抗模型中,參與生成器模型的訓練,鼓勵生成器學習真實數(shù)據(jù)的分布。
⑵ Maximum Mean Discrepancy (MMD)
MMD[6]是一種基于最大均方差的統(tǒng)計檢驗來優(yōu)化兩類樣本的分布,常用于評估生成圖像的質量。此處,我們使用MMD衡量生成的結構化數(shù)據(jù),定義如下:
給定兩類結構化數(shù)據(jù)集,[V=v1,v2,…vm]和[W=w1,w2,…wm],以下為MMD計算公式:
其中[k?]是高斯核函數(shù)。
1.3 SCGAN整體流程
整體流程如圖1所示。我們使用生成對抗網(wǎng)絡對劃分好的訓練集進行訓練,生成指定類別的合成數(shù)據(jù),TrainData0表示第一類數(shù)據(jù)對應生成數(shù)據(jù)Fake0,TrainData1表示第二類數(shù)據(jù)對應生成數(shù)據(jù)Fake1,在G,D網(wǎng)絡中我們遵循了CTGAN的網(wǎng)絡結構,但是由于我們是生成指定類別的數(shù)據(jù),所以在生成器和判別器中去除了條件輸入,在G中加入了3種度量損失函數(shù)。當生成指定類別的數(shù)據(jù)后,對生成的數(shù)據(jù)每個類分別選取500個和1000個樣本,最終組成1000和2000大小的訓練集,訓練分類器(SVM,RF,DT)模型,使用gmean[10]評估分類器的性能。
2 實驗
2.1 數(shù)據(jù)集介紹
本文研究的數(shù)據(jù)集來自于①Covtype,用來預測森林覆蓋類型的多分類數(shù)據(jù)集,我們選擇了Ponderosa Pine,Krummholz這兩類數(shù)據(jù)來測試我們的模型。②Adult是一個從人口普查數(shù)據(jù)庫中提取的個人信息記錄的數(shù)據(jù)集,我們將收入是否超過50k,作為分類的二進制標簽。③BitcoinHeist是一個有關比特幣交易圖的數(shù)據(jù)集,簡記為Bit,從中選取了princetonCerber和montrealCryptoLocker類別的數(shù)據(jù),對數(shù)據(jù)進行二分類。
2.2 方法比較
在我們的SCGAN中,我們對比了使用不同度量下生成樣本的質量,而且也與不加度量損失的生成對抗網(wǎng)絡和原始的CTGAN進行了對比。SCGAN-FID表示在生成器上使用FID作為度量損失,SCGAN-MMD表示在生成器上使用MMD作為度量損失,SCGAN-LS表示在生成器上使用最小二乘作為度量損失,GAN表示沒有加度量損失。值得注意的一點,在三種度量方法和沒有使用度量方法的GAN中,除了損失函數(shù)的差異,其他迭代次數(shù)和網(wǎng)絡都是一致的。
2.3 實驗結果
在實驗中,我們記錄了每一種方法以及每一種數(shù)據(jù)集在每一種基分類器實驗結果,為了顯現(xiàn)整體的有效性,表1至表3是每一種方法在三個基分類器上的平均結果。從表1和表2中可以看到,在三個真實的數(shù)據(jù)集上,本文提出的SCGAN整體優(yōu)于CTGAN,另外,在表3中,我們記錄了不使用度量損失下的GAN模型的性能,根據(jù)在gmean指標上的評估可以看到,進一步說明了度量損失的有效性。
3 總結
本文提出的SCGAN,分別進行每一類別的數(shù)據(jù)合成,通過實驗表明能夠提升模型的分類性能。我們只在二分類問題上進行了嘗試,將此方法應用到多類不均衡數(shù)據(jù)集中是我們接下來的研究重點。
參考文獻(References):
[1] Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems,2014:2672-2680
[2] Xu L, Skoularidou M, Cuesta-Infante A, et al. Modeling tabular data using conditional gan[C]//Advances in Neural Information Processing Systems,2019:7335-7345
[3] Choi E, Biswal S, Malin B, et al. Generating multi-label discrete patient records using generative adversarial networks[J]. arXiv preprint arXiv:1703.06490,2017.
[4] Park N, Mohammadi M, Gorde K, et al. Data synthesis based on generative adversarial networks[J].arXiv preprint arXiv:1806.03384,2018.
[5] Heusel M, Ramsauer H, Unterthiner T, et al. Gans trained by a two time-scale update rule converge to a local nash equilibrium[J]. Advances in neural information processing systems,2017.30: 6626-6637
[6] Sutherland D J, Tung H Y, Strathmann H, et al.Generative models and model criticism via optimized maximum mean discrepancy[J]. arXiv preprint arXiv:1611.04488,2016.
[7] Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems,2017.30: 5767-5777
[8] Miyato T, Kataoka T, Koyama M, et al. Spectral normalization for generative adversarial networks[J].arXiv preprint arXiv:1802.05957,2018
[9] 張重生著.人工智能 人臉識別與搜索[M].電子工業(yè)出版社,2020.
[10] Leevy J L, Khoshgoftaar T M, Bauder R A, et al. A survey on addressing high-class imbalance in big data[J]. Journal of Big Data,2018.5(1):42