google-deepmind / ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Apache License 2.0
721 stars 120 forks source link

Something went wrong in RepeatedDenseBlock.update_curvature_matrix_estimate #67

Closed JustusvLiebig closed 1 year ago

JustusvLiebig commented 1 year ago

I think current version of update_curvature_matrix_estimate have something problem for ignoring the input name pmap_axis_name, and thus I think the proper function should be given as follow:

  def update_curvature_matrix_estimate(
      self,
      state: kfac_jax.TwoKroneckerFactored.State,
      estimation_data: Mapping[str, Sequence[Array]],
      ema_old: Numeric,
      ema_new: Numeric,
      batch_size: int,
          pmap_axis_name: Optional[str],
          sync: Array | bool = True,
  ) -> kfac_jax.TwoKroneckerFactored.State:
    estimation_data = dict(**estimation_data)
    x, = estimation_data["inputs"]
    dy, = estimation_data["outputs_tangent"]
    assert x.shape[0] == batch_size
    estimation_data["inputs"] = (x.reshape([-1, x.shape[-1]]),)
    estimation_data["outputs_tangent"] = (dy.reshape([-1, dy.shape[-1]]),)
    batch_size = x.size // x.shape[-1]
    return super().update_curvature_matrix_estimate(
        state=state,
        estimation_data=estimation_data,
        ema_old=ema_old,
        ema_new=ema_new,
        batch_size=batch_size,
        pmap_axis_name=pmap_axis_name,
    )

where I added pmap_axis_name.

jsspencer commented 1 year ago

We track the HEAD of kfac-jax. The pmap_axis_name argument TwoKroneckerFactored.update_curvature_matrix_estimate was removed a few months ago (https://github.com/deepmind/kfac-jax/commit/4bf9bec1618035cb58586ea1b00d75a6172da45c).

If you are using an older version of kfac-jax, you can either update or rollback the commit which updated the API ferminet uses (https://github.com/deepmind/ferminet/commit/e29145a6a18ffa74f4d012b718526bada3ffcd18 - looks like the code in your comment)

JustusvLiebig commented 1 year ago

Thank you for your answer. I'm sorry that I installed the kfac-jax via pypi, which is updated in May 17, and this 4bf9bec is commit on Jun 26. Thank you again for solving my question.