3.7 示例:将GAN应用于Atari图像
几乎每本有关DL的书都使用MNIST数据集来展示DL功能,多年来,该数据集都变得无聊了,就像遗传研究人员的果蝇一样。为了打破这一传统,并添加更多乐趣,我尝试避免沿用以前的方法,而使用其他方法说明PyTorch。本章前面简要提到了GAN,它们是由伊恩·古德费洛(Ian Goodfellow)发明和推广的。本示例中将训练GAN生成各种Atari游戏的屏幕截图。
最简单的GAN架构有两个网络,第一个网络充当“欺骗者”(也称为生成器),另一个网络充当“侦探”(另一个名称是判别器)。两个网络相互竞争,生成器试图生成伪造的数据,这些数据使判别器也难以将它与原数据集区分开,判别器试图检测生成的数据样本。随着时间的流逝,两个网络都提高了技能,生成器生成越来越多的真实数据样本,而判别器发明了更复杂的方法来区分伪造的数据。
GAN的实际应用包括改善图像质量、逼真图像生成和特征学习。在本示例中,实用性几乎为零,但这将是一个很好的示例,可以说明对于相当复杂的模型而言,PyTorch代码可以很简洁。
整个示例代码在文件Chapter03/03_atari_gan.py
中。这里将给出一些重要的代码,不包括import
部分和常量声明:
此类是Gym游戏的包装器,其中包括以下几种转换:
- 将输入图像的尺寸从210×160(标准Atari分辨率)调整为正方形尺寸64×64。
- 将图像的颜色平面从最后一个位置移到第一个位置,以满足PyTorch卷积层的约定,该卷积层输入包含形状为通道、高度和宽度的张量。
- 将图像从
bytes
转换为float
。
然后,定义两个nn.Module
类:Discriminator
和Generator
。第一种将经过缩放的彩色图像作为输入,并通过应用五层卷积,再使用Sigmoid
进行非线性变换将数据转换为数字。Sigmoid
的输出被解释为:判别器认为输入图像来自真实数据集的概率。
Generator
将随机数向量(隐向量)作为输入,并使用“转置卷积”操作(也称为deconvolution
)将该向量转换为原始分辨率的彩色图像。这里不会介绍这些类,因为它们很冗长且与示例无关,你可以在完整的示例文件中找到它们。
我们让几个随机智能体同时玩Atari游戏,并将游戏截图作为输入。图3.6是输入数据的示例,它是由以下函数生成的:
图3.6 三种Atari游戏的屏幕截图示例
从提供的数组中对环境进行无限采样,发出随机动作,并在batch
列表中记录观察结果。当批满足所需大小时,将图像归一化,将其转换为张量,然后从生成器中yield
出来。由于其中一个游戏存在问题,因此需要检查观察值均值非零,以防止图像闪烁。
现在,我们看一下主函数,它包括准备模型并运行训练循环。
在此,我们处理命令行参数(只有一个可选参数--cuda
,启用GPU计算模式),创建环境池并用包装器包装。该环境数组将传递给iterate_batches
函数以生成训练数据。
上面的代码创建了几个类:一个Summary Writer
、两个网络、一个损失函数和两个优化器。为什么是两个?因为这就是GAN训练的方式:要训练判别器,需要用适当的标签(1代表真实的,0代表伪造的)来向它展示真实和伪造的数据样本。在此过程中,仅更新判别器的参数。
此后,再次将真实和伪造样本都通过判别器,但是这次,所有样本的标签均为1,并且仅更新生成器的权重。第二遍告诉生成器如何欺骗判别器,并将真实样本与生成的样本混淆起来。
这段代码定义了数组(用于累积损失)、迭代器计数器以及带有真假标签的变量。
在训练循环开始前,生成一个随机向量并将其传递给Generator
网络。
首先,通过两批数据来训练判别器,即分别应用于真实数据样本和生成的样本。我们需要在生成器的输出上调用detach()
函数,以防止此次训练的梯度流入生成器(detach()
是tensor
的方法,该方法可以复制张量而不与原始张量的操作关联)。
以上代码用于生成器的训练。将生成器的输出传递给判别器,但是现在不停止梯度。相反,我们将目标函数与True标签一起应用。它将使生成器向生成可欺骗判别器的样本的方向发展。
那是与训练相关的代码,接下来的两行代码会上报损失,并将图像样本输入给TensorBoard:
这个例子的训练是一个漫长的过程。在GTX 1080 GPU上,100次迭代大约需要40秒。最初,生成的图像完全是随机噪声,但是在经过1万~2万次迭代后,生成器变得越来越熟练,并且生成次图像越来越类似于真实游戏的屏幕截图。
经过4万~5万次训练迭代后(在GPU上几个小时),实验给出了以下图像(见图3.7)。
图3.7 生成器网络产生的样例图片