Closed nmohr192 closed 3 years ago
Hi! thanks for your contribution!, great first issue!
Found an issue with my implementation: ResnetDecoder gets extremely large (n_parameters) with high n_latent_dims with this approach. Now im using
gcd = np.gcd(input_height, input_width)
self.latent_transform_shape = (input_height // gcd, input_width // gcd) if input_height != input_width else (4,4)
and then having the dense layer and upsampling like this:
self.linear = nn.Linear(latent_dim, self.inplanes * self.latent_transform_shape[0] * self.latent_transform_shape[1])
self.upscale1 = Interpolate(size=(input_height // self.upscale_factor, input_width // self.upscale_factor))
and in forward:
x = x.view(x.size(0), 512 * self.expansion, self.latent_transform_shape[0], self.latent_transform_shape[1])
x = self.upscale1(x)
will commit later to own fork.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
🚀 Feature
Resnet18/50 decoders should allow for decoding of originally non-square input. Also it seems they are optimized for smaller sized inputs instead of higher resolutions, which can be changed by increasing initial dense layer size.
Motivation
I have implemented a variational autoencoder on 224x512 sized images and the default implementation of the resnet18-decoder does not allow for non-squared inputs at all and even if fixed in line 295 of https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/autoencoders/components.py in the interpolate right after the dense layer, my example of higher resolution rectangular inputs produced blurry artifact lines exactly at the middle of the initial (4x4) pattern that appears in the first epochs (which is caused by the FIX reshape after the dense layer into (512, 4, 4) in line 321 of https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/autoencoders/components.py.
Pitch
I want the decoder constructor to accept an optional argument
input_width
(that can be set to input_height if it's None) to take into account the aspect ratio of inputs. Then you can skip the fixed initial reshape into 512,4,4 and instead have a greater initial dense layer(e.g. 512(self.inplanes) * (input_height//self.upscale_factor) * (input_width//self.upscale_factor))
and a reshape into respective shape(--> reshape: (512(self.inplanes) * (input_height//self.upscale_factor) * (input_width//self.upscale_factor))
Alternatives
None.
Additional context
View my branch
resnet18-decoder-for-non-squared-input
on the fork of your repository and it's latest two commits where I show how I implemented it in my own local project to work perfectly on 224x512 shaped inputs for the variational autoencoder. https://github.com/nmohr192/pytorch-lightning-bolts/commits/resnet18-decoder-for-non-squared-input