brentyi / jax_dataclasses

Pytrees + dataclasses ❤️
MIT License
61 stars 6 forks source link

Serialization of static fields? #1

Open erdmann opened 2 years ago

erdmann commented 2 years ago

Thanks for the handy library!

I have a pytree_dataclass that contains a few static_fields that I would like to have serialized by the facilities in flax.serialize. I noticed that jax_dataclasses.asdict handles these, but that flax.serialization.to_state_dict and flax.serialization.to_bytes both ignore them. What is the correct way (if any) to have these fields included in flax's serialization? Should I be using another technique?

import jax_dataclasses as jdc
from jax import numpy as jnp
import flax.serialization as fs

@jdc.pytree_dataclass
class Demo:
    a: jnp.ndarray = jnp.ones(3)
    b: bool = jdc.static_field(default=False)

demo = Demo()
print(f'{jdc.asdict(demo) = }')
print(f'{fs.to_state_dict(demo) = }')
print(f'{fs.from_bytes(Demo, fs.to_bytes(demo)) = }')

# jdc.asdict(demo) = {'a': array([1., 1., 1.]), 'b': False}
# fs.to_state_dict(demo) = {'a': DeviceArray([1., 1., 1.], dtype=float64)}
# fs.from_bytes(Demo, fs.to_bytes(demo)) = {'a': array([1., 1., 1.])}

Thanks in advance!

brentyi commented 2 years ago

Appreciate the detailed example :)

To me the main reasons for excluding the static fields are:

Does that make sense?

For serializing/deserializing dataclasses, the usual pattern I use is to save two files: the serialized state dictionary from Flax (for ML applications: a "checkpoint") and a configuration object that can be passed into a helper function for instantiating the dataclass with all of the right array shapes + static fields populated (for ML applications: a "model config"). Here's an example of a helper function like this.

I think this works okay, but open to suggestions for improvements/new APIs. Maybe it's possible for a serialization helper to return a 2-element tuple consisting of both the state dictionary and the attributes needed to reproduce the tree structure + static fields?

erdmann commented 2 years ago

Yes, that makes sense. Thanks. As a long-time user of pytorch, I have always enjoyed the simplicity of torch.load and torch.save (pickles with some extra sauce, I believe) and have also found it convenient that by default load automatically places tensors back on the GPU from which they were saved.

For my daily work in a Jupyter notebook in which I rewrite classes a lot, doing a straight torch.save/load of dataclasses is fragile because the definitions of my dataclasses are in flux. The usual recommendation to save the state dict then loses the convenience of having a single file and a single load statement that remembers the class. In working with simple (non-nested) jax_dataclasses.pytree_dataclass objects a lot lately, I have found the following to be convenient:

import torch
import jax_dataclasses as jdc

def save(dc: jdc.pytree_dataclass, filename: str) -> None:
    torch.save((dc.__class__.__name__, jdc.asdict(dc)), filename)

def load(filename: str) -> jdc.pytree_dataclass:
    cls_name, data = torch.load(filename)
    return jax.device_put(eval(cls_name)(**data))

This way, if the given dataclass changes its definition but retains the same fields, loading still works. If the class name changes or the fields change, it's still trivial to get the state dict with a straight torch.load. And it's fast and avoids having a sidecar file.

From a user perspective, torch.load/save are super convenient, so something implementing a similarly fast and simple interface would be great. (Maybe a solution utilizing __getstate__ and __setstate__ somehow?)

brentyi commented 2 years ago

Thanks for clarifying!

I've run into similar issues, and something like the snippet you suggested sounds really useful. Main desired features on my end would be (1) better support for nested structures and (2) possibly avoiding the eval() call.

If your pytrees contain only objects that are serializable via PyYAML — this includes most Python objects, JAX/numpy arrays, flax FrozenDicts, etc, but not things like lambda functions — I have some dataclass serialization utilities that gets us partially there. The basic idea is to take a class reference when deserializing, and recursively traverse the type annotations to understand how to reconstruct dataclasses. It's also pretty easy to update the YAML via a text editor if anything gets renamed.

Example of serialization:

import jax_dataclasses as jdc
from jax import numpy as jnp

import dcargs

@jdc.pytree_dataclass
class Tree:
    number: float
    flag: bool = jdc.static_field()

yaml = dcargs.to_yaml(Tree(3.0, False))
print(yaml)

# Recovery
dcargs.from_yaml(Tree, yaml)

Output:

!dataclass:Tree
flag: false
number: 3.0

Or, to mimic the torch.save() and torch.load() syntax:

import dataclasses
import pathlib
from typing import Any, Type, TypeVar, Union

import dcargs

Dataclass = Any
DataclassT = TypeVar("DataclassT")

def save(instance: Dataclass, path: Union[str, pathlib.Path]) -> None:
    assert dataclasses.is_dataclass(instance)
    with open(path, "w") as file:
        file.write(dcargs.to_yaml(instance))

def load(cls: Type[DataclassT], path: Union[str, pathlib.Path]) -> DataclassT:
    assert dataclasses.is_dataclass(cls)
    with open(path, "r") as file:
        output = dcargs.from_yaml(cls, file.read())
    return output

Another possible source of inspiration is dacite, which should work out-of-the-box with our dataclass objects and might be used to achieve a similar goal, albeit with a separate set of constraints / possible failure cases... will continue to think about this; seems like there's room for a more robust solution.