深度强化学习实践(原书第2版)
上QQ阅读APP看书,第一时间看更新

3.7 示例:将GAN应用于Atari图像

几乎每本有关DL的书都使用MNIST数据集来展示DL功能,多年来,该数据集都变得无聊了,就像遗传研究人员的果蝇一样。为了打破这一传统,并添加更多乐趣,我尝试避免沿用以前的方法,而使用其他方法说明PyTorch。本章前面简要提到了GAN,它们是由伊恩·古德费洛(Ian Goodfellow)发明和推广的。本示例中将训练GAN生成各种Atari游戏的屏幕截图。

最简单的GAN架构有两个网络,第一个网络充当“欺骗者”(也称为生成器),另一个网络充当“侦探”(另一个名称是判别器)。两个网络相互竞争,生成器试图生成伪造的数据,这些数据使判别器也难以将它与原数据集区分开,判别器试图检测生成的数据样本。随着时间的流逝,两个网络都提高了技能,生成器生成越来越多的真实数据样本,而判别器发明了更复杂的方法来区分伪造的数据。

GAN的实际应用包括改善图像质量、逼真图像生成和特征学习。在本示例中,实用性几乎为零,但这将是一个很好的示例,可以说明对于相当复杂的模型而言,PyTorch代码可以很简洁。

整个示例代码在文件Chapter03/03_atari_gan.py中。这里将给出一些重要的代码,不包括import部分和常量声明:

070-02

此类是Gym游戏的包装器,其中包括以下几种转换:

  • 将输入图像的尺寸从210×160(标准Atari分辨率)调整为正方形尺寸64×64。
  • 将图像的颜色平面从最后一个位置移到第一个位置,以满足PyTorch卷积层的约定,该卷积层输入包含形状为通道、高度和宽度的张量。
  • 将图像从bytes转换为float

然后,定义两个nn.Module类:DiscriminatorGenerator。第一种将经过缩放的彩色图像作为输入,并通过应用五层卷积,再使用Sigmoid进行非线性变换将数据转换为数字。Sigmoid的输出被解释为:判别器认为输入图像来自真实数据集的概率。

Generator将随机数向量(隐向量)作为输入,并使用“转置卷积”操作(也称为deconvolution)将该向量转换为原始分辨率的彩色图像。这里不会介绍这些类,因为它们很冗长且与示例无关,你可以在完整的示例文件中找到它们。

我们让几个随机智能体同时玩Atari游戏,并将游戏截图作为输入。图3.6是输入数据的示例,它是由以下函数生成的:

071-01
072-01

图3.6 三种Atari游戏的屏幕截图示例

从提供的数组中对环境进行无限采样,发出随机动作,并在batch列表中记录观察结果。当批满足所需大小时,将图像归一化,将其转换为张量,然后从生成器中yield出来。由于其中一个游戏存在问题,因此需要检查观察值均值非零,以防止图像闪烁。

现在,我们看一下主函数,它包括准备模型并运行训练循环。

072-02

在此,我们处理命令行参数(只有一个可选参数--cuda,启用GPU计算模式),创建环境池并用包装器包装。该环境数组将传递给iterate_batches函数以生成训练数据。

072-03

上面的代码创建了几个类:一个Summary Writer、两个网络、一个损失函数和两个优化器。为什么是两个?因为这就是GAN训练的方式:要训练判别器,需要用适当的标签(1代表真实的,0代表伪造的)来向它展示真实和伪造的数据样本。在此过程中,仅更新判别器的参数。

此后,再次将真实和伪造样本都通过判别器,但是这次,所有样本的标签均为1,并且仅更新生成器的权重。第二遍告诉生成器如何欺骗判别器,并将真实样本与生成的样本混淆起来。

073-01

这段代码定义了数组(用于累积损失)、迭代器计数器以及带有真假标签的变量。

073-02

在训练循环开始前,生成一个随机向量并将其传递给Generator网络。

073-03

首先,通过两批数据来训练判别器,即分别应用于真实数据样本和生成的样本。我们需要在生成器的输出上调用detach()函数,以防止此次训练的梯度流入生成器(detach()tensor的方法,该方法可以复制张量而不与原始张量的操作关联)。

073-04

以上代码用于生成器的训练。将生成器的输出传递给判别器,但是现在不停止梯度。相反,我们将目标函数与True标签一起应用。它将使生成器向生成可欺骗判别器的样本的方向发展。

那是与训练相关的代码,接下来的两行代码会上报损失,并将图像样本输入给TensorBoard:

074-01

这个例子的训练是一个漫长的过程。在GTX 1080 GPU上,100次迭代大约需要40秒。最初,生成的图像完全是随机噪声,但是在经过1万~2万次迭代后,生成器变得越来越熟练,并且生成次图像越来越类似于真实游戏的屏幕截图。

经过4万~5万次训练迭代后(在GPU上几个小时),实验给出了以下图像(见图3.7)。

074-02

图3.7 生成器网络产生的样例图片