albermax / innvestigate

A toolbox to iNNvestigate neural networks' predictions!
Other
1.27k stars 233 forks source link

keras 2.2.0 can not load pretrained plos* network model weights for mnist #88

Closed sebastian-lapuschkin closed 6 years ago

sebastian-lapuschkin commented 6 years ago

Models are loaded, weights are random. Worked with keras 2.1.6.

maxkohlbrenner commented 6 years ago

Seems related to the keras.model.clone_model, the method doesn't seem to correctly initialize the model weights.

After calling model_w_sm = clone_model(model) (as done in _load_pretrained_net, mnist.py) model_w_sm.get_weights() returns an empty list.

maxkohlbrenner commented 6 years ago

Workaround (in the style of mnist_example_neuron_select.ipynb):

model, modelp = mnistutils.create_model(channels_first, modelname, **create_model_kwargs)
modelp.predict(data[0][0:1])
modelp.set_weights(model.get_weights())

Somehow the weights fo modelp are not correctly initialized when using clone_model in the current version, therefore the line modelp.set_weights(model.get_weights()) is not correctly transferring the weights when called in innvestigate/applications/mnist.py

By first doing a dummy forward and then calling the same line, both models are correctly initialized

sebastian-lapuschkin commented 6 years ago

fixed in _load_pretrained_net, mnist.py. Thank you for the workaround, @maxkohlbrenner !

pankessel commented 6 years ago

We encounter the same error with the current version. Fix does not seem to do the trick.

sebastian-lapuschkin commented 6 years ago

Hi @pankessel ,

i can not reproduce the issue on commit with a fresh clone of innvestigate (master branch, commit 43db77ea708b4fb1545c400b4dc5baca355ad5ad) and keras 2.2.0, tensorflow(-gpu) 1.8 and cuda 9.0.

Do you have some more info for us?

Cheers,

pankessel commented 6 years ago

Thanks Sebastian. After re-running

python3 setup.py develop --user

as suggest by you, it seems to work. Sorry about that.