raghakot / keras-vis

Neural network visualization toolkit for keras
https://raghakot.github.io/keras-vis
MIT License
2.97k stars 664 forks source link

Performance issue with visualize_cam #71

Open vijayg78 opened 6 years ago

vijayg78 commented 6 years ago

Hi, We are using the visualize_cam to visualize the learning of neural net (xception). We used the following code to generate heatmap and ovelays for about 5000 cross-validation images. We saw the time taken by visualize_cam steadily increases over time. It starts with approx 2 seconds and it reaches 9 seconds per iteration by the time it reached 1000 samples. Any idea how to improve the performance and have a steady performance? Looks like there is some memory leak or so.

The memory consumption also goes up in this case steadily.

def load_model():
    # load state of the art model for now exception
    model = keras.models.load_model('top_model_weights-00-1.00.h5')
    #model.summary()
    # Utility to search for layer index by name. 
    # Alternatively we can specify this as -1 since it corresponds to the last layer.
    layer_idx = utils.find_layer_idx(model, 'dense_1')

    # Swap softmax with linear
    model.layers[layer_idx].activation = activations.linear
    model = utils.apply_modifications(model)
    return model,layer_idx

def generate_heatmaps(logs,image_dir,heatmap_dir):
    model,layer_idx = load_model()
    keys = logs.keys()
    warnings.filterwarnings('ignore')
    plt.rcParams['figure.figsize'] = (18, 6)
    for key in keys:
        img_name = key
        original = utils.load_img(os.path.join(image_dir,key),target_size=(299,299))
        heatmap = visualize_cam(model,layer_idx, 
                                                         filter_indices=int(classified), 
                                                         seed_input=original, backprop_modifier=None)
        imsave(os.path.join(plainhm_dir,key),heatmap)
    print('{} images there'.format(len(list_of_images)))
raghakot commented 6 years ago

I was planning on releasing a batch API for the visualizations that would be super fast for use-cases like this. I need to investigate the issue with the leak. Thanks for pointing this out.

vobject commented 6 years ago

I'm ran into the same leak-like-situation working on a real time cam visualization. Calls to visualize_cam() get slower over time and consume increasing amounts of memory.

The issue seems to be inside Optimizer.init() when building the losses. But I'm not versatile enough to track it down. Tensor ops might be chained together indefinitely.

RickardSjogren commented 6 years ago

I run into the same situation running visualize_activation repeatedly. The time per image increased to require more than x10 from start until I interrupted. The memory consumption increased to around x3 as well.

LennartPiro commented 6 years ago

I have the same problem with visualize_saliency. I profiled my code and the culprit seems to be <built-in method tensorflow.python._pywrap_tensorflow_internal.TF_ExtendGraph>, which is called every time a saliency is calculated (from Optimizer.minimize). Each call to TF_ExtendGraph takes longer than the previous one, and uses up some additional memory. Alas, I can't figure out how to fix this. My stupid workaround is to run my code on about 200 pictures at a time, to completely reset tensorflow in between.

saliency_profiling

JeremBlain commented 5 years ago

Any news ?? It is pretty annoying when you want to compute the saliency in a loop... :/

Thys3Potgieter commented 4 years ago

We use this function to generate heatmaps as part of a REST API in a production setup - latency increases until requests time-out. Really want to stick to the library, even a workaround will be appreciated. Please advise if this will be fixed?

Hommoner commented 4 years ago

I found the leak is due to every time we call this line "opt = Optimizer(input_tensor, losses, wrt_tensor=penultimate_output, norm_grads=False)" , the tensorflow graph will add new tensor. My workaround is only get "opt" once and keep it in memory.

Hommoner commented 4 years ago

code like this: create init once function and generate map function

def visualize_cam_init(model, layer_idx, filter_indices):

    penultimate_layer = _find_penultimate_layer(model, layer_idx, None)
    losses = [
        (ActivationMaximization(model.layers[layer_idx], filter_indices), -1)
    ]
    penultimate_output = penultimate_layer.output
    opt = Optimizer(model.input, losses, wrt_tensor=penultimate_output, norm_grads=False)
    return opt

