greentfrapp / lucent

Lucid library adapted for PyTorch
Apache License 2.0
597 stars 89 forks source link

Grayscale Optimal stimuli #29

Closed arnaghosh closed 3 years ago

arnaghosh commented 3 years ago

Hello, I am trying to understand the working of Lucent a bit better with one of my models. My model is trained on grayscale images (grayscale version of natural images) but uses a VGG16 backbone for feature extraction. Therefore, it accepts 3-channel images just like other torchvision zoo models. When I run Lucent on certain units, it returns RGB optimal stimuli. I believe this is because VGG16 (pretrained) has its own implicit color processing filtering operations that the Lucent optimization framework leverages to return RGB optimal stimuli. However, given that I trained (finetuned) my model on grayscale images, I want to optimize for optimal stimuli in the same space. Is it possible to do this in Lucent? I tried a naive solution, i.e. adding the following transform: rgb2gray_tfo = lambda x: torch.tensordot(x[...,:3],torch.Tensor([0.2989, 0.5870, 0.1140]).cuda(),dims=1).unsqueeze(-1).expand_as(x) which should convert a RGB image to grayscale and repeat it in 3 channels to make the input suitable for passing to the network. However, the optimal Stimuli generated are just blank (gray) images. So, I am wondering if there's a solution to my problem. Thanks in advance. I must add that using Lucent for my project has been amazing so far.😄

arnaghosh commented 3 years ago

Adding a solution that seems to work for me.
I defined a new transform as below and included it in the list of transforms.

def grayscale_tfo(channels):
        assert channels==1 or channels==3, "Only 1 or 3 channels can be passed, currently {} passed".format(channels)
        def inner(image_t):
            return torchvision.transforms.functional.rgb_to_grayscale(image_t,num_output_channels=channels)
    return inner

Maybe it's worth adding it to the list of available transforms, in case it is useful to others?