Open wygjy opened 1 year ago
I came across your issue after the same error, I presume: ValueError: Expected parameter probs (Tensor of shape (1, 128, 1, 28, 28)) of distribution LogitRelaxedBernoulli(probs: torch.Size([1, 128, 1, 28, 28])) to satisfy the constraint Interval(lower_bound=0.0, upper_bound=1.0), but found invalid values
i'm not sure but my hunch is this is due to a change in pytorch. i'm using pytorch v.1.13 and had that error. this library assumes an earlier version of both geoopt and pytorch, i think.
if you resolve the error, please share! hope my reply helps.
I think the reason for this is that the constructor for torch.distributions.RelaxedBernoulli
changed at some point. The decoder here is outputting logits, and to initialize a RelaxedBernoulli with logits, one must use the logits=
keyword argument:
So, I modified the following in pvae/models/vae.py
and training works:
class VAE(nn.Module):
...
def forward(self, x, K=1):
...
- px_z = self.px_z(*self.dec(zs))
+ temperature, decoder_output = self.dec(zs)
+ px_z = self.px_z(temperature, logits=decoder_output)
note, perhaps obviously, this assumes your px_z
(likelihood distribution) is parameterized by temperature
and logits
.
While inspecting the VAE structure, I noticed that the model applies an additional operation to the encoder's output and, furthermore, it also applies an operation to the decoder's output in the final step. This leads to a requirement for the decoder's output to be in the range of 0 to 1. However, I encounter an error when training the model. Could you please explain the reasons and purposes behind these two operations, and how to address this error? The training command I am using is: 'python3 pvae/main.py --model mnist --manifold PoincareBall --c 0.7 --latent-dim 2 --hidden-dim 600 --prior WrappedNormal --posterior WrappedNormal --dec Geo --enc Wrapped --lr 5e-4 --epochs 80 --save-freq 80 --batch-size 128 --iwae-samples 5000'.