raghakot / keras-vis

Neural network visualization toolkit for keras
https://raghakot.github.io/keras-vis
MIT License
2.97k stars 664 forks source link

visualize_saliency breaks down if input_shape is (None, None, 3) #95

Closed lucasdavid closed 6 years ago

lucasdavid commented 6 years ago

My network accepts images of variable height and width, which can be represented by x = Input(1, None, None, 3). Invoking the vis.visualization.visualize_saliency on this network will throw the following error:

ValueError: Cannot feed value of shape (1, 3, 235, 224) for Tensor 'input_1_2:0', which has shape '(?, ?, ?, 3)'

I believe the problem is in Optimizer._get_seed_input(seed_input) (optimizer.py L95, on master), where an assumption incorrectly swaps the channel dim from its correct position to an incorrect one:

# Only possible if channel idx is out of place.
if seed_input.shape != desired_shape:
    seed_input = np.moveaxis(seed_input, -1, 1)
return seed_input.astype(K.floatx())

For example, let's say I'm feeding an image of (423, 451, 3). seed_input.shape == (1, 423, 451, 3) != (1, None, None, 3) == desired_shape will pass and seed_input.shape will become (1, 3, 451, 423).

This can be corrected by removing this condition and assuming the user knows what to input. However, if this condition must be kept, then maybe we could make this checking more strict. Something in the lines of:

channels = 1 if K.image_data_format() == 'channels_first' else -1
if seed_input.shape[channels] != desired_shape[channels]:
    seed_input = np.moveaxis(seed_input, -1, 1)
lucasdavid commented 6 years ago

duplicate of #90