Closed corentinlger closed 3 months ago
It looks interesting indeed. I think the first thing to check at this stage is if they are compatible with jax-md. Maybe you can try a quick test like initializing a State
with the exact same structure we currently have but using a flax dataclass instead and check if the simulation runs correctly
Description
Hi, I think you already mentioned it @clement-moulin-frier but using flax dataclasses could potentially help us a lot.
Indeed, they are widely used in RL projects using jax (mostly to store elements of agents, states ...) and already implement some interesting features, notably
to_state_dict
andfrom_state_dict
.Flax also comes with an interesting serialization module that enables using methods such as
to_bytes
orfrom_bytes
directly on objects such as flax dataclasses for example (could potentially be interesting if we want to transfer states in a network).