栏目分类:
子分类:
返回
终身学习网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
终身学习网 > IT > 软件开发 > 后端开发 > Python

变分自编码器

Python 更新时间:发布时间: 百科书网 趣学号

 变分自编码器(vae)这个东西知道很久了,不过一直理解不是很深刻,现在总结一下查阅到的文档,同时记录一下自己的一些问题。

1. pytorch实现
class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.relu(self.linear2(x))


class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.relu(self.linear2(x))


class VAE(torch.nn.Module):
    latent_dim = 8

    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self._enc_mu = torch.nn.Linear(100, 8)
        self._enc_log_sigma = torch.nn.Linear(100, 8)

    def _sample_latent(self, h_enc):
        """
        Return the latent normal sample z ~ N(mu, sigma^2)
        """
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()

        self.z_mean = mu
        self.z_sigma = sigma

        return mu + sigma * Variable(std_z, requires_grad=False)  # Reparameterization trick

    def forward(self, state):
        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)


def latent_loss(z_mean, z_stddev):
    mean_sq = z_mean * z_mean
    stddev_sq = z_stddev * z_stddev
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)
  
criterion = nn.MSELoss()
dec = vae(inputs)
loss = criterion(dec, inputs) + latent_loss(vae.z_mean, vae.z_sigma) # 重建损失与kl距离
2. 记录点

X(样本空间),Z(latent space)

2.1 保证Z满足标准高斯分布(独立,多元),如何保证呢?

 只要保证 p ( Z ∣ X ) p(Z|X) p(Z∣X)满足 N ( 0 , I ) mathcal{N}(0,I) N(0,I)即可,这就是latent_loss为何要这么设计。
p ( Z ) = ∑ X p ( Z ∣ X ) p ( X ) = ∑ X N ( 0 , I ) p ( X ) = N ( 0 , I ) ∑ X p ( X ) = N ( 0 , I ) p(Z)=sum_X p(Z|X)p(X)=sum_X mathcal{N}(0,I)p(X)=mathcal{N}(0,I) sum_X p(X) = mathcal{N}(0,I) p(Z)=X∑​p(Z∣X)p(X)=X∑​N(0,I)p(X)=N(0,I)X∑​p(X)=N(0,I)

2.2 latent loss的推导

 由于我们考虑的是各分量独立的多元正态分布,因此只需要推导一元正态分布的情形即可,
K L ( N ( μ , σ 2 ) ∥ N ( 0 , 1 ) ) = ∫ 1 2 π σ 2 e − ( x − μ ) 2 / 2 σ 2 ( log ⁡ e − ( x − μ ) 2 / 2 σ 2 / 2 π σ 2 e − x 2 / 2 / 2 π ) d x = ∫ 1 2 π σ 2 e − ( x − μ ) 2 / 2 σ 2 log ⁡ { 1 σ 2 exp ⁡ { 1 2 [ x 2 − ( x − μ ) 2 / σ 2 ] } } d x = 1 2 ∫ 1 2 π σ 2 e − ( x − μ ) 2 / 2 σ 2 [ − log ⁡ σ 2 + x 2 − ( x − μ ) 2 / σ 2 ] d x begin{aligned}&KLBig(N(mu,sigma^2)BigVert N(0,1)Big)\ =&int frac{1}{sqrt{2pisigma^2}}e^{-(x-mu)^2/2sigma^2} left(log frac{e^{-(x-mu)^2/2sigma^2}/sqrt{2pisigma^2}}{e^{-x^2/2}/sqrt{2pi}}right)dx\ =&int frac{1}{sqrt{2pisigma^2}}e^{-(x-mu)^2/2sigma^2} log left{frac{1}{sqrt{sigma^2}}expleft{frac{1}{2}big[x^2-(x-mu)^2/sigma^2big]right} right}dx\ =&frac{1}{2}int frac{1}{sqrt{2pisigma^2}}e^{-(x-mu)^2/2sigma^2} Big[-log sigma^2+x^2-(x-mu)^2/sigma^2 Big] dxend{aligned} ===​KL(N(μ,σ2)∥∥∥​N(0,1))∫2πσ2 ​1​e−(x−μ)2/2σ2(loge−x2/2/2π ​e−(x−μ)2/2σ2/2πσ2 ​​)dx∫2πσ2 ​1​e−(x−μ)2/2σ2log{σ2 ​1​exp{21​[x2−(x−μ)2/σ2]}}dx21​∫2πσ2 ​1​e−(x−μ)2/2σ2[−logσ2+x2−(x−μ)2/σ2]dx​
整个结果分为三项积分,第一项实际上就是 − log ⁡ σ 2 -log sigma^2 −logσ2乘以概率密度的积分(也就是1),所以结果是 − log ⁡ σ 2 -log sigma^2 −logσ2;第二项实际是正态分布的二阶矩,熟悉正态分布的朋友应该都清楚正态分布的二阶矩为 μ 2 + σ 2 mu^2+sigma^2 μ2+σ2;而根据定义,第三项实际上就是“-方差除以方差=-1”。所以总结果就是:
K L ( N ( μ , σ 2 ) ∥ N ( 0 , 1 ) ) = 1 2 ( − log ⁡ σ 2 + μ 2 + σ 2 − 1 ) KLBig(N(mu,sigma^2)BigVert N(0,1)Big)=frac{1}{2}Big(-log sigma^2+mu^2+sigma^2-1Big) KL(N(μ,σ2)∥∥∥​N(0,1))=21​(−logσ2+μ2+σ2−1)

