mit-ll-responsible-ai / hydra-zen

Create powerful Hydra applications without the yaml files and boilerplate code.
https://mit-ll-responsible-ai.github.io/hydra-zen/
MIT License
315 stars 13 forks source link

Pydantic Parser causes issues when pickling configs #717

Open lebrice opened 1 month ago

lebrice commented 1 month ago

Hello there!

I'm getting some weird issue related to pickling of configs, but it seems to only happen when using the new feature causes with the pydantic parser. I'll try to make a good reproduction example soon, but the message I'm getting is something like this:

_pickle.PicklingError: Can't pickle <function CosineAnnealingLR at 0x7c97d754dda0>: it's not the same object as torch.optim.lr_scheduler.CosineAnnealingLR

This only happens with hydra_zen.instantiate(config, _target_wrapper_=pydantic_parser), and does not happen with only hydra_zen.instantiate.

Here's what my config looks like, roughly:

#project/configs/lr_scheduler/__init__.py
CosineAnnealingLRConfig = hydra_zen.builds(
    torch.optim.lr_scheduler.CosineAnnealingLR,
    populate_full_signature=True,
    zen_partial=True,
    zen_exclude=["optimizer", "verbose"],
    T_max=85,
    eta_min=1e-5,
    zen_dataclass={"module": "project.configs.lr_scheduler", "cls_name": "CosineAnnealingLRConfig"},
)

Does this ring a bell? Could it be because of the wrapping of functions that seems to happens when using a _target_wrapper_ but not otherwise?

Thanks a lot!

rsokl commented 1 month ago

A repro would be great. I can imagine ways in which _target_wrapper_ could mess with pickling, but it isn't clear what natural code patterns would result in this. Happy to take a look.

lebrice commented 1 month ago

Here's a reproduction script:

#hydra_zen_issue_debug.py
import pickle

import hydra_zen
import torch.optim
from hydra_zen.third_party.pydantic import pydantic_parser  # noqa

AdamConfig = hydra_zen.builds(
    torch.optim.Adam,
    populate_full_signature=True,
    zen_partial=True,
    zen_exclude=["params"],
    zen_dataclass={
        "cls_name": "AdamConfig",
        "module": "hydra_zen_issue_debug",
        "frozen": True,
    },
    zen_wrappers=[pydantic_parser],  # comment this out to make it work below.
)

# @hydra_zen.hydrated_dataclass(
#     torch.optim.Adam,
#     frozen=True,
#     zen_partial=True,
#     zen_wrappers=[pydantic_parser],
#     populate_full_signature=True,
# )
# class AdamConfig: ...

def main():
    obj = AdamConfig(lr=3e-4)

    # NOTE: perhaps unrelated, but the pickling of `obj` (the config dataclass instance) doesn't
    # work here, but works if we use `hydrated_dataclass` above.
    # restored_obj = pickle.loads(pickle.dumps(obj))
    # assert restored_obj == obj

    fn = hydra_zen.instantiate(obj)
    restored_fn = pickle.loads(pickle.dumps(fn))  # <-- this fails when using the pydantic parser
    assert restored_fn.func == fn.func
    assert restored_fn.args == fn.args
    assert restored_fn.keywords == fn.keywords

    print("All good.")

if __name__ == "__main__":
    main()
rsokl commented 1 month ago

Ah. I see what is going on. fn used to be partial(Adam, ...) but now it is a partial(constructor_as_fn(Adam, ...). See here

Basically, I do not wrap class-objects in-place because this entails potentially sketchy mutation of the objects. Instead I make a function whose signature matches that of the class object, that is validated by pydantic, and that, when called, returns an instance of the class object. That function is what is getting pickled.

It is not immediately obvious how I should fix this, but perhaps I can change the function to a class that is designed to pickle to the underlying class.

Sorry about this. Always hard to anticipate all of these edge cases. Thanks for the repro!