mit-ll-responsible-ai / hydra-zen

Create powerful Hydra applications without the yaml files and boilerplate code.
https://mit-ll-responsible-ai.github.io/hydra-zen/
MIT License
338 stars 15 forks source link

Nested Dataclasses can't be initialized #635

Closed sorenmc closed 9 months ago

sorenmc commented 9 months ago

First of all thank you for this project.

Like many others I am striving to move away from yaml files. However i would still like to preserve the 1 main config approach that hydra is known for. I realize that one of hydra-zen's main points is to avoid this behavior to follow DRY principles as mentioned in the docs. I have a pretty big codebase that is relying on these hierarchically structured dicts, and would prefer not having to rewrite it all. A toy example to showcase what I am looking for can be seen below

from dataclasses import dataclass, field
from typing import Literal

from hydra_zen import store, instantiate, make_custom_builds_fn, zen

builds = make_custom_builds_fn(populate_full_signature=True)

@dataclass
class C:
    c_1: int = 1
    c_2: str = "c text"
    c_3: Literal["a", "b", "c"] = "a"

@dataclass
class B:
    b_1: float = 1.0
    b_2: str = "b text"
    c: C = field(default_factory= lambda: C())

@dataclass
class A:
    a_1: float = 1.0
    a_2: int = 1

@dataclass
class Config:
    a: A = field(default_factory= lambda: A())
    b: B = field(default_factory= lambda: B())

class Model:
    def __init__(self, config: Config):
        self._config = config

    @property
    def config(self):
        return self._config

# c setup
builds_c = builds(C)
c_text_1 = builds_c(c_2="c text 1")
c_text_2 = builds_c(c_2="c text 2")
c_text_3 = builds_c(c_2="c text 3")

c_store = store(group="config/B/C")
c_store(c_text_1, name="c_text_1")
c_store(c_text_2, name="c_text_2")
c_store(c_text_3, name="c_text_3")

builds_b = builds(B)
b_1_1 = builds_b(b_1=1, c=c_text_1)
b_1_2 = builds_b(b_1=2, c=c_text_2)
b_1_3 = builds_b(b_1=3, c=c_text_3)

b_store = store(group="config/B")
b_store(b_1_1, name="b_1_1")
b_store(b_1_2, name="b_1_2")
b_store(b_1_3, name="b_1_3")

builds_a = builds(A)
a_1_1 = builds_a(a_1=1.2)
a_1_2 = builds_a(a_1=2.2)
a_1_3 = builds_a(a_1=3.2)

a_store = store(group="config/A")
a_store(a_1_1, name="a_1_1")
a_store(a_1_2, name="a_1_2")
a_store(a_1_3, name="a_1_3")

builds_config = builds(Config)
config_1 = builds_config(a=a_1_1, b=b_1_1)
config_2 = builds_config(a=a_1_2, b=b_1_2)
config_3 = builds_config(a=a_1_3, b=b_1_3)

config_store = store(group="config")
config_store(config_1, name="config_1")
config_store(config_2, name="config_2")
config_store(config_3, name="config_3")

@store(name="config", hydra_defaults=["_self_", {"config": "config_1"}])
def task_function(config: Config):
    model = Model(config)
    print(model.config)

if __name__ == "__main__":
    store.add_to_hydra_store()
    zen(task_function).hydra_main(config_name="config", version_base="1.3", config_path=".")

With this fully encapsulated example i get the following error:

Exception has occurred: ConfigCompositionException
In 'config/config_1': ValidationError raised while composing config:
Invalid type assigned: Builds_Config is not a subclass of Config. value: {'_target_': '__main__.Config', 'a': {'_target_': '__main__.A', 'a_1': 1.2, 'a_2': 1}, 'b': {'_target_': '__main__.B', 'b_1': 1.0, 'b_2': 'b text', 'c': {'_target_': '__main__.C', 'c_1': 1, 'c_2': 'c text 1', 'c_3': 'a'}}}
    full_key: 
    object_type=Builds_task_function
omegaconf.errors.ValidationError: Invalid type assigned: Builds_Config is not a subclass of Config. value: {'_target_': '__main__.Config', 'a': {'_target_': '__main__.A', 'a_1': 1.2, 'a_2': 1}, 'b': {'_target_': '__main__.B', 'b_1': 1.0, 'b_2': 'b text', 'c': {'_target_': '__main__.C', 'c_1': 1, 'c_2': 'c text 1', 'c_3': 'a'}}}

