Lightning-Universe / lightning-bolts

Toolbox of models, callbacks, and datasets for AI/ML researchers.
https://lightning-bolts.readthedocs.io
Apache License 2.0
1.69k stars 323 forks source link

Enable resnet-decoders for autoencoders to be used with non-squared, higher resolution inputs #556

Closed nmohr192 closed 3 years ago

nmohr192 commented 3 years ago

🚀 Feature

Resnet18/50 decoders should allow for decoding of originally non-square input. Also it seems they are optimized for smaller sized inputs instead of higher resolutions, which can be changed by increasing initial dense layer size.

Motivation

I have implemented a variational autoencoder on 224x512 sized images and the default implementation of the resnet18-decoder does not allow for non-squared inputs at all and even if fixed in line 295 of https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/autoencoders/components.py in the interpolate right after the dense layer, my example of higher resolution rectangular inputs produced blurry artifact lines exactly at the middle of the initial (4x4) pattern that appears in the first epochs (which is caused by the FIX reshape after the dense layer into (512, 4, 4) in line 321 of https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/autoencoders/components.py.

Pitch

I want the decoder constructor to accept an optional argument input_width (that can be set to input_height if it's None) to take into account the aspect ratio of inputs. Then you can skip the fixed initial reshape into 512,4,4 and instead have a greater initial dense layer (e.g. 512(self.inplanes) * (input_height//self.upscale_factor) * (input_width//self.upscale_factor)) and a reshape into respective shape (--> reshape: (512(self.inplanes) * (input_height//self.upscale_factor) * (input_width//self.upscale_factor))

Alternatives

None.

Additional context

View my branch resnet18-decoder-for-non-squared-input on the fork of your repository and it's latest two commits where I show how I implemented it in my own local project to work perfectly on 224x512 shaped inputs for the variational autoencoder. https://github.com/nmohr192/pytorch-lightning-bolts/commits/resnet18-decoder-for-non-squared-input

github-actions[bot] commented 3 years ago

Hi! thanks for your contribution!, great first issue!

nmohr192 commented 3 years ago

Found an issue with my implementation: ResnetDecoder gets extremely large (n_parameters) with high n_latent_dims with this approach. Now im using

gcd = np.gcd(input_height, input_width)
self.latent_transform_shape = (input_height // gcd, input_width // gcd) if input_height != input_width else (4,4)

and then having the dense layer and upsampling like this:

self.linear = nn.Linear(latent_dim, self.inplanes * self.latent_transform_shape[0] * self.latent_transform_shape[1])
self.upscale1 = Interpolate(size=(input_height // self.upscale_factor, input_width // self.upscale_factor))

and in forward:

x = x.view(x.size(0), 512 * self.expansion, self.latent_transform_shape[0], self.latent_transform_shape[1])
x = self.upscale1(x)

will commit later to own fork.

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.