marcoancona / DeepExplain

A unified framework of perturbation and gradient-based attribution methods for Deep Neural Networks interpretability. DeepExplain also includes support for Shapley Values sampling. (ICLR 2018)
https://arxiv.org/abs/1711.06104
MIT License
720 stars 133 forks source link

Avoiding graph recompilation from one batch to the next #31

Closed AvantiShri closed 5 years ago

AvantiShri commented 5 years ago

Hello,

I have a question about the whether the current DeepExplain implementation recompiles the graph every time the run() method is called. Currently, every call to run() appears to include a call to get_symbolic_attribution(), which I suspect results in graph recompilation when session.run is called:

https://github.com/marcoancona/DeepExplain/blob/a4d6dcd849cd408d7f03615208ef94dab101cfeb/deepexplain/tensorflow/methods.py#L123-L125

Preliminary benchmarking conducted by @jsu27 seems to suggest that this is the case (i.e. refactoring the code to avoid repeatedly calling get_symbolic_attribution() considerably sped things up). Because de.explain() accepts only a batch of samples at a time, the recompilation on every batch could add considerable overhead. Are there any plans to refactor the code (at least for gradient based methods) to iterate over the input in batches within a single call to run(), so as to avoid graph recompilations?

marcoancona commented 5 years ago

Hi, thanks for pointing out. Indeed, the run() method was designed to produce explanations for a small batch of samples and does not reuse the gradient ops. I should definitely refactor this. Some refactoring is also planned to fix a problem with recent Tensorflow versions (https://github.com/tensorflow/tensorflow/issues/23997) on which DeepExplain hangs.

linchundan88 commented 5 years ago

I have to execute K.clear_session() form time to time .

marcoancona commented 5 years ago

I still haven't had time to fix this, but yes, clearing the session can help to avoid cluttering in memory.

linchundan88 commented 5 years ago

Anyway, this is a great library. I appreciate you work very much.

linchundan88 commented 5 years ago

I hope this can be fixed before Mid April, because I want to cite you website(or paper) in my paper, which is under second review of Nature Medicine.

marcoancona commented 5 years ago

I looked into this enhancement. It seems that avoiding the compilation of the graph for each batch is more complex than I initially thought. The problem is that the target tensor potentially changes for each batch. Consider the classification problem where the target tensor is set to something like output_tensor * ys in order to filter the outputs to the unit corresponding to the correct class: in this case, ys depends on the batch and therefore the gradient needs to be recomputed. @AvantiShri you mentioned some experiments to avoid the problem, do you have a solution that works in this case? Otherwise, I would rather go for automatically cleaning up the explanation ops once the explanation has been generated, to make sure the graph does not grow indefinitely slowing down the evaluation as pointed out by @linchundan88

AvantiShri commented 5 years ago

@marcoancona Two ways to handle it. The first is one where the user compiles a different explanation function for each target class, and then only calls that function on the inputs pertaining to that particular class. That is, the user manually slices the target tensor to include only the class of interest, compiles the function, and then calls that function only on the inputs where they care about that particular class. This is the approach we used.

The other strategy would be to treat the labels as an input to the graph. That is, have a tensor ys_tensor that stores the labels, and compute the gradients with respect to output_tensor*ys_tensor. Then supply the ys via a feeddict to ys_tensor.

Does that make sense?

marcoancona commented 5 years ago

Solved in v0.2. Now the explain() method supports a batch_size parameter to enable batch processing. It also supports an optional ys parameter to pass weights (or binary masks) for the target tensor, as in the second approach suggested by @AvantiShri For back-compatibility, manual masking of the target tensor is still allowed, but it would fail when batch processing is enabled.

marcoancona commented 5 years ago

Why do you call explain() multiple times? I assumed it is because you need some batch processing, in this case you can use the new batch_size parameter. Every time you call explain() a new graph is recompiled because you might have potentially changed the target and input tensors.

linchundan88 commented 5 years ago

Because we have developed a cloud-based system, our service process need process user requests(dynamiclly).

linchundan88 commented 5 years ago

In v0.2, you modify this library very well. Mini-batch can prevent insufficient GPU memory. I appreciate it.

AvantiShri commented 5 years ago

@marcoancona in your next version, you might consider an API where you return an interpretation function to the user (rather than returning the attribution scores). The interpretation function would accept concrete input values and call sess.run without recompiling the graph. The DeepLIFT repo uses this API design, specifically to avoid the graph recompilation issue that @linchundan88 is facing: https://github.com/kundajelab/deeplift/blob/f1437900dfd427be64a8d8c1cfda302a3b2fa4a4/deeplift/util.py#L33-L51

linchundan88 commented 5 years ago

Even though the DeepLIFT library support revealcancel rule(other than rescale), it has a lot of limitations. For example residual layer. And can only support DeepLift. ^_^

AvantiShri commented 5 years ago

Oh no arguments there @linchundan88 :-) that’s why I like DeepExplain.

linchundan88 commented 5 years ago

@AvantiShri as for GradientBasedMethod

symbolic_attribution = self.get_symbolic_attribution() results = self.session_run(symbolic_attribution, self.xs, self.ys)

Do you mean API return symbolic_attribution rather than results?

AvantiShri commented 5 years ago

Close: I mean an API that returns a function that calls self.session_run on symbolic_attribution. The function takes xs and ys as inputs and returns results.

linchundan88 commented 5 years ago

@AvantiShri @marcoancona I am not familiar with tensorflow's computation graph(define and run, I am waiting for 2.0^_^, which turn on eager mode by default) I know self.session.run(T, feed_dict) in session_run_batch actually execute the computation graph, is that correct? Please help me.

AvantiShri commented 5 years ago

Yes, session.run(T, feed_dict) will execute the part of the computation graph that results in outputting the value that T has when the inputs have the values defined in feed_dict.

linchundan88 commented 5 years ago

My very ugly codes: ^_^

class GradientBasedMethod(AttributionMethod):
def run(self): self.symbolic_attribution = self.get_symbolic_attribution()

def run1(self, xs, ys):

    results = self.session_run(self.symbolic_attribution, xs, ys)

    return results[0] if not self.has_multiple_inputs else results

class DeepLIFTRescale(GradientBasedMethod): def run(self):

Check user baseline or set default one

    self._set_check_baseline()

    # Init references with a forward pass
    self._init_references()

    # Run the default run
    super(DeepLIFTRescale, self).run()

def run1(self,xs, ys):

    return super(DeepLIFTRescale, self).run1(xs,ys)

class DeepExplain(object): def explain(self, method, T, X, xs, ys=None, batch_size=None, kwargs): method = _ENABLED_METHOD_CLASS(T, X, xs, self.session, ys=ys, keras_learning_phase=self.keras_phase_placeholder, batch_size=batch_size, kwargs)

    method.run()

    return method

def explain1(self, method, xs, ys):

    result = method.run1(xs, ys)
linchundan88 commented 5 years ago

call: method = de.explain('deeplift', target_tensor, input_tensor, xs0, ys=ys0) attributions = de.explain1(method, xs, ys)

linchundan88 commented 5 years ago

Moreover, in a dynamic environment, without compiling computation graph every time, it run much faster.

linchundan88 commented 5 years ago

@marcoancona
You are great. My ugly API only works for a short time, here comes your Explainer API. ^_^