juliuskunze / cwvae-jax

Clockwork VAEs in JAX/Flax
MIT License
31 stars 3 forks source link

potential bug in the encoder #1

Closed xmax1 closed 3 years ago

xmax1 commented 3 years ago

for level in range(1, self.c.levels): for _ in range(self.c.enc_dense_layers - 1): x = nn.relu(nn.Dense(self.c.enc_dense_embed_size)(x)) if self.c.enc_dense_layers > 0: x = nn.Dense(feat_size)(x) layer = x

line 39 onwards in the cnn.py Encoder(), the depth of these layers increases with the level as the hidden variables is overwritten. At large n_levels and n_enc_dense_layers this will result in a very deep network mapping from the observation embedding to the latent space. Not sure it's intentional, doesn't seem to have a purpose, ie is there a reason the higher latent spaces need a deeper function to map from the embedding?

Same issue in the original tensorflow version https://github.com/vaibhavsaxena11/cwvae/issues/2

xmax1 commented 3 years ago

Ah, it's intentional for a ladder vae type structure. Closed.