I realized my previous PR #114 introduced a bug where KFAC would throw an error when running vmc-molecule on a system with only one ion.
On debugging it seems like kfac_jax has a bug where it ends up dividing by zero and throwing a floating point error when a layer of the network involves only 1x1 kernels, so that every block of the block-diagonal part of KFAC for that layer is simply a scalar. This was exactly what was happening in our exponential envelope code.
I didn't see an immediate way to fix it through the KFAC tagging, partly because I don't deeply understand how the kfac_jax library works. However, that particular layer doesn't need to be written that way and it can just as easily be written as an element-wise multiplication followed by a sum over the final axis. As such I fixed the bug by adding an ElementWiseMultiply Module and using this approach instead.
I realized my previous PR #114 introduced a bug where KFAC would throw an error when running vmc-molecule on a system with only one ion.
On debugging it seems like kfac_jax has a bug where it ends up dividing by zero and throwing a floating point error when a layer of the network involves only 1x1 kernels, so that every block of the block-diagonal part of KFAC for that layer is simply a scalar. This was exactly what was happening in our exponential envelope code.
I didn't see an immediate way to fix it through the KFAC tagging, partly because I don't deeply understand how the kfac_jax library works. However, that particular layer doesn't need to be written that way and it can just as easily be written as an element-wise multiplication followed by a sum over the final axis. As such I fixed the bug by adding an ElementWiseMultiply Module and using this approach instead.
PTAL @nilin or @JiahaoYao