greentfrapp / lucent

Lucid library adapted for PyTorch
Apache License 2.0
597 stars 89 forks source link

Input with more than 3 channels #25

Open vincent341 opened 3 years ago

vincent341 commented 3 years ago

Thanks so much for this repo. Does lucent support an input with more than 3 channels? For example, an input with (4 channels = 3 channels RGB + 1 channel semantic channel). If it is possible, is it also possible to optimize 3 RGB channels, with the semantic channel fixed?

urancon commented 3 years ago

Hi, thanks for the great repo ! I have a similar question.

How is it possible to optimize inputs of any shape ? For instance, I am working on a recurrent network whose inputs have two spatial dimensions, a channel dimension, and a time dimension as well, i.e., shape = [batchsize, T, 2, W, H]. Obviously, I get an IndexError when directly applying lucent to my model like in the tutorial.

I feel like it's possible to get such a tensor out of Lucent (instead of a regular 3-channel RGB image tensor) without to much changes in the code. Could you please give some insights on how to adapt Lucent's functions (e.g. I was thinking about render_vis or its param_f argument) to fit to my problem ?

I think this would help @vincent341 as well. Thanks for your reading !

greentfrapp commented 3 years ago

Thanks for the interest! Let me take a look at this over the weekend and I'll have an update!

vincent341 commented 3 years ago

Thanks for the interest! Let me take a look at this over the weekend and I'll have an update!

Thanks for your help! Seems more stars are being received by this great repo recently.

urancon commented 3 years ago

Thank you so much for your reply ! Looking forward to see that !

greentfrapp commented 3 years ago

Hi @vincent341 @urancon

If your model only receives n-channel input where n != 3, you probably want to disable preprocess and explicitly set transforms=[] (or your desired transformations). You should also initialize your own param_f input with channels=n.

Here's a code snippet that should help. The model here is simply a conv layer that takes a 4 channel input.

model = torch.nn.Sequential(torch.nn.Conv2d(4, 10, (3, 3))).to(device).eval()
param_f = lambda: param.image(128, channels=4)
# Alternatively, use pixel_image: param_f = lambda: param.pixel_image(shape=(1, 4, 128, 128))
_ = render.render_vis(model, "0:0", param_f, transforms=[], preprocess=False, show_inline=True)
urancon commented 3 years ago

Hi @greentfrapp , thank you for your help, that's great !

So that fixes our problem for 3 dimensional images [N (batch), C, H, W] with C != 3. Can you explain why transformations are used in Lucent ? I mean, do you think there would be a loss in the quality of visualizations without them ? I have the same question for the fft_image function instead of pixel_image.

To optimize a 4-dimensional input image (e.g. with an added "time" dimension as in my problem, see my previous post) like [N (batch), T, C, H, W], I followed your advice and used param_f = lambda: pixel_image(shape=(1, 5, 2, 260, 346)). However, I had to ignore the addition of an upsample transform in line 77 of render.py https://github.com/greentfrapp/lucent/blob/31919072457f314f256755d11be8a87212ed2c69/lucent/optvis/render.py#L77-L80 because the previous size checks do not fit with the shape of my input https://github.com/greentfrapp/lucent/blob/31919072457f314f256755d11be8a87212ed2c69/lucent/optvis/render.py#L69-L76

To avoid this unnecessary upsampling, I suggest changing line 73 https://github.com/greentfrapp/lucent/blob/31919072457f314f256755d11be8a87212ed2c69/lucent/optvis/render.py#L73 to: elif image_shape[-2] < 224 or image_shape[-1] < 224: so that the last two dimensions of the input tensor are its spatial dimensions (x and y).

This is just a suggestion, not a resquest ! I think it would be great to have Lucent be as general as possible; I'm very interested and I think it's a great tool ! :) Thank again anyways for your help !

vincent341 commented 3 years ago

@greentfrapp Thanks so much for your great help. Let me give it a try!

greentfrapp commented 3 years ago

@urancon Lucent was heavily based on the Lucid library and designed primarily for CNNs with regular inputs i.e. shape = (batch, channel, height, width). The transforms do make a difference in some cases and the defaults were there to help with visualizations without having to tune the transforms. Likewise, the FFT parameterization generally gave better results in the regular CNN setting, as compared to regular pixel parameterization. (Some details here!).

That being said, I would love to help extend Lucid to more general use cases too! I have also received questions about other models previously (e.g. 3D convolutions). I'll take a closer look at your suggestions and also perhaps refactor the code to support more cases.

urancon commented 3 years ago

Ok, it seems much clearer to me now, thank you ! I definitely have to re-read this article.

Yes, I think that supporting another dimension (e.g. a third spatial dimension as you mentioned or a temporal one) would surely make Lucent stand out ! Just to let you know, what I see myself doing with Lucent and my temporal model, is to use a [batch, time, channel, height, width] input tensor and visualize it as a Gif or video of time frames, instead of a static image.

Thanks anyways for the help and the interesting discussion, I might return with some more questions ! :)