keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.92k stars 19.45k forks source link

Gradient of the output wrt to its weights (agnostic - backend) #5721

Closed afcruzs closed 7 years ago

afcruzs commented 7 years ago

I need to compute the gradient of the output w.r.t to its weights, the keras backend has a method called gradients which seems to do the work. Here is a question on this matter:

http://stackoverflow.com/questions/39561560/getting-gradient-of-model-output-w-r-t-weights-using-keras

This seems to work pretty well on tensorflow, but when I run the code (with theano as a backend) of the accepted answer (which makes sense to me) it shows the following exception "TypeError: cost must be a scalar." when calling the tensor.grad method.

To actually run the gradients they use tensorflow, I wonder if there is a way of computing the gradients with keras without use an specific backend.

This is the code which fails when using theano as backend:

from keras.models import Sequential
from keras.layers import Dense, Activation
from keras import backend as k

model = Sequential()
model.add(Dense(12, input_dim=8, init='uniform', activation='relu'))
model.add(Dense(8, init='uniform', activation='relu'))
model.add(Dense(1, init='uniform', activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
outputTensor = model.output
listOfVariableTensors = model.trainable_weights
gradients = k.gradients(outputTensor, listOfVariableTensors)
patrickmesana commented 7 years ago

Anyone has an answer on this?

stale[bot] commented 7 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

AmirAlavi commented 6 years ago

I'm running into this as well. I'm trying to run an "improved" WGAN from here: https://github.com/eriklindernoren/Keras-GAN/tree/master/wgan_gp I only have access to the Theano backend on my cluster.