The above exception was the direct cause of the following exception:

omegaconf.errors.ValidationError: Invalid type assigned: Builds_Config is not a subclass of Config. value: {'_target_': '__main__.Config', 'a': {'_target_': '__main__.A', 'a_1': 1.2, 'a_2': 1}, 'b': {'_target_': '__main__.B', 'b_1': 1.0, 'b_2': 'b text', 'c': {'_target_': '__main__.C', 'c_1': 1, 'c_2': 'c text 1', 'c_3': 'a'}}}
    full_key: 
    object_type=Builds_task_function

During handling of the above exception, another exception occurred:

  File "test/config.py", line 94, in <module>
    zen(task_function).hydra_main(config_name="config", version_base="1.3", config_path=".")
hydra.errors.ConfigCompositionException: In 'config/config_1': ValidationError raised while composing config:
Invalid type assigned: Builds_Config is not a subclass of Config. value: {'_target_': '__main__.Config', 'a': {'_target_': '__main__.A', 'a_1': 1.2, 'a_2': 1}, 'b': {'_target_': '__main__.B', 'b_1': 1.0, 'b_2': 'b text', 'c': {'_target_': '__main__.C', 'c_1': 1, 'c_2': 'c text 1', 'c_3': 'a'}}}
    full_key: 
    object_type=Builds_task_function
sorenmc commented 9 months ago

I made it work!

from dataclasses import dataclass, field
from typing import Literal

from hydra_zen import store, make_custom_builds_fn, zen

builds = make_custom_builds_fn(populate_full_signature=True)

@dataclass
class C:
    c_1: int = 1
    c_2: str = "c text"
    c_3: Literal["a", "b", "c"] = "a"

@dataclass
class B:
    b_1: float = 1.0
    b_2: str = "b text"
    c: C = field(default_factory= lambda: C())

@dataclass
class A:
    a_1: float = 1.0
    a_2: int = 1

@dataclass
class Config:
    a: A = field(default_factory= lambda: A())
    b: B = field(default_factory= lambda: B())

@dataclass
class WrapsConfig:
    config: Config = field(default_factory= lambda: Config())

class Model:
    def __init__(self, config: Config):
        self._config = config

    @property
    def config(self):
        return self._config

builds_c = builds(C)
c_text_1 = builds_c(c_2="c text 1")
c_text_2 = builds_c(c_2="c text 2")
c_text_3 = builds_c(c_2="c text 3")

c_store = store(group="config/b/c")
c_store(c_text_1, name="c_text_1")
c_store(c_text_2, name="c_text_2")
c_store(c_text_3, name="c_text_3")

builds_b = builds(B)
b_1_1 = builds_b(b_1=1, c=c_text_1)
b_1_2 = builds_b(b_1=2, c=c_text_2)
b_1_3 = builds_b(b_1=3, c=c_text_3)

b_store = store(group="config/b")
b_store(b_1_1, name="b_1_1")
b_store(b_1_2, name="b_1_2")
b_store(b_1_3, name="b_1_3")

builds_a = builds(A)
a_1_1 = builds_a(a_1=1.2)
a_1_2 = builds_a(a_1=2.2)
a_1_3 = builds_a(a_1=3.2)

a_store = store(group="config/a")
a_store(a_1_1, name="a_1_1")
a_store(a_1_2, name="a_1_2")
a_store(a_1_3, name="a_1_3")

builds_config = builds(Config)
config_1 = builds_config(a=a_1_1, b=b_1_1)
config_2 = builds_config(a=a_1_2, b=b_1_2)
config_3 = builds_config(a=a_1_3, b=b_1_3)

config_store = store(group="config")
config_store(config_1, name="config_1")
config_store(config_2, name="config_2")
config_store(config_3, name="config_3")

builds_wraps_config = builds(WrapsConfig, config=config_1)
store(builds_wraps_config,  name="default_config")

def task_function(config: Config):
    model = Model(config)
    print(model.config)

if __name__ == "__main__":
    store.add_to_hydra_store()
    zen(task_function).hydra_main(config_name="default_config", version_base="1.3", config_path=None)

This will also allow us to use the cli to run eg.

python config.py +config/b/c=c_text_1,c_text_2 -m

or even

python config.py "+config/b/c=glob(*)" -m

For some reason i have to use + to override, but it works for now. Would be nice to hear a reason why you can't override like you would normally be able to