masa-su / pixyz

A library for developing deep generative models in a more concise, intuitive and extendable way
https://pixyz.io
MIT License
491 stars 41 forks source link

Add Analytical Entropy #73

Closed TMats closed 5 years ago

TMats commented 5 years ago

We may want an analytical version of entropy and cross-entropy.

Question: Is it better to use SetLoss instead of Loss? (I don't know it is suitable for this case.)

TMats commented 5 years ago

something is wrong with batch

In [33]: loc = torch.tensor(0.)
    ...: scale = torch.tensor(1.)
    ...: prior = Normal(loc=loc, scale=scale, var=["z"], dim=64, name="p_prior")

In [34]: Entropy(prior).eval()
Out[34]: tensor([87.5932])

In [35]: AnalyticalEntropy(prior).eval()
Out[35]: tensor(1.4189)

if there is a batch axis, it seems with no problem

In [47]: from pixyz.distributions import Bernoulli, Normal
    ...: # inference model (encoder) q(z|x)
    ...: class Inference(Normal):
    ...:     def __init__(self):
    ...:         super(Inference, self).__init__(cond_var=["x"], var=["z"], name="q")  # var: variabl
    ...: es of this distribution, cond_var: coditional variables.
    ...:
    ...:         self.fc1 = nn.Linear(784, 512)
    ...:         self.fc2 = nn.Linear(512, 512)
    ...:         self.fc31 = nn.Linear(512, 64)
    ...:         self.fc32 = nn.Linear(512, 64)
    ...:
    ...:     def forward(self, x):  # the name of this argument should be same as cond_var.
    ...:         h = F.relu(self.fc1(x))
    ...:         h = F.relu(self.fc2(h))
    ...:         return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}  # return paramaters
    ...:  of the normal distribution
    ...:

In [48]: q = Inference()

In [49]: x = torch.zeros(10,784)

In [50]: Entropy(q).eval({"x":x})
Out[50]:
tensor([58.0688, 61.0638, 70.3537, 56.1003, 62.3366, 60.3948, 59.5125, 81.9610,
        69.0726, 65.0433], grad_fn=<NegBackward>)

In [51]: AnalyticalEntropy(q).eval({"x":x})
Out[51]:
tensor([67.4644, 67.4644, 67.4644, 67.4644, 67.4644, 67.4644, 67.4644, 67.4644,
        67.4644, 67.4644], grad_fn=<SumBackward2>)
TMats commented 5 years ago

If prior has the batch axis this problem won't happen (not related to this implementation)

>>> loc = torch.zeros([1,64])
>>> scale = torch.ones([1,64])
>>> prior = Normal(loc=loc, scale=scale, var=["z"], dim=64, name="p_prior")
>>> AnalyticalEntropy(prior).eval()
tensor([90.8121])
masa-su commented 5 years ago

Thank you for your pull request!

Question: Is it better to use SetLoss instead of Loss? (I don't know it is suitable for this case.)

I think you don't need to inherit SetLoss in this case because the entropy method is not implemented in Distribution API for now.

masa-su commented 5 years ago

@TMats Also, I would be glad if you could add your implemented class to the following docs. https://github.com/masa-su/pixyz/blob/develop/v0.1.0/docs/source/losses.rst