Closed JustusvLiebig closed 1 year ago
We track the HEAD of kfac-jax. The pmap_axis_name
argument TwoKroneckerFactored.update_curvature_matrix_estimate
was removed a few months ago (https://github.com/deepmind/kfac-jax/commit/4bf9bec1618035cb58586ea1b00d75a6172da45c).
If you are using an older version of kfac-jax, you can either update or rollback the commit which updated the API ferminet uses (https://github.com/deepmind/ferminet/commit/e29145a6a18ffa74f4d012b718526bada3ffcd18 - looks like the code in your comment)
Thank you for your answer. I'm sorry that I installed the kfac-jax via pypi, which is updated in May 17, and this 4bf9bec is commit on Jun 26. Thank you again for solving my question.
I think current version of update_curvature_matrix_estimate have something problem for ignoring the input name
pmap_axis_name
, and thus I think the proper function should be given as follow:where I added pmap_axis_name.