def visualize_cam_run(model, layer_idx, opt,seed_input):
    input_tensor = model.input
    penultimate_layer = _find_penultimate_layer(model, layer_idx, None)
    penultimate_output = penultimate_layer.output

    _, grads, penultimate_output_value = opt.minimize(seed_input, max_iter=1, grad_modifier=None, verbose=False)
    #opt.minimize(seed_input, max_iter=1, grad_modifier=grad_modifier, verbose=False)

    # For numerical stability. Very small grad values along with small penultimate_output_value can cause
    # w * penultimate_output_value to zero out, even for reasonable fp precision of float32.
    grads = grads / (np.max(grads) + K.epsilon())
    # Average pooling across all feature maps.
    # This captures the importance of feature map (channel) idx to the output.
    channel_idx = 1 if K.image_data_format() == 'channels_first' else -1
    other_axis = np.delete(np.arange(len(grads.shape)), channel_idx)
    weights = np.mean(grads, axis=tuple(other_axis))

    # Generate heatmap by computing weight * output over feature maps
    output_dims = utils.get_img_shape(penultimate_output)[2:]
    heatmap = np.zeros(shape=output_dims, dtype=K.floatx())

    for i, w in enumerate(weights):
        if channel_idx == -1:
            heatmap += w * penultimate_output_value[0, ..., i]
        else:
            heatmap += w * penultimate_output_value[0, i, ...]

    # ReLU thresholding to exclude pattern mismatch information (negative gradients).
    heatmap = np.maximum(heatmap, 0)

    # The penultimate feature map size is definitely smaller than input image.
    input_dims = utils.get_img_shape(input_tensor)[2:]
    heatmap = imresize(heatmap, input_dims, interp='bicubic', mode='F')

    # Normalize and create heatmap.
    heatmap = utils.normalize(heatmap)
    return np.uint8(cm.jet(heatmap)[..., :3] * 255)

my calling subroutines like this:

def init_cam(self):
    if not self.isCamInit:
        self.isCamInit = True    
        del self.CamOpt[:]
        #for opt in self.CamOpt:
        self.CamOpt = []
        cnt = 0
        #print('self.labels=',self.labels)
        for n in self.labels:
            #print('label=',n,', cnt=',cnt)
            opt = visualize_cam_init(self.model,-1,filter_indices=cnt)
            self.CamOpt.append(opt)
            cnt = cnt +1

def Get_Activation_map(self,class_id):
    if not self.isCamInit:
       self.init_cam()

    grads = []
    try:   
        opt = self.CamOpt[class_id]
        grads = visualize_cam_run(self.model, -1, opt, seed_input=self.image_array)

if we reuse opt , the calculation is very fast and almost real-time , have fun! :)

betterze commented 4 years ago

@Hommoner Thank for sharing the code.

Why you define several opt, rather than just use one. opt is the same for every class_id, right? Or you want to keep the opt as init state?

The class_id is the index of image? or the class of network work final layer output?

mdew192837 commented 4 years ago

@raghakot This would be immensely helpful. Any updates on this? Thanks!!

mdew192837 commented 4 years ago

@Hommoner Could you provide a more detailed description of your solution? It seems like some of your code is written in a class, but it's a bit hard to see how to use it in practice.

Thanks!!

Hommoner commented 4 years ago

@Hommoner Thank for sharing the code.

Why you define several opt, rather than just use one. opt is the same for every class_id, right? Or you want to keep the opt as init state?

The class_id is the index of image? or the class of network work final layer output?

several opt should work but I only need to see one layer activation map. init opt once is due to original code "opt = Optimizer(input_tensor, losses, wrt_tensor=penultimate_output, norm_grads=False)" will cause tensorflow graph will add new tensor,but we only need to init this once and use it every time if we don't need to change the layer.

One you init a certain opt you can use class_id to see the activation map for the class_id in this layer, ex: A model for classify cats[calss_id=0] and dogs[calss_id=1], If this image is really a dog and I want to see the activation map result for cats , I can specify calss_id = 1.