pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.96k stars 499 forks source link

Scripting/tracing Captum classes #228

Closed pquadri closed 4 years ago

pquadri commented 4 years ago

Hello, I was experimenting with Captum and I was wondering if there was any way to trace/script an attribution model in order to just obtain the final heatmap as output of the serialized file.

I did not find any reference in the documentation nor in the code, and did not manage to integrate it myself by creating intermediate classes to, for example, wrap the Saliency class in a torch.nn.Module one.

Is there something I am missing / is it in the future plans?

NarineK commented 4 years ago

Hi @pquadri ! This is a good question. We plan to test all algorithms for JIT models. Some algorithms which do not use hooks such as integrated gradients or feature ablations/permutations will work out of the box. The ones which do use hooks, we need to be tested and possibly have some fixes.

Let us know if you try any and run into issues.

cc: @vivekmig , @orionr

NarineK commented 4 years ago

Actually, Saliency should have worked with a JIT model. Do you have an example of your model ?

pquadri commented 4 years ago

At first I tried using a wrapper like I mentioned:

class SaliencyJIT(torch.nn.Module):
    def __init__(self, model):
        super(SaliencyJIT, self).__init__()
        self.model = Saliency(model)`

    def forward(self, inputs):
        labels = torch.zeros([inputs.size(0)]).to(torch.long).cuda()
        return self.model.attribute(inputs, labels)`

Right now I am testing adding a method to the model itself :

def saliency(self, inputs):
    self.saliency_model = Saliency(self.forward)
    labels = torch.zeros([inputs.size(0)]).to(torch.long)  
    return self.saliency_model.attribute(inputs, labels)

And then tracing it using torch.jit.script, and this approach seems to be working, still testing if it yields the same results

EDIT: I tried both tracing and scripting on the second approach:

vivekmig commented 4 years ago

Hi @pquadri, that's an interesting question, thanks for posting the information of what you tried!

I think the main use case we expected was just passing a JIT scripted / traced model to a Captum attribution method, which should work directly for methods not using hooks such as IntegratedGradients and Saliency, which is what @NarineK was referencing.

Your use case is interesting, to create a JIT-scripted module which directly outputs the model attribution. I tried out a few experiments of this, and like the issues you encountered, the main problems were:

  1. For scripting, there's quite a bit of Python functionality which is not supported by TorchScript, such as generators, which are used in the Saliency implementation and across the Captum codebase. It would take quite a lot of effort to update or reimplement all methods to only use TorchScript supported functionality, so this is something we will have to think about and discuss further internally.
  2. For tracing, any method that takes gradients returns an error that "output of traced region did not have observable data dependence with trace inputs". This seems to be the result of a known limitation that the returned value of autograd is treated as a constant, since the autograd function isn't tracked.

One way I found to accomplish what you're looking for with Saliency is to just add the autograd call, e.g. grads = torch.autograd.grad(torch.unbind(outputs), inputs) at the end of the the model's forward method, since Saliency is essentially just the model's input gradients. As explained above, this cannot be traced, but scripting this module appears to accomplish what you want to achieve.

This will likely be more complex to accomplish for other methods, but you could potentially do this for a particular method in the same way by moving the necessary parts of the method's implementation in Captum to the module forward method and remove / rearrange any portions which are not supported by scripting. This is likely somewhat messy, but unfortunately the best solution to accomplish this right now. Otherwise just JIT scripting the model and using Captum on that would work, but would still require Python. Let us know if you have any further questions on this!

Edit: See below for a potential cleaner solution.

vivekmig commented 4 years ago

Actually, I experimented a little more with this, I think there's a solution that seems to work by combining scripting for the autograd portions and tracing for the overall method. We will need to test further and come up with the best way to incorporate this generically before adding it to Captum, but you can try this change and see if it works for you:

Replace line 92 in attr/_utils/gradient.py, within the compute_gradients method, which is currently: grads = torch.autograd.grad(torch.unbind(outputs), inputs)with:

def grad_fn(outputs: Tensor, inputs: Tuple[Tensor]):
    return torch.autograd.grad(torch.unbind(outputs), inputs)
grad_fn_scripted = torch.jit.script(grad_fn)
grads = tuple(grad_fn_scripted(outputs, inputs))

and add the following imports to the header:

from typing import Tuple
from torch import Tensor

With these changes, you should be able to directly trace attribution methods, for example:

sal = Saliency(model)
scripted_saliency = torch.jit.trace(sal.attribute, inp, check_trace=False)

Note the check_trace needs to be False, there seems to be some issues with the check that need to be debugged further.

This should work for feature attribution methods such as Integrated Gradients and Saliency (layer and neuron methods will need further changes). The limitations of tracing still apply though (e.g. data dependent conditionals aren't scripted), which may affect some functionalities, I will need to investigate that further, but this can hopefully help you get started with scripted attribution methods.

pquadri commented 4 years ago

Thank you very much for the help, I did try to rewrite a part of the saliency code yesterday to avoid unsupported functionalities but I got stuck to a point where I just didn't know enough about Torchscript to get it to work. As far as the second part, I did something similar for my own implementation of GradCam (which also required to switch check_trace to False) and the only solution I found was to mix tracing and scripting to have the traced GradCa calculated on a single input and the scripting part to handle the number of calls and the conditionals statements.

I did try your suggestion: when using your code to track Saliency.attribute directly I still got

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a 
parameter or input, or detaching the gradient

there is the option that this error is due to the model structure, but I managed to fix it by using the solution of adding a "saliency" method to the model class (same as my comment above except that Saliency initialization is within the model init() method).

In this way I can get the results I was looking for.

vivekmig commented 4 years ago

That's great, I'm glad you were able to get the results you were looking for!

To resolve that error, I think another option without adding it to the model would be to set requires_grad to be False on your model's parameters, which should allow tracing the function directly, e.g.:

for param in model.parameters():
    param.requires_grad = False
NarineK commented 4 years ago

Thank you for the thorough investigation, @vivekmig and the question @pquadri ! We'll definitely need to look into this more and evaluate its benefits.