jrmaddison / bt_ocean

Differentiable GPU capable barotropic vorticity solver using JAX, for rapid testing of online learning algorithms
https://jrmaddison.github.io/bt_ocean/
MIT License
1 stars 0 forks source link

Allow embedded neural network evaluation to have side effects #37

Open jrmaddison opened 1 day ago

jrmaddison commented 1 day ago

Required e.g. for batch normalization

jrmaddison commented 1 day ago

The simplest way to do this might be to register Keras Layers as pytree nodes, repurposing Keras serialization for the flatten and unflatten, and then swap the vmap and fori_loop used to evaluate the Dynamics layer so that the embedded network operates on the whole batch. It might alternatively be possible to have Dynamics inherit from keras.layers.RNN.

jrmaddison commented 1 day ago

Keras cannot deserialize when Traced variables appear, which I think would be needed for the pytree node approach.