Gumbel-Softmax完全解析
写在前面
本文对大部分人来说可能仅仅起到科普的作用,因为Gumbel-Max仅在部分领域会用到,例如GAN、VAE等。笔者是在研究EMNLP上的一篇论文时,看到其中有用Gumbel-Softmax公式解决对一个概率分布进行采样无法求导的问题,故想到对Gumbel-Softmax做一个总结,由此写下本文
为什么我们需要Gumbel-Softmax ?
假设现在我们有一个离散随机变量 Z 的分布
$$ p_1 = p(Z=1)=\pi_1\\ p_2 = p(Z=2) = \pi_2\\ p_3 = p(Z=3) = \pi_3\\ ...\\ p_x = p(Z=x) = \pi_x\\ $$
其中, \sum_i \pi_i=1 。我们想根据 p_1,p_2,...,p_x 的概率采样得到一系列离散 z 的值。但是这么做有一个问题,我们采样出来的 z 只有值,没有生成 z 的式子。例如我们要求 Z 的期望,那么就有公式
Z 对 p_1,p_2,...,p_x 的导数都很清楚。但是现在我们的需求是采样一些具体的 z 值,采样这个操作没有任何公式,因此也就无法求导。于是一个很自然的想法就产生了,我们能不能给一个以 p_1,p_2,...,p_z 为参数的公式,让这个公式返回的结果是 z 采样的结果呢?
Gumbel-Softmax
一般来说 \pi_i 是通过神经网络预测对于类别 i 的概率,这在分类问题中非常常见,假设我们将一个样本送入模型,最后输出的概率分布为 [0.2, 0.4,0.1,0.2,0.1] ,表明这是一个5分类问题,其中概率最大的是第2类,到这一步,我们直接通过argmax就能获得结果了,但现在我们不是预测问题,而是一个采样问题。对于模型来说,直接取出概率最大的就可以了,但对我们来说,每个类别都是有一定概率的,我们想根据这个概率来进行采样,而不是直接简单无脑的输出概率最大的值
最常见的采样 \mathbf{z} 的onehot公式为
其中 i=1,2,..,x 是类别的下标,随机变量 u 服从均匀分布 U(0,1)
上面这个过程实际上是很巧妙的,我们将概率分布从前往后不断加起来,当加到 \pi_i 时超过了某个随机值 0\leq u \leq 1 ,那么这一次随机采样过程, z 就被随机采样为第 i 类,最后通过一个onehot变换
但是上述公式存在一个致命的问题:max函数是不可导的
Gumbel-Max Trick
Gumbel-Max技巧就是解决max函数不可导问题的,我们可以用argmax替换max,即
其中, g_i=-\log(-\log(u_i)), u_i \sim U(0,1) ,这一项名为Gumbel噪声,或者叫Gumbel分布,目的是使得 \mathbf{z} 的返回结果不固定
可以看到式 (2) 的整个过程中,不可导的部分只有argmax,实际上我们可以用可导的softmax函数,在参数 \tau 的控制下逼近argmax,最终 z_i 的公式为
其中, \tau 越小 (\tau \to 0) ,整个softmax越光滑逼近argmax,并且 \mathbf{z} = \{z_i\mid i=1,2,...,x\} 也越接近onehot向量; \tau 越大 (\tau \to \infty) , \mathbf{z} 向量越接近于均匀分布
总结
整个过程相当于我们把不可导的取样过程,从 \mathbf{z} 本身转移到了求 \mathbf{z} 的公式中的一项 g_i 中,而 g_i 本身不依赖 p_1,..,p_x ,所以 z 对 p_1,...,p_x 就可以到了,而且我们得到的 \mathbf{z} 仍然是离散概率分布的采样。这种采样过程转嫁的技巧有一个专有名词,叫重参数化技巧(Reparameterization Trick)
References
原始发表:13, 2021, 如有侵权请联系 cloudcommunity@tencent.com 删除
社区
活动
资源
关于
腾讯云开发者
扫码关注腾讯云开发者
领取腾讯云代金券
热门产品
热门推荐
更多推荐
Copyright © 2013 - 2023 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号: 粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287