龙空技术网

基于Wasserstein距离的生成对抗网络研究#AI#Mila

AITIME论道 224

前言:

如今小伙伴们对“alm算法”都比较重视,各位老铁们都需要剖析一些“alm算法”的相关资讯。那么小编同时在网上搜集了一些关于“alm算法””的相关内容,希望小伙伴们能喜欢,我们快快来了解一下吧!

从理论上来说,建立在Wasserstein距离的Kantorovich-Rubinstein (KR)二元性基础上的Wasserstein GANs (WGANs)模型是最完善的GAN模型之一。但在实践过程中,我们发现,该模型并不总是优于其他GANs模型的变体。这主要是由于KR二元性所要求的Lipschitz条件的实现不完善。针对这一问题,目前不少研究针对Lipschitz约束进行了不同的实现,但这在实践中仍然难以完美的满足该限制条件。相比之下,一篇来自AAAI2021的题为《Towards Generalized Implementation of Wasserstein Distance in GANs》的研究提出了一种新颖的想法:强Lipschitz约束对于优化可能是不必要的,有没有可能通过放松Lipschitz约束对现有方法进行改进。在理论上,该研究首先证明了Wasserstein距离的一般二元形式,称为Sobolev二元性,它放松了Lipschitz约束,但仍然保持了Wasserstein距离的有利梯度属性。此外,该研究还表明,KR二元性实际上是Sobolev二元性的一种特殊情况。基于放宽的二元性,文章提出了一种广义的WGAN训练方案,命名为Sobolev Wasserstein GAN (SWGAN),并通过大量实验证明了SWGAN优于现存方法。

本期AI TIME PhD直播间,我们有幸邀请到了该论文的作者,来自蒙特利尔大学算法研究所的博士生徐民凯,为大家分享这项研究工作!

徐民凯:蒙特利尔大学算法研究所(Mila)在读博士生,导师为Jian Tang教授。研究兴趣主要在于深层生成模型及其在离散数据上的应用,重点关注图形(药物研发)和文本(自然语言生成)。本科毕业于上海交通大学,师从Weinan Zhang教授和Yong Yu教授。曾在ByteDance(TikTok)AI实验室实习,并与Lei Li博士和Mingxuan Wang博士合作。其研究成果在ICML等多个顶级会议上发表,并担任ICML和AAAI等会议的程序委员会委员。

一、背景

1.1 生成对抗网络(GAN)的核心

生成对抗网络的本质是训练一个生成模型。图1是生成对抗网络的原理概念图。从图中我们可以看到两个神经网络,分别是Generator(生成器)和Discriminator(判别器)。绿色背景的生成器以随机噪声作为输入,输出一个图片。红色背景的判别器既可以输入真实世界的图片,也可以输入生成器生成的图片。判别器的作用是通过分类尽量分辨出输入的图片是来自真实世界的还是由生成器生成的。

GAN的对抗点在于,判别器需要尽量通过分类来分辨出哪些图片是真实的,哪些图片是由生成器生成的,而生成器需要尽量骗过判别器,以达到以假乱真的效果。

图1. GANs原理概念图

1.2 什么是生成模型?

图2展示了生成模型的工作原理。在GANs中,生成器首先从先验分布P(Z)中进行采样,(高斯分布或均匀分布等)。然后输入到生成器的神经网络中,进而输出到更高维的空间中。生成器的经历可以通过函数表示,该函数将先验分布映射到和真实样本空间一样大的空间里,得到更复杂的分布,即PG(x)。生成器要求PG(x)和真实数据的分布Pdata(x)越接近越好。

生成模型核心的问题是:怎样衡量生成器定义的分布和真实世界的分布之间的距离,并使该距离最小化?最小化距离的过程也是训练生成器的过程。与此同时,判别器可以帮助生成器进行训练,其目标就是进行二分类。

图2. 生成模型的工作原理

1.3 为什么生成器可以将分布收敛到真实的分布?

图3中蓝色曲线代表生成器生成的分布,绿色曲线代表真实数据的分布,D(x)代表判别器对输入数据的评价。

