utkuozbulak / pytorch-cnn-visualizations

Pytorch implementation of convolutional neural network visualization techniques
MIT License
7.81k stars 1.49k forks source link

Image Reconstruction size is same as conv1 layer #108

Closed kbkartik closed 2 years ago

kbkartik commented 2 years ago

Hi,

Thanks for setting up this amazing repo! I am using the following custom network:

CNN( (seq_model): Sequential( (Conv2d_1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=same) (BN_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu_1): ReLU() (Maxpool2d_1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (Conv2d_2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=same) (BN_2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu_2): ReLU() (Maxpool2d_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (Conv2d_3): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=same) (BN_3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu_3): ReLU() (Maxpool2d_3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (Conv2d_4): Conv2d(256, 512, kernel_size=(5, 5), stride=(1, 1), padding=same) (BN_4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu_4): ReLU() (Maxpool2d_4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (Conv2d_5): Conv2d(512, 1024, kernel_size=(5, 5), stride=(1, 1), padding=same) (BN_5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu_5): ReLU() (Maxpool2d_5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (Flatten): Flatten(start_dim=1, end_dim=-1) (fc1): Linear(in_features=16384, out_features=128, bias=True) (Dropout): Dropout(p=0, inplace=False) (fcReLU): ReLU() (head): Linear(in_features=128, out_features=10, bias=True) ) )

I have adopted your code as follows:

` class Guided_backprop:

def __init__(self, model, utils_agent):
    self.model = model
    self.image_reconstruction = None # store R0
    self.activation_maps = []  # store f1, f2, ...
    for _, p in self.model.named_parameters():
        p.requires_grad = True
    self.model.eval()
    self.register_hooks()

def register_hooks(self):
    def first_layer_hook_fn(module, grad_in, grad_out):
        self.image_reconstruction = grad_in[0] 

    def forward_hook_fn(module, input, output):
        self.activation_maps.append(output)

    def backward_hook_fn(module, grad_in, grad_out):
        grad = self.activation_maps.pop() 
        # for the forward pass, after the ReLU operation, 
        # if the output value is positive, we set the value to 1,
        # and if the output value is negative, we set it to 0.
        grad[grad > 0] = 1 

        # grad_out[0] stores the gradients for each feature map,
        # and we only retain the positive gradients
        new_grad_in = grad * torch.clamp(grad_out[0], min=0.0)
        return (new_grad_in,)

    modules = []
    for module in self.model.seq_model.named_children():
        modules.append(module)

    # travese the modules,register forward hook & backward hook
    # for the ReLU
    for name, module in modules:
        if isinstance(module, nn.ReLU):
            module.register_forward_hook(forward_hook_fn)
            module.register_backward_hook(backward_hook_fn)

    # register backward hook for the first conv layer
    first_layer = modules[0][1]
    first_layer.register_backward_hook(first_layer_hook_fn)

def visualize(self, datapoint):
    def normalize(image):
        norm = (image - image.mean())/image.std()
        norm = norm * 0.1
        norm = norm + 0.5
        norm = norm.clip(0, 1)
        return norm

    input_image, _ = datapoint
    target_class = None
    input_image = input_image.unsqueeze(0).requires_grad_().to(device)
    model_output = self.model(input_image)
    self.model.zero_grad()
    pred_class = model_output.argmax().item()

    grad_target_map = torch.zeros(model_output.shape, dtype=torch.float, device=device)

    if target_class is not None:
        grad_target_map[0][target_class.argmax(0).item()] = 1
    else:
        grad_target_map[0][pred_class] = 1

    model_output.backward(grad_target_map)
    input_image = input_image.squeeze(0)
    result = self.image_reconstruction.data[0].permute(1,2,0)
    print("Img reconst", self.image_reconstruction.shape)
    result = normalize(result)
    gbp_result = wandb.Image(result.cpu().numpy(), caption='Guided BP Image')
    orig_img = wandb.Image(self.utils_agent.invTransf(input_image).cpu(), caption='Original Image')

    wandb.log({'Orig_img': orig_img, 'GBP_Result': gbp_result})`

My image reconstruction output size is: torch.Size([1, 64, 128, 128]) which is the same as the first conv1 layer. It seems it is computing the backprop until the input image. Any idea where the bug could be?

utkuozbulak commented 2 years ago

No clue at a first glance, but if the size is not right, you are probably looking at gradients at a wrong layer or hooking the wrong layer.