albermax / innvestigate

A toolbox to iNNvestigate neural networks' predictions!
Other
1.24k stars 235 forks source link

[BUG] Mutation warning being printed at every layer of EfficientNet when running LRP analyzer #324

Open palatyle opened 7 months ago

palatyle commented 7 months ago

Describe the bug

Running any LRP analyzer, I get tensorflow warnings in the format of something like: 2023-11-22 10:49:33.032371: W tensorflow/c/c_api.cc:305] Operation '{name:'mul_1355/x' id:43796 op device:{requested: '', assigned: ''} def:{{{node mul_1355/x}} = Const[_has_manual_control_dependencies=true, dtype=DT_FLOAT, value=Tensor<type: float shape: [] values: 1e-07>]()}}' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.

Being printed at what looks to be every single layer in the EfficientNet model I'm using. It also runs very slowly, taking about 15 minutes (longer if I use the sequential presets) to analyze the model. Is this runtime normal due to the size of the network? Can I safely ignore these warnings or is something actually wrong?

Steps to reproduce the bug

NOTE: I had to implement juliowissing-iis's fix here for this code block to run without errors.

import tensorflow as tf
import tensorflow.keras.utils as ku
import numpy as np
import innvestigate as inn
from tensorflow.keras.applications import EfficientNetB5
import time

tf.compat.v1.disable_eager_execution()

# create EfficientNet B5 model with imagenet weights
model = EfficientNetB5(weights='imagenet')

# Read in data
img = ku.load_img('img.jpg',target_size = (456,456))
# Conver1 to numpy array
img_arr = ku.img_to_array(img)
# Add in extra 1st dim
img_CNN = np.array([img_arr])

# Remove last softmax layer (required for INNvestigate)
model_wo_softmax = inn.model_wo_softmax(model)

# Create LRP epsilon analyzer
analyzer = inn.create_analyzer("lrp.epsilon", model_wo_softmax,**{'epsilon':1})

# Analyze input image
t0=time.time()
a = analyzer.analyze(img_CNN)
t1=time.time()

total_time = t1-t0
print(f'Anaylzer took {total_time} seconds')

Expected behavior

Minimal mutation warnings and a code that doesn't take 15 minutes to run.

Platform information