相关文章推荐
暴走的小熊猫  ·  postgresql pivot - CSDN文库·  3 周前    · 
想出国的勺子  ·  sqlserver split 多行 - ...·  3 月前    · 
幸福的帽子  ·  How To Get ...·  6 月前    · 
独立的冲锋衣  ·  [复选框] ...·  9 月前    · 
失恋的稀饭  ·  Power BI 和 Excel ...·  1 年前    · 
黄小天 黄小天 翻译

通过PyTorch实现对抗自编码器

「大多数人类和动物学习是无监督学习。如果智能是一块蛋糕,无监督学习是蛋糕的坯子,有监督学习是蛋糕上的糖衣,而强化学习则是蛋糕上的樱桃。我们知道如何做糖衣和樱桃,但我们不知道如何做蛋糕。」


Facebook 人工智能研究部门负责人 Yann LeCun 教授在讲话中多次提及这一类比。对于无监督学习,他引用了「机器对环境进行建模、预测可能的未来、并通过观察和行动来了解世界如何运作的能力」。


深度生成模型(deep generative model)是尝试解决机器学习中无监督学习问题的技术之一。在此框架下,需要一个机器学习系统来发现未标记数据中的隐藏结构。深度生成模型在许多应用中有许多广泛的应用,如密度估计、图像/音频去噪、压缩、场景理解(scene understanding)、表征学习(representation learning)和半监督分类(semi-supervised classification)。


变分自编码器(Variational Autoencoder/VAE)使得我们可以在概率图形模型(probabilistic graphical model)的框架下将这个问题形式化,在此框架下我们可以最大化数据的对数似然值的下界。在本文中,我们将介绍一种最新开发的架构,即对抗自编码器(Adversarial Autoencoder),它由 VAE 启发,但它在数据到潜在维度的映射方式中(如果现在还不清楚,不要担心,我们将在本文中重新提到这个想法)有更大的灵活性。关于对抗自编码器最有趣的想法之一是如何通过使用对抗学习(adversarial learning)将先验分布(prior distribution)运用到神经网络的输出中。

