brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
491 stars 90 forks source link

Strange behavior of bm.hessian() #661

Closed Dr-Chen-Xiaoyu closed 2 months ago

Dr-Chen-Xiaoyu commented 2 months ago

Hi, Chaoming,

I am trying to use bm.hessian() to compute the hessian matrix of parameters of a model as to a loss, just like using bm.grad() for gradients.

import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
print('bp version:', bp.__version__)
bm.set_mode(bm.training_mode)
bm.random.seed(321)
# bp version: 2.4.6.post5

class RNN(bp.DynamicalSystem):
    def __init__(self, num_in, num_hidden):
        super(RNN, self).__init__()
        self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
        self.out = bp.dnn.Dense(num_hidden, 1)
    def update(self, x):
        return self.out(self.rnn(x))

# define the loss function
def lossfunc(inputs, targets):
    runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
    predicts = runner.predict(inputs)
    loss = bp.losses.mean_squared_error(predicts, targets)
    return loss

model = RNN(1, 2)
data_x=bm.random.rand(1,1000,1)
data_y=data_x+bm.random.randn(1,1000,1)
print(lossfunc(data_x,data_y))
#1.0623081

It works well with bm.grad:

lossgrad = bm.grad(lossfunc, grad_vars=model.train_vars(), return_value=True)
grad_vector=lossgrad(data_x,data_y)
print(grad_vector[0]['Dense0.W'])
#[[-0.03799846]
# [-0.38051015]]

I expect to return a nested hessian matrix just like the 2nd example in https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html, but get strange behavior with bm.hessian():

losshess = bm.hessian(lossfunc, grad_vars=model.train_vars(), return_value=True)
hess_matrix=losshess(data_x,data_y)
print(hess_matrix[0]['Dense0.W'])
print(hess_matrix[1]['Dense0.W'])
#Dense0.W
#[[-0.03799846]
# [-0.38051015]]

By the way, appreciate a lot if some examples could be provided in the document of https://brainpy.readthedocs.io/en/latest/apis/generated/brainpy.math.hessian.html 😊

Best, Xiaoyu Chen, SJTU

Dr-Chen-Xiaoyu commented 2 months ago

Inspired from jax.hessian(), I tried to modify the loss function to let the parameters of interest to be the inputs.

def lossfunc_2(w,inputs, targets):
    model.out.W=w
    runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
    predicts = runner.predict(inputs)
    loss = bp.losses.mean_squared_error(predicts, targets)
    return loss

It seems to work:

losshess = bm.hessian(lossfunc_2,argnums=0)
hess_matrix=losshess(bm.zeros((2,1)),data_x,data_y)
print(hess_matrix.squeeze())
#[[0.03443691 0.02228835]
# [0.02228835 0.2377306 ]]

I guess runner.predict() could be transformed by brainpy into a pure function? and then running with the original jax.hessian() style?

chaoming0625 commented 2 months ago

Thanks for the report. I have submitted a PR for fixing the error. Currently, the new API can produce the same behavior of the functional jax.hessian(). Please try the new API after the PR #662 has been merged.

Dr-Chen-Xiaoyu commented 2 months ago

Thanks for the report. I have submitted a PR for fixing the error. Currently, the new API can produce the same behavior of the functional jax.hessian(). Please try the new API after the PR #662 has been merged.

Thanks~😊

chaoming0625 commented 2 months ago

PR #662 has been merged. So I close this issue. Reopen any time if there are additional questions.