keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.68k stars 19.42k forks source link

Using keras model with torch.func.jacrev gives zero gradients #20279

Open smartArancina opened 6 days ago

smartArancina commented 6 days ago

Issue description

Trying computing the model jacobian using a keras model returns zeros

Code example

import keras
import torch
model = keras.Sequential(
    [keras.layers.Dense(1,activation='linear')]
)
model.build(input_shape=(1,1))
x = torch.tensor([1.0])
params = dict(model.named_parameters())
def f(params, x):
    prediction = torch.func.functional_call(
        model, params, x, strict=True,
    )
    print("prediction", prediction)
    print("prediction has grad", prediction.requires_grad)
    return prediction

J = torch.func.jacrev(f)(
    params,
    x,
)

Gives as output

prediction GradTrackingTensor(lvl=1, value=
    tensor([[0.7901]], grad_fn=<AddBackward0>)
)
prediction has grad False
{'_functional.torch_params.sequential_12/dense_97/bias': tensor([[[0.]]]),
 '_functional.torch_params.sequential_12/dense_97/kernel': tensor([[[[0.]]]])}

I am new to pytorch, maybe the problem is related with the fact that the prediction requires_gradis false ? I don't understand why if I call the model outside torch.func.jacrev the requires_grad attribute of the prediction becomes True (code below)

model(x).requires_grad
True

Maybe there is some compatbility issue between torch.func and Keras model?

Using a pure pytorch code it seems to work correctly

model = torch.nn.Sequential(
    nn.Linear(1,1)
)

x = torch.tensor([1.0])
params = dict(model.named_parameters())
def f(params, x):
    prediction = torch.func.functional_call(
        model, params, x, strict=True,
    )
    print("prediction", prediction)
    print("prediction has grad", prediction.requires_grad)
    return prediction

J = torch.func.jacrev(f)(
    params,
    x,
)
J
prediction GradTrackingTensor(lvl=1, value=
    tensor([0.5290], grad_fn=<AddBackward0>)
)
prediction has grad True
{'0.weight': tensor([[[1.]]]), '0.bias': tensor([[1.]])}

System Info

Python 3.9.5

keras                     3.5.0
torch                     2.4.0
torchaudio                2.4.0
torchmetrics              1.4.1
torchvision               0.19.0
VarunS1997 commented 3 days ago

Have you taken a look at "Model.stateless_call()`?

Theoretically, PyTorch should work directly but we have not particularly tested the functional_call(). Its functionality would depend on how the torch function is written.

smartArancina commented 2 days ago

Hi @VarunS1997 thanks for the answer. Yes, I tried as you suggested and it seems to work on this dummy example

import keras
import torch
model = keras.Sequential(
    [keras.layers.Dense(1,activation='linear')]
)
model.build(input_shape=(1,1))
x = torch.tensor([1.0])
params = [p.value for p in model.trainable_weights]
def f(params, x):
    # pytorch functional_call
    # prediction = torch.func.functional_call(
    #     model, params, x, strict=True,
    # )
    prediction = model.stateless_call(params, [], x)[0]
    print("prediction", prediction)
    print("prediction has grad", prediction.requires_grad)
    return prediction
#
J = torch.func.jacrev(f)(
    params,
    x,
)
J
prediction GradTrackingTensor(lvl=1, value=
    tensor([[-0.5455]], grad_fn=<AddBackward0>)
)
prediction has grad True
[tensor([[[[1.]]]]), tensor([[[1.]]])]

The problem is that I am using another pytorch package where the model differentations are done using torch.func.functional_call and I would not modify their source code to be compliant with keras. For this reason I ask you : 1) is there a workaround to make keras + Pytorch fully trasparent to pure pytorch code (torch.func.functional_call)? For exampIe unwrapping the pytorch model from keras, make the computation with ext. libraries and then re-wrapping trasparently will be a workaroud. I am asking this because I am moving from keras + Tensorflow to Keras + Pytorch (because I need some pyotorch external libraries) and the trasparancy above will be a crucial aspect. 2) Just to understand better / curiosity: why the the keras model called with torch.func.functional_call has no gradient prediction has grad False ? Is this why it doesn't work ? Thanks!