Open daniel-j-h opened 2 years ago
flax checkpoint serialization right now depends on tensorflow https://github.com/google/flax/issues/1924 for I/O.
This is an upstream issue and a workaround could be to use their 2nd format (sigh) with msgpack.
Raising FrozenDict issue upstream https://github.com/google/flax/issues/2005
At the moment we don't save the weights, we simply write out predictions.
We should write out the best weights (based on loss), and also allow users to load them back in.