Open jglaser opened 2 years ago
Thanks for the suggestion, please check out the added https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.repeat.html#neural_tangents.stax.repeat
One caveat that makes this less elegant than we'd like is that kernel_fn
sometimes makes non-jittable changes to the metadata of the Kernel
object, and when this happens, lax.scan
fails (see especially second warning), so unfortunately for now it's less flexible than stax.serial
.
awesome, thanks!
The LLVM compiler pass uses excessive amounts of memory for deep networks which are constructed like this
In fact, the compilation may eventually OOM.
The reason is that the
serial
combinator internally relies on a python for loop (with carry) to support mixed input sequences.It would be nice to have a specialization for the case in which the same layer is repeated
n
times, which could then usejax.lax.scan()
to save compilation time by avoiding loop unrolling.Suggestion:
Use like this