sicara / tf-explain

Interpretability Methods for tf.keras models with Tensorflow 2.x
https://tf-explain.readthedocs.io
MIT License
1.02k stars 112 forks source link

Does tf-explain support applying multiple input mode on pretrained tensorflow keras model #172

Open lastproxy opened 3 years ago

lastproxy commented 3 years ago

Generally, my input data is a list of array with different shape (same first axis, sample numbers), it seems that tf-explain does not work properly on my pretrained tensorflow keras model. The error msg looks like below:

Traceback (most recent call last):
  File "deepST_generate_map.py", line 349, in <module>
    grid = explainer.explain(data, model, class_index=0)
  File "/home/zcx/anaconda3/envs/tfpy/lib/python3.7/site-packages/tf_explain/core/grad_cam.py", line 55, in explain
    model, images, layer_name, class_index, use_guided_grads
  File "/home/zcx/anaconda3/envs/tfpy/lib/python3.7/site-packages/tf_explain/core/grad_cam.py", line 115, in get_gradients_and_filters
    inputs = tf.cast(images, tf.float32)
  File "/home/zcx/anaconda3/envs/tfpy/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py", line 201, in wrapper
    return target(*args, **kwargs)
  File "/home/zcx/anaconda3/envs/tfpy/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py", line 964, in cast
    x = ops.convert_to_tensor(x, name="x")
  File "/home/zcx/anaconda3/envs/tfpy/lib/python3.7/site-packages/tensorflow/python/profiler/trace.py", line 163, in wrapped
    return func(*args, **kwargs)
  File "/home/zcx/anaconda3/envs/tfpy/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1540, in convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/home/zcx/anaconda3/envs/tfpy/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py", line 339, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/home/zcx/anaconda3/envs/tfpy/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py", line 265, in constant
    allow_broadcast=True)
  File "/home/zcx/anaconda3/envs/tfpy/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py", line 276, in _constant_impl
    return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
  File "/home/zcx/anaconda3/envs/tfpy/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py", line 301, in _constant_eager_impl
    t = convert_to_eager_tensor(value, ctx, dtype)
  File "/home/zcx/anaconda3/envs/tfpy/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py", line 98, in convert_to_eager_tensor
    return ops.EagerTensor(value, ctx.device_name, dtype)
ValueError: Can't convert non-rectangular Python sequence to Tensor.