如果想将深入了解 Pytorch 代码,请访问 GitHub repo(https://github.com/fducau/AAE_pytorch)。在本系列中,我们将首先介绍降噪自编码器和变分自编码器的一些背景,然后转到对抗自编码器,之后是 Pytorch 实现和训练过程以及 MNIST 数据集使用过程中一些关于消纠缠(disentanglement)和半监督学习的实验。


背景


降噪自编码器(DAE)


我们可在自编码器(autoencoder)的最简版本之中训练一个网络以重建其输入。换句话说,我们希望网络以某种方式学习恒等函数(identity function)f(x)= x。为了简化这个问题,我们将此条件通过一个中间层(潜在空间)施加于网络,这个中间层的维度远低于输入的维度。有了这个瓶颈条件,网络必须压缩输入信息。因此,网络分为两部分:「编码器」用于接收输入并创建一个「潜在」或「隐藏」的表征(representation);「解码器」使用这个中间表征,并重建输入。自编码器的损失函数称为「重建损失函数(reconstruction loss)」,它可以简单地定义为输入和生成样本之间的平方误差:

640.png

当输入标准化为在 [0,1] N 范围内时,另一种广泛使用的重建损失函数是交叉熵(cross-entropy loss)。


变分自编码器(VAE)


变分自编码器对如何构造隐藏表征施加了第二个约束。现在,潜在代码的先验分布由设计好的某概率函数 p(x)定义。换句话说,编码器不能自由地使用整个潜在空间,而是必须限制产生的隐藏代码,使其可能服从先验分布 p(x)。例如,如果潜在代码上的先验分布是具有平均值 0 和标准差 1 的高斯分布,则生成值为 1000 的潜在代码应该是不可能的。

这可以被看作是可以存储在潜在代码中的信息量的第二类正则化。这样做的好处是现在我们可以作为一个生成模型使用该系统。为了创建一个服从数据分布 p(x)的新样本,我们只需要从 p(z)进行采样,并通过解码器来运行该样本以重建一个新图像。如果不施加这种条件,则潜在代码在潜在空间中的分布是随意的,因此不可能采样出有效的潜在代码来直接产生输出。

为了强制执行此属性,将第二项以先验分布与编码器建立分布之间的 KL 散度(Kullback-Liebler divergence)的形式添加到损失函数中。由于 VAE 基于概率解释,所使用的重建损失函数是前面提到的交叉熵损失函数。把它们放在一起我们有:

640-2.png


640-3.png

其中 q(z|x) 是我们网络的编码器,p(z) 是施加在潜在代码上的先验分布。现在这个架构可以使用反向传播(backpropagation)联合训练。


对抗自编码器(AAE)


作为生成模型的对抗自编码器


变分自编码器的主要缺点之一是,除了少数分布之外,KL 散度项的积分不具有封闭形式的分析解法。此外,对于潜在代码 z 使用离散分布并不直接。这是因为通过离散变量的反向传播通常是不可能的,使得模型难以有效地训练。这篇论文介绍了在 VAE 环境中执行此操作的一种方法(https://arxiv.org/abs/1609.02200)。

对抗自编码器通过使用对抗学习(adversarial learning)避免了使用 KL 散度。在该架构中,训练一个新网络来有区分地预测样本是来自自编码器的隐藏代码还是来自用户确定的先验分布 p(z)。编码器的损失函数现在由重建损失函数与判别器网络(discriminator network)的损失函数组成。

图中显示了当我们在潜在代码中使用高斯先验(尽管该方法是通用的并且可以使用任何分布)时 AAE 的工作原理。最上面一行相当于 VAE。首先,根据生成网络 q(z|x) 抽取样本 z,然后将该样本发送到根据 z 产生 x' 的解码器。在 x 和 x' 之间计算重建损失函数,并且相应地通过 p 和 q 反向推导梯度,并更新其权重。

640-6.jpeg

图 1. AAE 的基本架构最上面一行是自编码器,而最下面一行是对抗网络,迫使到编码器的输出服从分布 p(z)。


在对抗正则化部分,判别器收到来自分布为 q(z|x)的 z 和来自真实先验 p(z) 的 z' 采样,并为每个来自 p(z)的样本附加概率。发生的损失函数通过判别器反向传播,以更新其权重。然后重复该过程,同时生成器更新其参数。

我们现在可以使用对抗网络(它是自编码器的编码器)的生成器产生的损失函数而不是 KL 散度,以便学习如何根据分布 p(z)生成样本。这种修改使我们能够使用更广泛的分布作为潜在代码的先验。

判别器的损失函数是

640-4.png

其中 m 是微批尺寸(minibatch size),z 由编码器生成,z' 是来自真实先验的样本。

对于对抗生成器,我们有

640-5.png

通过查看方程式和曲线,你应该明白,以这种方式定义的损失函数将强制判别器能够识别假样本,同时推动生成器欺骗判别器。


定义网络


在进入这个模型的训练过程之前,我们来看一下如何在 Pytorch 中实现我们现在所做的工作。对于编码器、解码器和判别器网络,我们将使用 3 个带有 ReLU 非线性函数与概率为 0.2 的 dropout 的 1000 隐藏状态层的简单前馈神经网络(feed forward neural network)。

在进入这个模型的训练过程之前,我们来看一下如何在 Pytorch 中实现我们现在所做的工作。对于编码器、解码器和判别器网络,我们将使用 3 个带有 ReLU 非线性函数与概率为 0.2 的 dropout 的 1000 隐藏状态层的简单前馈神经网络(feed forward neural network)。

#Encoderclass Q_net(nn.Module):      def __init__(self):        super(Q_net, self).__init__()        self.lin1 = nn.Linear(X_dim, N)        self.lin2 = nn.Linear(N, N)        self.lin3gauss = nn.Linear(N, z_dim)    def forward(self, x):        x = F.droppout(self.lin1(x), p=0.25, training=self.training)        x = F.relu(x)        x = F.droppout(self.lin2(x), p=0.25, training=self.training)        x = F.relu(x)        xgauss = self.lin3gauss(x)        return xgauss

# Decoderclass P_net(nn.Module):      def __init__(self):        super(P_net, self).__init__()        self.lin1 = nn.Linear(z_dim, N)        self.lin2 = nn.Linear(N, N)        self.lin3 = nn.Linear(N, X_dim)    def forward(self, x):        x = self.lin1(x)        x = F.dropout(x, p=0.25, training=self.training)        x = F.relu(x)        x = self.lin2(x)        x = F.dropout(x, p=0.25, training=self.training)        x = self.lin3(x)        return F.sigmoid(x)

# Discriminatorclass D_net_gauss(nn.Module):      def __init__(self):        super(D_net_gauss, self).__init__()        self.lin1 = nn.Linear(z_dim, N)        self.lin2 = nn.Linear(N, N)        self.lin3 = nn.Linear(N, 1)    def forward(self, x):        x = F.dropout(self.lin1(x), p=0.2, training=self.training)        x = F.relu(x)        x = F.dropout(self.lin2(x), p=0.2, training=self.training)        x = F.relu(x)        return F.sigmoid(self.lin3(x))

从这个定义可以注意到一些事情。首先,由于编码器的输出必须服从高斯分布,我们在最后一层不使用任何非线性定义。解码器的输出具有 S 形非线性,这是因为我们使用以其值在 0 和 1 范围内的标准化输入。判别器网络的输出仅为 0 和 1 之间的一个数字,表示来自真正先验分布的输入概率。

一旦网络的类(class)定义完成,我们创建每个类的实例并定义要使用的优化器。为了在编码器(这也是对抗网络的生成器)的优化过程中具有独立性,我们为网络的这一部分定义了两个优化器,如下所示:

torch.manual_seed(10)   Q, P = Q_net() = Q_net(), P_net(0)     # Encoder/Decoder  D_gauss = D_net_gauss()                # Discriminator adversarial  if torch.cuda.is_available():      Q = Q.cuda()    P = P.cuda()    D_cat = D_gauss.cuda()    D_gauss = D_net_gauss().cuda()# Set learning ratesgen_lr, reg_lr = 0.0006, 0.0008  # Set optimizatorsP_decoder = optim.Adam(P.parameters(), lr=gen_lr)   Q_encoder = optim.Adam(Q.parameters(), lr=gen_lr)   Q_generator = optim.Adam(Q.parameters(), lr=reg_lr)   D_gauss_solver = optim.Adam(D_gauss.parameters(), lr=reg_lr)


训练步骤


每个微批处理的架构的训练步骤如下:

1)通过编码器/解码器部分进行前向路径(forward path)计算,计算重建损失并更新编码器 Q 和解码器 P 网络的参数。

z_sample = Q(X)    X_sample = P(z_sample)    recon_loss = F.binary_cross_entropy(X_sample + TINY,                                        X.resize(train_batch_size, X_dim) + TINY)    recon_loss.backward()    P_decoder.step()    Q_encoder.step()

2)创建潜在表征 z = Q(x),并从先验函数的 p(z) 取样本 z',通过判别器运行每个样本,并计算分配给每个 (D(z) 和 D(z')) 的分数。

Q.eval()        z_real_gauss = Variable(torch.randn(train_batch_size, z_dim) * 5)   # Sample from N(0,5)    if torch.cuda.is_available():        z_real_gauss = z_real_gauss.cuda()    z_fake_gauss = Q(X)

3)计算判别器的损失函数,并通过判别器网络反向传播更新其权重。在代码中,

# Compute discriminator outputs and loss