emilemathieu / pvae

code for "Continuous Hierarchical Representations with Poincaré Variational Auto-Encoders".
https://arxiv.org/abs/1901.06033
MIT License
123 stars 43 forks source link

Why is the decoder's output range required to be from 0 to 1 in MNIST? #18

Open wygjy opened 1 year ago

wygjy commented 1 year ago

While inspecting the VAE structure, I noticed that the model applies an additional operation to the encoder's output and, furthermore, it also applies an operation to the decoder's output in the final step. This leads to a requirement for the decoder's output to be in the range of 0 to 1. However, I encounter an error when training the model. Could you please explain the reasons and purposes behind these two operations, and how to address this error? The training command I am using is: 'python3 pvae/main.py --model mnist --manifold PoincareBall --c 0.7 --latent-dim 2 --hidden-dim 600 --prior WrappedNormal --posterior WrappedNormal --dec Geo --enc Wrapped --lr 5e-4 --epochs 80 --save-freq 80 --batch-size 128 --iwae-samples 5000'.

grisaitis commented 1 year ago

I came across your issue after the same error, I presume: ValueError: Expected parameter probs (Tensor of shape (1, 128, 1, 28, 28)) of distribution LogitRelaxedBernoulli(probs: torch.Size([1, 128, 1, 28, 28])) to satisfy the constraint Interval(lower_bound=0.0, upper_bound=1.0), but found invalid values

i'm not sure but my hunch is this is due to a change in pytorch. i'm using pytorch v.1.13 and had that error. this library assumes an earlier version of both geoopt and pytorch, i think.

if you resolve the error, please share! hope my reply helps.

grisaitis commented 1 year ago

I think the reason for this is that the constructor for torch.distributions.RelaxedBernoulli changed at some point. The decoder here is outputting logits, and to initialize a RelaxedBernoulli with logits, one must use the logits= keyword argument:

So, I modified the following in pvae/models/vae.py and training works:

class VAE(nn.Module):
...
    def forward(self, x, K=1):
...
-        px_z = self.px_z(*self.dec(zs))
+        temperature, decoder_output = self.dec(zs)
+        px_z = self.px_z(temperature, logits=decoder_output)

note, perhaps obviously, this assumes your px_z (likelihood distribution) is parameterized by temperature and logits.