cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
215 stars 17 forks source link

Recommended way to save/load tx.Modules? #76

Open bhoov opened 2 years ago

bhoov commented 2 years ago

First off, I love this library. It is so much more elegant and intuitive than flax while being more fully featured than equinox (I guess it helps that I use dataclasses regularly).

What would be the recommended way to save trained tx.Modules? I often find myself making a very lightweight tx.Module that mimics the functionality of flax...TrainState for my training runs and it would be nice to know a standard way to capture all the static fields and nodes in a single file. I know pickling is an option, but I have always found it safer to save a simple python dict of my model and find a way to load that simple dict back in, much like pytorch's state_dict interface.

cgarciae commented 1 year ago

Hey @bhoov, sorry this skipped my inbox.

First off, I love this library. It is so much more elegant and intuitive than flax while being more fully featured than equinox (I guess it helps that I use dataclasses regularly).

Thanks!

What would be the recommended way to save trained tx.Modules?

I've been using just cloudpickle but I wonder if maybe trying to use flax.serialization somehow is better.

I often find myself making a very lightweight tx.Module that mimics the functionality of flax...TrainState for my training runs and it would be nice to know a standard way to capture all the static fields and nodes in a single file.

I do this often, is just that in Treex, normally they are treeo.Trees that are more lightweight, since this is so natural I've never had the urge to create an abstraction for it.