Open smartArancina opened 6 days ago
Have you taken a look at "Model.stateless_call()`?
Theoretically, PyTorch should work directly but we have not particularly tested the functional_call(). Its functionality would depend on how the torch function is written.
Hi @VarunS1997 thanks for the answer. Yes, I tried as you suggested and it seems to work on this dummy example
import keras
import torch
model = keras.Sequential(
[keras.layers.Dense(1,activation='linear')]
)
model.build(input_shape=(1,1))
x = torch.tensor([1.0])
params = [p.value for p in model.trainable_weights]
def f(params, x):
# pytorch functional_call
# prediction = torch.func.functional_call(
# model, params, x, strict=True,
# )
prediction = model.stateless_call(params, [], x)[0]
print("prediction", prediction)
print("prediction has grad", prediction.requires_grad)
return prediction
#
J = torch.func.jacrev(f)(
params,
x,
)
J
prediction GradTrackingTensor(lvl=1, value=
tensor([[-0.5455]], grad_fn=<AddBackward0>)
)
prediction has grad True
[tensor([[[[1.]]]]), tensor([[[1.]]])]
The problem is that I am using another pytorch package where the model differentations are done using torch.func.functional_call
and I would not modify their source code to be compliant with keras. For this reason I ask you :
1) is there a workaround to make keras + Pytorch fully trasparent to pure pytorch code (torch.func.functional_call
)? For exampIe unwrapping the pytorch model from keras, make the computation with ext. libraries and then re-wrapping trasparently will be a workaroud. I am asking this because I am moving from keras + Tensorflow to Keras + Pytorch (because I need some pyotorch external libraries) and the trasparancy above will be a crucial aspect.
2) Just to understand better / curiosity: why the the keras model called with torch.func.functional_call
has no gradient prediction has grad False ? Is this why it doesn't work ?
Thanks!
Issue description
Trying computing the model jacobian using a keras model returns zeros
Code example
Gives as output
I am new to pytorch, maybe the problem is related with the fact that the prediction
requires_grad
is false ? I don't understand why if I call the model outsidetorch.func.jacrev
therequires_grad
attribute of the prediction becomes True (code below)Maybe there is some compatbility issue between torch.func and Keras model?
Using a pure pytorch code it seems to work correctly
System Info
Python 3.9.5