greentfrapp / lucent

Lucid library adapted for PyTorch
Apache License 2.0
597 stars 89 forks source link

render_vis becomes slow when used multiple times #45

Closed antoninogreco closed 3 months ago

antoninogreco commented 1 year ago

Hi, I noticed this strange behavior using the render_vis function when optimizing multiple images (like when you use it with videos). The time spent to otpimize the same image with same amount of iterations increases, both using CPU and GPU! What do you think could be the reason?

tshead2 commented 1 year ago

I've noticed this, too.

TomasWilson commented 3 months ago

This issue stems from the fact that the registered forward_hooks (they are used to grab feature maps from the hidden layers) are not properly cleaned up before render_vis returns.

This is quite a serious bug, although it is not harmful to the correctness of the visualizations. After many iterations, the sub modules accumulate hundreds of these hooks, that get called after every forward() call on the model. This even impacts any other code that uses the model after calling render_vis.

A quick and dirty fix is to clear all forward_hooks from the model, after each call to render_vis, like so:

from collections import OrderedDict
def remove_all_forward_hooks(model):
    for _, child in model._modules.items():
        if child is not None:
            if hasattr(child, "_forward_hooks"):
                child._forward_hooks = OrderedDict()
            remove_all_forward_hooks(child)

# call this after render_vis
remove_all_forward_hooks(your_torch_model)

(I took this code from here)

Of course, a proper fix to this should be patched directly into the library, cleaning up the model by removing only the hooks actually created by the render_vis function (the above code removes ALL hooks, even those unrelated to Lucent / render_vis).

greentfrapp commented 3 months ago

Thanks @antoninogreco and @tshead2 for reporting this! And thanks to @TomasWilson for the suggested fix! This should be fixed with #52 where the created module hooks are cleared after each render_vis run.