2.3 Evidence Lower Bound(ELBO)推导

 2.2部分是kl散度的推导,不过VAE的整个损失并不是只有这个,其损失函数是被称为ELBO的一个东西。因为我们想将Z变成一个 N ( 0 , I ) mathcal{N}(0,I) N(0,I)的分布,而我们又只有X,那么我们要做的就是使得 K L ( Q ( Z ) ∥ P ( Z ∣ X ) ) KLBig(Q(Z)BigVert P(Z|X)Big) KL(Q(Z)∥∥∥​P(Z∣X))最小化。
K L ( Q ( Z ) ∥ P ( Z ∣ X ) ) = E Z ∼ Q ( l o g Q ( Z ) − l o g P ( Z ∣ X ) ) = E Z ∼ Q ( l o g Q ( Z ) − l o g P ( X ∣ Z ) − l o g P ( Z ) ) + l o g P ( X ) begin{aligned}&KLBig(Q(Z)BigVert P(Z|X)Big)\=&E_{Zsim Q}Big(logQ(Z)-logP(Z|X)Big)\=&E_{Zsim Q}Big(logQ(Z)-logP(X|Z)-logP(Z)Big)+logP(X)end{aligned} ==​KL(Q(Z)∥∥∥​P(Z∣X))EZ∼Q​(logQ(Z)−logP(Z∣X))EZ∼Q​(logQ(Z)−logP(X∣Z)−logP(Z))+logP(X)​

 移项整理得到:
l o g P ( X ) − K L ( Q ( Z ) ∥ P ( Z ∣ X ) ) = E Z ∼ Q ( l o g P ( X ∣ Z ) ) − K L ( Q ( Z ) ∥ P ( Z ) ) logP(X)-KLBig(Q(Z)big Vert P(Z|X)Big)=E_{Zsim Q}Big(logP(X|Z)Big)-KLBig(Q(Z)Big Vert P(Z) Big) logP(X)−KL(Q(Z)∥∥​P(Z∣X))=EZ∼Q​(logP(X∣Z))−KL(Q(Z)∥∥∥​P(Z))
 将 Q ( Z ) Q(Z) Q(Z)替换为 Q ( Z ∣ X ) Q(Z|X) Q(Z∣X)得到:
l o g P ( X ) − K L ( Q ( Z ∣ X ) ∥ P ( Z ∣ X ) ) = E Z ∼ Q ( l o g P ( X ∣ Z ) ) − K L ( Q ( Z ∣ X ) ∥ P ( Z ) ) logP(X)-KLBig(Q(Z|X)big Vert P(Z|X)Big)=E_{Zsim Q}Big(logP(X|Z)Big)-KLBig(Q(Z|X)Big Vert P(Z) Big) logP(X)−KL(Q(Z∣X)∥∥​P(Z∣X))=EZ∼Q​(logP(X∣Z))−KL(Q(Z∣X)∥∥∥​P(Z))
 显然,左边两项就是我们要优化的项,左边两项越大越好。而右边两项则是可以计算的,右边第一项相当于一个decoder,而第二项相当于一个encoder,它也对应于2.2部分的推导,右边部分即被称为Evidence Lower Bound,一般在讨论VAE的时候我们用ELBO来指代它的cost function

2.4 KL vanish

KL vanish的出现可以从公式4得知,如果ZX相互独立,即X完全不依赖于Z,那么右边第二项KL损失就可以被优化为0,而仅仅第一项起作用,这时候KL就发生了vanish

参考资料:

  1. 变分自编码器(一):原来是这么一回事
  2. 重参数化技巧
  3. ELBO
转载请注明:文章转载自 www.051e.com
本文地址:http://www.051e.com/it/272770.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 ©2023-2025 051e.com

ICP备案号:京ICP备12030808号