clement-moulin-frier / vivarium

MIT License
5 stars 0 forks source link

Use Flax dataclasses to store state #52

Closed corentinlger closed 3 months ago

corentinlger commented 6 months ago

Description

Hi, I think you already mentioned it @clement-moulin-frier but using flax dataclasses could potentially help us a lot.

Indeed, they are widely used in RL projects using jax (mostly to store elements of agents, states ...) and already implement some interesting features, notably to_state_dict and from_state_dict.

Flax also comes with an interesting serialization module that enables using methods such as to_bytes or from_bytes directly on objects such as flax dataclasses for example (could potentially be interesting if we want to transfer states in a network).

clement-moulin-frier commented 6 months ago

It looks interesting indeed. I think the first thing to check at this stage is if they are compatible with jax-md. Maybe you can try a quick test like initializing a State with the exact same structure we currently have but using a flax dataclass instead and check if the simulation runs correctly