jeffminlin / vmcnet

Flexible, general-purpose VMC framework, built on JAX.
https://jeffminlin.github.io/vmcnet/
MIT License
24 stars 2 forks source link

Refix KFAC registration by removing indexing into kernel #126

Closed ggoldsh closed 1 year ago

ggoldsh commented 1 year ago

In previous PR #125 I attempted to fix the envelopes so that KFAC would register them properly, but just before merging I tweaked things in a way that rebroke it.

In order for KFAC to register as scale_and_shift it seems that the kernel must be multiplied directly onto the inputs without any indexing. This is actually fine for our application since jax broadcasting will handle the edge case where the kernel becomes larger than the inputs after padding.

PTAL @nilin