jeffminlin / vmcnet

Flexible, general-purpose VMC framework, built on JAX.
https://jeffminlin.github.io/vmcnet/
MIT License
24 stars 2 forks source link

Fix KFAC when running vmc-molecule with single ion #115

Closed ggoldsh closed 1 year ago

ggoldsh commented 1 year ago

I realized my previous PR #114 introduced a bug where KFAC would throw an error when running vmc-molecule on a system with only one ion.

On debugging it seems like kfac_jax has a bug where it ends up dividing by zero and throwing a floating point error when a layer of the network involves only 1x1 kernels, so that every block of the block-diagonal part of KFAC for that layer is simply a scalar. This was exactly what was happening in our exponential envelope code.

I didn't see an immediate way to fix it through the KFAC tagging, partly because I don't deeply understand how the kfac_jax library works. However, that particular layer doesn't need to be written that way and it can just as easily be written as an element-wise multiplication followed by a sum over the final axis. As such I fixed the bug by adding an ElementWiseMultiply Module and using this approach instead.

PTAL @nilin or @JiahaoYao

JiahaoYao commented 1 year ago

Hi @ggoldsh , thanks for having an fast fix on this, and the fix looks good to me.

ggoldsh commented 1 year ago

Thanks @JiahaoYao for the review!