真实情况:在生成器的分布处,D(x)的值低;在判别器的分布处,D(x)的值高(如图3a所示)。

训练生成器:生成器会顺着D(x)的梯度向着更高评价处移动(如图3b所示)。

训练判别器:判别器会降低生成器所获得的评价(如图3c所示)。

训练生成器:此时生成器会继续向左移动,即向着D(x)值高的方向移动。

通过一系列交替迭代训练,希望会达到如图3d所示的理想的收敛状态(如图3d所示),即生成器所生成的分布和真实分布重叠,使得判别器无法再进行区分。

图3(a, b, c and d). 生成对抗网络的训练过程

1.4 存在问题

在绝大多数的情况下,生成数据和真实数据在空间中并没有重叠。

原因1:数据自身的性质决定的。真实数据的维度和图片本身的维度并不完全一致,真实数据的维度更偏向于低维流形。如图4(a)所示。真实数据与生成数据仅有两个点是重叠的,并不存在任何分布上的交集。

原因2:即使真实数据与生成数据是有重叠的,但训练过程是通过采样得到的。所以,只要我们采样的数据是有限的,则一定存在可以区分真实数据和生成数据的方法,即图4(b)中的黑线。

图4(a and b). 存在问题分析

二、理论方法介绍

2.1 Wasserstein GAN

为了使训练后得到的D(x)呈现出线性增长的形式,文章采用的解决方法是将jensen shannon divergence换成wasserstein distance。具体的做法是,在判别器中, 的得分更高,而的得分更低,与此同时,分类器能满足1- Lipschitz约束。

(1) Lipschitz 约束

输出函数的变化要小于等于常数K与输入数据变化的乘积。当K=1时,该约束被称为“1- Lipschitz”。

(2) Wasserstein distance (W-distance或推土机距离)

通过最简单的方式将下图中的P分布变为Q分布,文章中使用移动体积乘以移动距离的方法,使得最终付出的搬运是最少的。这样的方式称为推土机距离,即W-distance。

上述移动计划可以转化为图5 joint space中的矩阵,矩阵中的每个值代表P分布有多少应该移动到Q分布。其中,最优的移动计划就是最小化移动计划中的平均距离。

图5. 移动计划矩阵图

2.3 Improved WGAN

有关Improved WGAN的研究证明了在最优的距离差异下,两个点在Lipschitz 约束条件下的输出变化与输入变化相同。也就是说,输入一个生成器的生成数据,判别器可以判别它可以移动到哪个真实数据,判别器给生成器中生成数据的梯度是指向真实数据的。生成的分布只要顺着梯度,就可以收敛到真实数据的分布。在实际优化过程中,每个点的梯度越靠近1越好。具体过程如下图6所示。

图6. Improved WGAN原理图

训练过程为什么会收敛呢?本文是基于梯度的角度进行解释的,但WGAN的原始论文中是从两个分布并不是完全重叠的角度进行解释的。本文的研究给出了解释这个问题的全新角度。

三、Sobolev Wasserstein GAN (SWGAN)

能否找到更泛化的约束满足收敛的性质?作者通过一系列理论证明,提出一种新的约束 (Sobolev约束),该约束希望任何一个生成的数据和真实数据的连线的积分小于等于1。显然,Sobolev约束比Lipschitz 约束更加泛化。当采用Lipschitz 约束时,每一个点的梯度都小于等于1,即它们在连线上的积分都小于等于1;反之,当采用Sobolev约束时,两点之间在连线上的积分都小于等于1,但并不一定满足Lipschitz 约束。具体公式如图7所示。由此发现,SWGAN是GAN更泛化的形式。与此同时,作者的研究也证明了在Sobolev约束下,梯度具有较好的指向性(指向真实数据的梯度),指导每一个生成的数据收敛到真实数据。

图7. Sobolev Wasserstein GAN原理图

四、如何优化SWGAN算法?

SWGAN的计算方法:

约束条件:

那么约束就是限制Ωij≥0。由此,作者通过采用增广拉格朗日不等式正则化的方法,在训练判别器时,将正则化项(如下式)加入到损失中,进而对判别器进行惩罚计算,使其满足要求。

