
假设y为真实分布,x为预测分布。这个函数的正确打开方式应该是下面这样。
import torch.nn.functional as F kl = F.kl_div(x.softmax(dim=-1).log(), y.softmax(dim=-1), reduction='sum')
这里有一些细节需要注意,第一个参数与第二个参数都要进行softmax(dim=-1),目的是使两个概率分布的所有值之和都为1,若不进行此操作,如果x或y概率分布所有值的和大于1,则可能会使计算的KL为负数。softmax接收一个参数dim,dim=-1表示在最后一维进行softmax操作。除此之外,第一个参数还要进行log()操作(至于为什么,大概是为了方便pytorch的代码组织,pytorch定义的损失函数都调用handle_torch_function函数,方便权重控制等),才能得到正确结果。 — 转载自知乎-KL散度理解以及使用pytorch计算KL散度。