Closed Dr-Chen-Xiaoyu closed 2 months ago
Inspired from jax.hessian()
, I tried to modify the loss function to let the parameters of interest to be the inputs.
def lossfunc_2(w,inputs, targets):
model.out.W=w
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss
It seems to work:
losshess = bm.hessian(lossfunc_2,argnums=0)
hess_matrix=losshess(bm.zeros((2,1)),data_x,data_y)
print(hess_matrix.squeeze())
#[[0.03443691 0.02228835]
# [0.02228835 0.2377306 ]]
I guess runner.predict() could be transformed by brainpy into a pure function? and then running with the original jax.hessian()
style?
Thanks for the report. I have submitted a PR for fixing the error. Currently, the new API can produce the same behavior of the functional jax.hessian()
. Please try the new API after the PR #662 has been merged.
Thanks for the report. I have submitted a PR for fixing the error. Currently, the new API can produce the same behavior of the functional
jax.hessian()
. Please try the new API after the PR #662 has been merged.
Thanks~😊
PR #662 has been merged. So I close this issue. Reopen any time if there are additional questions.
Hi, Chaoming,
I am trying to use
bm.hessian()
to compute the hessian matrix of parameters of a model as to a loss, just like usingbm.grad()
for gradients.It works well with
bm.grad
:I expect to return a nested hessian matrix just like the 2nd example in https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html, but get strange behavior with
bm.hessian()
:By the way, appreciate a lot if some examples could be provided in the document of https://brainpy.readthedocs.io/en/latest/apis/generated/brainpy.math.hessian.html 😊
Best, Xiaoyu Chen, SJTU