vsitzmann / siren

Official implementation of "Implicit Neural Representations with Periodic Activation Functions"
MIT License
1.72k stars 247 forks source link

HyperNetwork (CNN+SIREN) on larger images? #15

Closed trougnouf closed 4 years ago

trougnouf commented 4 years ago

Does it seem feasible to train the CNN+SIREN scheme on larger images? The given example (train_img_neural_process.py) uses 32x32 pixel images, and I haven't managed to generate visually pleasing results on a more useful size (eg 256x256px).

Either the image quality is very poor, or the memory blows up when I try increasing the number of parameters.

Another example (train_img.py) uses the same SIREN network to generate larger images (512x512) so the SIREN shouldn't be the problem, but its weight generation process is. It seems that the output of the CNN is reduced to only 256 features, first in ConvImgEncoder, which does something like

torch.nn.Linear(1024,1)(torch.rand([batch_size, 256, 32, 32]).view(4,256,-1)).squeeze(-1).shape                                                                                      
torch.Size([batch_size, 256])

Then the HyperParameter network (FCBlock) does the same according to the hidden layer's (hyper) hidden features.

This seems to be very little when we are generating weights for a network which has 256*256 weights per layer. I guess it would work on 32x32px but it makes sense that this would not scale up.

Do you have any insight as how to generate the SIREN weights from the output of a CNN without blowing up memory?

alexanderbergman7 commented 4 years ago

Hi,

Thanks for the question, this is a good observation and something we've thought about a bit. As you mention, the capacity of the SIREN network itself (hypo-network?) is likely not the problem as similar size SIRENs are capable of perfectly reconstructing high-resolution images.

We've previously ran experiments using an auto-decoder architecture for generalization: instead of an encoder which maps input images to a latent space, assign each image in the training set a latent code which we jointly optimize with the hypernetwork. In this case, the generalization scheme was able to reconstruct high-resolution images with higher fidelity, but the latent space learned was not especially useful as a prior. This suggests that the encoder is likely the bottleneck - the encoder provides the ability to trade-off some training reconstruction quality for a latent space which can generalize to unseen images.

I'm not entirely sure on how to improve the convolutional encoder without blowing up the memory cost. Have you tried the set encoder instead of the convolutional encoder? The set encoder tended to "under-fit" and blur images, but perhaps its variance with respect to the input size is not as much.

trougnouf commented 4 years ago

Thank you @alexanderbergman7 ! Sorry I meant to respond earlier but I haven't experimented with this as much as I hoped to report results. I was unable to train a CNN to generate better hyperparameters just by increasing the number of parameters or changing the CNN (or hypernetwork)'s architecture. Thank you for the suggestion to use a linear network all along, I have not yet tried that. It is especially interesting that this hypernetwork architecture is able to reconstruct high-quality images given an optimized latent code.