daniel-j-h / nedem

Neural implicit digital elevation model
MIT License
2 stars 0 forks source link

Save / load weights in jax format #10

Open daniel-j-h opened 2 years ago

daniel-j-h commented 2 years ago

At the moment we don't save the weights, we simply write out predictions.

We should write out the best weights (based on loss), and also allow users to load them back in.

daniel-j-h commented 2 years ago

https://flax.readthedocs.io/en/latest/flax.serialization.html

daniel-j-h commented 2 years ago

flax checkpoint serialization right now depends on tensorflow https://github.com/google/flax/issues/1924 for I/O.

This is an upstream issue and a workaround could be to use their 2nd format (sigh) with msgpack.

daniel-j-h commented 2 years ago

Raising FrozenDict issue upstream https://github.com/google/flax/issues/2005