五、试验结果分析

5.1 一维数据模型

图7展示了synthetic data在一维训练集上的对比结果。上半部分是采用WGAN方法得到的结果,下半部分是采用SWGAN方法得到的结果。从图中可以看出,原始WGAN算法的优化过程非常缓慢,约束更加泛化的SWGAN方法的优化过程更快且稳定

图7. 一维数据训练对比结果

5.2 Level sets of the critic

为了证明SWGAN critic应该能够对更具挑战性的真实和虚假数据进行建模,并提供有意义的梯度。图8展示了在几个二维数据集分布上训练SWGAN critic以使其达到最优的平面图,从图可以看出,经过SWGAN模型分析后,分布具有良好的适配性。

图8. 二维数据集的可视化

5.3 对比训练结果

表1统计了CIFAR-10和Tiny-ImageNet上的FID和IS得分。图9绘制了CIFAR-10上FID的训练曲线,以显示不同GAN的训练过程。从表1和图9可以看出,SWGANs的效果普遍好于baseline模型。WGAN和SGAN在使用ALM或采样更多的插值点时,性能往往略好。但是,与SWGAN-AL和SWGAN-GP相比,它们的性能仍然没有足够的竞争力。这说明较大的采样量和ALM优化算法并不是提升SWGAN性能的关键因素,而是Sobolev二元性中的宽松约束导致了性能的提高,即更宽松的约束将简化约束优化问题,并导致更强的GAN模型。

图9. CIFAR-10的训练曲线

表1. GAN在CIFAR-10和Tiny-ImageNet上的性能对比

六、后续工作

(1) 作者发现了比Lipschitz 约束更泛化条件的约束,是否还存在其它更泛化的约束条件,在后续的工作中可以继续深入探究。

(2) 在真实世界的数据中,点与点之间不一定是一对一的对应关系,可能是多对一或一对多的关系。所以如何解决一对多、多对一或多对多的transport plan问题,也可以在后续工作中继续深入探究。

相关资料

论文原文和链接:

Minkai Xu, et al. Towards Generalized Implementation of Wasserstein Distance in GANs. AAAI Conference on Artificial Intelligence (AAAI), 2021.

Reference:

Marc G Bellemare, et al. The cramer distance as a solution to biased wasserstein gradients. arXiv preprint arXiv:1705.10743, 2017.

David Berthelot, et al. Began: Boundary equilibrium generative adversarial networks. arXiv preprint arXiv:1703.10717, 2017.

Tong Che, et al. Mode regularized generative adversarial networks. arXiv preprint arXiv:1612.02136, 2016.

Jia Deng, et al. Imagenet: A large-scale hierarchical image database. In CVPR. Ieee, 2009.

Farzan Farnia and David Tse. A convex duality framework for gans. In NIPS, 2018.

William Fedus, et al. Many paths to equilibrium: Gans do not need to decrease a divergence at every step. arXiv preprint arXiv:1710.08446, 2017.

Ian Goodfellow, et al. Generative adversarial nets. In NIPS, 2014.

Ian Goodfellow. Nips 2016 tutorial: Generative adversarial networks. arXiv preprint arXiv:1701.00160, 2016.

Ishaan Gulrajani, et al. Improved training of wasserstein gans. In NIPS, 2017.

Kaiming He, et al. Deep residual learning for image recognition. In CVPR, 2016.

Martin Heusel, et al. Gans trained by a two time-scale update rule converge to a local nash equilibrium. In NIPS, 2017.

Martin Heusel, et al. Gans trained by a two time-scale update rule converge to a nash equilibrium. arXiv preprint arXiv:1706.08500, 2017.

Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.

Naveen Kodali, et al. How to train your dragan. arXiv preprint arXiv:1705.07215, 2(4), 2017.

Naveen Kodali, et al. On convergence and stability of gans. arXiv preprint arXiv:1705.07215, 2017.

本文所引用的图片均来自讲者徐民凯的PPT.

标签: #alm算法