1202kbs / Understanding-NN

Tensorflow tutorial for various Deep Neural Network visualization techniques
MIT License
345 stars 91 forks source link

About how to use deconvolution guided-back on Batch normalization #4

Closed GlowingHorse closed 5 years ago

GlowingHorse commented 5 years ago

Hello, I read your amazing article about network interpretability methods. In fact, I try to visualize ResNet recently via Guided-Back. But I found there are some troubles when I handle the Batch Normalization layers. I would like to ask do you have some information about how to handle these layers? In deconvolution papers, I cannot found methods about Batch Normalization layers. I use auto-gradient to visualize but got worse visualization results.

1202kbs commented 5 years ago

Hi, thank you for the kind comment :)

Batch normalization layers are handled in the same way as vanilla gradient, so registering the Guided Backpropagation ReLU gradient and calling tf.gradients() should be fine. Can you specifically describe what kind of problem you had? A description of the visualization or the error message would be helpful.

GlowingHorse commented 5 years ago

Hello, @1202kbs

Thanks for your reply. I would like to describe this problem in details.

As you know, there are four kinds of layers in ResNet:

  1. convolution layer, which has been solved by any deconvolution visualization methods.
  2. Pooling layer, which can be set as paper of deconvolutional networks
  3. Relu layer, which can be set as paper Guided-Back
  4. Batch Norm layer, I just try three ways to compute gradients in the layer, but both of them are not good to generate good visualization results.

Say, dzdx[l] dzdx[l-1] x[l] x[l-1] as one BN layer's inputs and outputs.

First, I try to use original gradient computing method (Built-in function in most of framework)in BN layer.

Second, I change the backward computing method to use identity value ( a) set one feature map of x[l] into a new all zeros matrix dzdx[l], b) at BN layer, just make dzdx[l-1] equal to dzdx[l]. c) continue the backward propagation.)

Third, BN is x[l] = moments_1 (x[l-1] - mean) / var + moments_2. Then I use the reverse version of it. dzdx[l-1] = (dzdx[l] - moments_2) / moments_1 var + mean.

But all those setting of BN layers cannot help me generate a good visualization results like generated in VGG-16. I just want to know if there some good setting for BN layers to generate good visualization as Guided-Back applied in VGG-16? Thanks for you warm help.

1202kbs commented 5 years ago

@ShiRuiCV

There doesn't seem to be anything wrong with what you have tried. All papers dealing with DNN interpretability use the vanilla gradient computing method for BN layers.

Does your GBP visualizations look like those in Figures 5 and 7 in this paper? If so, I'm afraid there is currently no way of solving that problem. You may want to try other methods such as Smoothgrad, Integrated Gradient or DEEPLIFT. This repository has a nice implementation of several of the recent attribution methods.

GlowingHorse commented 5 years ago

Hi @1202kbs

Thanks for your warm help.

I found these three papers mainly paid attention to similar problem in Grad-CAM " find which image part make main contribution to classification result". Their methods seem like cannot solve the BN layer problem.

I attached several images I generate with method 1 and 2 mentioned above. Method 3 cannot be applied in higher layers, so I just give up it. But because method 3 eliminate the influence of the first BN layer, so the generated result in the first BN layer (method 3 ) is as good as Guided-Back in VGG-16.

The images have been rescaled for visual quality. You will find there are some grids on the image because of the BN parameters' influence. Generated images via Guided-Back in VGG-16 don't have those artifacts.

pm5544_with_non-pal_signal44_dzdx pm5544_with_non-pal_signal44_dzdx_2

GlowingHorse commented 5 years ago

Ps: original image is PM5544, and there are some more worse results there. pm5544_with_non-pal_signal55_dzdx_2 pm5544_with_non-pal_signal57_dzdx_2

1202kbs commented 5 years ago

@ShiRuiCV ,

I have done a bit of searching and I was able to find a recent paper observing the same problem with visualizations from ResNet. This paper also noted grid-like patterns in GBP visualizations. To directly quote this paper, they said "however, we do observe some additional grid-like textures here and we conjecture that this deterioration of visual quality is due to the skip connections, as we have shown earlier that network structure has a significant impact on the visualizations. We leave the rigorous analysis of this phenomenon for future work" so they also weren't able to solve this. Currently there is no solution for the problem you have.

I also have done GBP visualizations with Inception V4 which also has BN layers, but I didn't have the same problem as you had. Why do you think BN layers are causing the grid-like patterns in the visualizations?

GlowingHorse commented 5 years ago

@1202kbs

Hello, thanks for your warm help. The reason I think BN layers caused the problems is simple.

ResNet-152 structure: conv1 -> bn_conv1(batch norm) -> conv1_relu -> conv1_pool. I use activation map of conv1 to generate clear results without artifacts. But when I handle the next layer, use activation map of bn_conv1. The result will become the images I showed above.

I plan to use some other backward computing methods in BN layers. Because all params in BN layer(11channels) have been set in advance and there are only multiplication and addition in BN layer. I think the grid-like thing maybe caused by framework computing rules? Would you mind tell me the framework you are using?

1202kbs commented 5 years ago

@ShiRuiCV

I'm using tensorflow to generate the visualizations. Sorry I couldn't help you out with your problem. I'm not familiar with backpropagation through batch normalization layers. If you happen to solve it, please tell me how you did it :) I think this is an interesting problem.

GlowingHorse commented 5 years ago

@1202kbs

OK, if I solve this problem, I will reply you. Thanks