Closed pquadri closed 4 years ago
Hi @pquadri ! This is a good question. We plan to test all algorithms for JIT models. Some algorithms which do not use hooks such as integrated gradients or feature ablations/permutations will work out of the box. The ones which do use hooks, we need to be tested and possibly have some fixes.
Let us know if you try any and run into issues.
cc: @vivekmig , @orionr
Actually, Saliency should have worked with a JIT model. Do you have an example of your model ?
At first I tried using a wrapper like I mentioned:
class SaliencyJIT(torch.nn.Module):
def __init__(self, model):
super(SaliencyJIT, self).__init__()
self.model = Saliency(model)`
def forward(self, inputs):
labels = torch.zeros([inputs.size(0)]).to(torch.long).cuda()
return self.model.attribute(inputs, labels)`
Right now I am testing adding a method to the model itself :
def saliency(self, inputs):
self.saliency_model = Saliency(self.forward)
labels = torch.zeros([inputs.size(0)]).to(torch.long)
return self.saliency_model.attribute(inputs, labels)
And then tracing it using torch.jit.script, and this approach seems to be working, still testing if it yields the same results
EDIT: I tried both tracing and scripting on the second approach:
with scripting it yields a "Generators are not supported"
UnsupportedNodeError: GeneratorExp aren't supported:
at /opt/conda/envs/metal/lib/python3.7/site-packages/captum/attr/_core/saliency.py:119:33
inputs = format_input(inputs)
gradient_mask = apply_gradient_requirements(inputs)
# No need to format additional_forward_args here.
# They are being formated in the `_run_forward` function in `common.py`
gradients = self.gradient_func(
self.forward_func, inputs, target, additional_forward_args
)
if abs:
attributions = tuple(torch.abs(gradient) for gradient in gradients)
~ <--- HERE
else:
attributions = gradients
undo_gradient_requirements(inputs, gradient_mask)
return _format_attributions(is_inputs_tuple, attributions)
Hi @pquadri, that's an interesting question, thanks for posting the information of what you tried!
I think the main use case we expected was just passing a JIT scripted / traced model to a Captum attribution method, which should work directly for methods not using hooks such as IntegratedGradients and Saliency, which is what @NarineK was referencing.
Your use case is interesting, to create a JIT-scripted module which directly outputs the model attribution. I tried out a few experiments of this, and like the issues you encountered, the main problems were:
One way I found to accomplish what you're looking for with Saliency is to just add the autograd call, e.g. grads = torch.autograd.grad(torch.unbind(outputs), inputs)
at the end of the the model's forward method, since Saliency is essentially just the model's input gradients. As explained above, this cannot be traced, but scripting this module appears to accomplish what you want to achieve.
This will likely be more complex to accomplish for other methods, but you could potentially do this for a particular method in the same way by moving the necessary parts of the method's implementation in Captum to the module forward method and remove / rearrange any portions which are not supported by scripting. This is likely somewhat messy, but unfortunately the best solution to accomplish this right now. Otherwise just JIT scripting the model and using Captum on that would work, but would still require Python. Let us know if you have any further questions on this!
Edit: See below for a potential cleaner solution.
Actually, I experimented a little more with this, I think there's a solution that seems to work by combining scripting for the autograd portions and tracing for the overall method. We will need to test further and come up with the best way to incorporate this generically before adding it to Captum, but you can try this change and see if it works for you:
Replace line 92 in attr/_utils/gradient.py, within the compute_gradients method, which is currently:
grads = torch.autograd.grad(torch.unbind(outputs), inputs)
with:
def grad_fn(outputs: Tensor, inputs: Tuple[Tensor]):
return torch.autograd.grad(torch.unbind(outputs), inputs)
grad_fn_scripted = torch.jit.script(grad_fn)
grads = tuple(grad_fn_scripted(outputs, inputs))
and add the following imports to the header:
from typing import Tuple
from torch import Tensor
With these changes, you should be able to directly trace attribution methods, for example:
sal = Saliency(model)
scripted_saliency = torch.jit.trace(sal.attribute, inp, check_trace=False)
Note the check_trace needs to be False, there seems to be some issues with the check that need to be debugged further.
This should work for feature attribution methods such as Integrated Gradients and Saliency (layer and neuron methods will need further changes). The limitations of tracing still apply though (e.g. data dependent conditionals aren't scripted), which may affect some functionalities, I will need to investigate that further, but this can hopefully help you get started with scripted attribution methods.
Thank you very much for the help, I did try to rewrite a part of the saliency code yesterday to avoid unsupported functionalities but I got stuck to a point where I just didn't know enough about Torchscript to get it to work. As far as the second part, I did something similar for my own implementation of GradCam (which also required to switch check_trace to False) and the only solution I found was to mix tracing and scripting to have the traced GradCa calculated on a single input and the scripting part to handle the number of calls and the conditionals statements.
I did try your suggestion: when using your code to track Saliency.attribute directly I still got
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a
parameter or input, or detaching the gradient
there is the option that this error is due to the model structure, but I managed to fix it by using the solution of adding a "saliency" method to the model class (same as my comment above except that Saliency initialization is within the model init() method).
In this way I can get the results I was looking for.
That's great, I'm glad you were able to get the results you were looking for!
To resolve that error, I think another option without adding it to the model would be to set requires_grad to be False on your model's parameters, which should allow tracing the function directly, e.g.:
for param in model.parameters():
param.requires_grad = False
Thank you for the thorough investigation, @vivekmig and the question @pquadri ! We'll definitely need to look into this more and evaluate its benefits.
Hello, I was experimenting with Captum and I was wondering if there was any way to trace/script an attribution model in order to just obtain the final heatmap as output of the serialized file.
I did not find any reference in the documentation nor in the code, and did not manage to integrate it myself by creating intermediate classes to, for example, wrap the Saliency class in a torch.nn.Module one.
Is there something I am missing / is it in the future plans?