In previous PR #125 I attempted to fix the envelopes so that KFAC would register them properly, but just before merging I tweaked things in a way that rebroke it.
In order for KFAC to register as scale_and_shift it seems that the kernel must be multiplied directly onto the inputs without any indexing. This is actually fine for our application since jax broadcasting will handle the edge case where the kernel becomes larger than the inputs after padding.
In previous PR #125 I attempted to fix the envelopes so that KFAC would register them properly, but just before merging I tweaked things in a way that rebroke it.
In order for KFAC to register as scale_and_shift it seems that the kernel must be multiplied directly onto the inputs without any indexing. This is actually fine for our application since jax broadcasting will handle the edge case where the kernel becomes larger than the inputs after padding.
PTAL @nilin