chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.
Other
190 stars 32 forks source link

Canonizers #63

Closed rachtibat closed 2 years ago

rachtibat commented 2 years ago

Hi Chris,

hopefully you're fine. There are two small problems I have noticed:

  1. Applying for instance SequentialMergeBatchNorm, the code iterates trough all modules in a model and checks if two subsequent modules are of Conv and BatchNorm type. Unfortunately, this does not work if the modules are not ordered linearly in the model. For instance I defined a model using torch's ModuleList and shuffled the layers as I wished in the forward pass. As a consequence, Zennit was not able to identify which module follows another one. I see, that in torch it is really difficult. Some time ago, I draw a model graph using torch's jit. Do you think there is a way to find out which module follows which one in a forward pass and then adapt Zennit to consider this correct ordering? For now, I just rewrote my code but maybe someone else will not be able to... at least we need to write this in the documentation if not already.

  2. SequentialMergeBatchNorm does not work for Conv1d layer unfortunately. The problem is following line in def merge_batch_norm(modules, batch_norm):

module.weight.data = (original_weight * scale[:, None, None, None])

There, you cast the weights to a fixed dimension of 4. But Conv1d layers have weight of dimension 3 i.e. we need module.weight.data = (original_weight * scale[:, None, None])

Do you know an elegant way to achieve this for all dimensions?

Best

chr5tphr commented 2 years ago

Hey Reduan,

thank you for keeping the issues coming!

  1. I implemented SequentialMergeBatchNorm this way for it's simplicity, since it's trivial to adapt models such that they conform to this assumption. While I was always very aware of this and thought I documented this behavior properly, I apparently did not. Originally I thought of many ways to do it, including parsing graph representations given by torch-jit, or even just looking at the gradient graph can give good clues, but it can really blow up if you try to find all edge cases. Thus, for simplicity's sake, and aforementioned easy adaptability, I opted for the iteration approach.
  2. This can be fixed very easily, since we know the needed shape from the weight. Keep an eye out for the following MR.
chr5tphr commented 2 years ago

One more comment to 1.: If you cannot adapt your model easily, or do not want to, there's always NamedMergeBatchNorm, where you can supply a list of 2-tuples, where the first element is a list of linear layer module names, and the second value is the name of the batch norm module to merge into the linear layers.