Closed TMats closed 5 years ago
something is wrong with batch
In [33]: loc = torch.tensor(0.)
...: scale = torch.tensor(1.)
...: prior = Normal(loc=loc, scale=scale, var=["z"], dim=64, name="p_prior")
In [34]: Entropy(prior).eval()
Out[34]: tensor([87.5932])
In [35]: AnalyticalEntropy(prior).eval()
Out[35]: tensor(1.4189)
if there is a batch axis, it seems with no problem
In [47]: from pixyz.distributions import Bernoulli, Normal
...: # inference model (encoder) q(z|x)
...: class Inference(Normal):
...: def __init__(self):
...: super(Inference, self).__init__(cond_var=["x"], var=["z"], name="q") # var: variabl
...: es of this distribution, cond_var: coditional variables.
...:
...: self.fc1 = nn.Linear(784, 512)
...: self.fc2 = nn.Linear(512, 512)
...: self.fc31 = nn.Linear(512, 64)
...: self.fc32 = nn.Linear(512, 64)
...:
...: def forward(self, x): # the name of this argument should be same as cond_var.
...: h = F.relu(self.fc1(x))
...: h = F.relu(self.fc2(h))
...: return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))} # return paramaters
...: of the normal distribution
...:
In [48]: q = Inference()
In [49]: x = torch.zeros(10,784)
In [50]: Entropy(q).eval({"x":x})
Out[50]:
tensor([58.0688, 61.0638, 70.3537, 56.1003, 62.3366, 60.3948, 59.5125, 81.9610,
69.0726, 65.0433], grad_fn=<NegBackward>)
In [51]: AnalyticalEntropy(q).eval({"x":x})
Out[51]:
tensor([67.4644, 67.4644, 67.4644, 67.4644, 67.4644, 67.4644, 67.4644, 67.4644,
67.4644, 67.4644], grad_fn=<SumBackward2>)
If prior has the batch axis this problem won't happen (not related to this implementation)
>>> loc = torch.zeros([1,64])
>>> scale = torch.ones([1,64])
>>> prior = Normal(loc=loc, scale=scale, var=["z"], dim=64, name="p_prior")
>>> AnalyticalEntropy(prior).eval()
tensor([90.8121])
Thank you for your pull request!
Question: Is it better to use SetLoss instead of Loss? (I don't know it is suitable for this case.)
I think you don't need to inherit SetLoss
in this case because the entropy
method is not implemented in Distribution API for now.
@TMats Also, I would be glad if you could add your implemented class to the following docs. https://github.com/masa-su/pixyz/blob/develop/v0.1.0/docs/source/losses.rst
We may want an analytical version of entropy and cross-entropy.
Question: Is it better to use
SetLoss
instead ofLoss
? (I don't know it is suitable for this case.)