Open jrmaddison opened 1 day ago
The simplest way to do this might be to register Keras Layer
s as pytree nodes, repurposing Keras serialization for the flatten and unflatten, and then swap the vmap
and fori_loop
used to evaluate the Dynamics
layer so that the embedded network operates on the whole batch. It might alternatively be possible to have Dynamics
inherit from keras.layers.RNN
.
Keras cannot deserialize when Traced
variables appear, which I think would be needed for the pytree node approach.
Required e.g. for batch normalization