mit-ll-responsible-ai / hydra-zen

Create powerful Hydra applications without the yaml files and boilerplate code.
https://mit-ll-responsible-ai.github.io/hydra-zen/
MIT License
338 stars 15 forks source link

Support for specifying arbitrary configs given kwargs in a method signature #654

Open alexanderswerdlow opened 8 months ago

alexanderswerdlow commented 8 months ago

I have a somewhat weird use-case where I'm merging global configs quite often (From #621). Sometimes, I want to override one of the nested fields entirely [e.g., changing the class]. It seems that even if I specify a new builds(...) that any nested arguments for that config are still kept.

This causes an issue where if the new class has a different signature without a previously specified arg, I get something like this:

hydra.errors.ConfigCompositionException: In 'modes/mode_1': ConfigKeyError raised while composing config:
Key 'other_param' not in 'PartialBuilds_NewDatasetCls'
    full_key: dataset.train_dataset.other_param

The two solutions I can think of are: 1) A way to specify that I want a builds() to totally overwrite any previous config for that key. This would be really great, but I'm not sure this is possible given how things are setup with merging. 2) A slightly hackier approach [that would be totally workable] would be to simply let the configs be passed as kwargs. However, it seems that hydra_zen doesn't allow for this and will error out unless the configs are explicitly declared in the __init__ signature.

It's not a minimal reproduction (apologies for that), but here's a somewhat concise example stripped from my codebase.

from functools import partial
from typing import Any, Optional
from hydra_zen import builds, store
from hydra_zen import make_config, store
from hydra_zen.wrapper import default_to_config
from dataclasses import is_dataclass
from omegaconf import OmegaConf
from typing import Optional
from omegaconf import OmegaConf

def destructure(x):
    x = default_to_config(x)  # apply the default auto-config logic of `store`
    if is_dataclass(x):
        # Recursively converts:
        # dataclass -> omegaconf-dict (backed by dataclass types)
        #           -> dict -> omegaconf dict (no types)
        return OmegaConf.create(OmegaConf.to_container(OmegaConf.create(x)))  # type: ignore
    return x

destructure_store = store(to_config=destructure)

def global_store(name: str, group: str, hydra_defaults: Optional[list[Any]] = None, **kwargs):
    cfg = make_config(
        hydra_defaults=hydra_defaults if hydra_defaults is not None else ["_self_"],
        bases=(BaseConfig,),
        zen_dataclass={"kw_only": True},
        **kwargs,
    )
    destructure_store(
        cfg,
        group=group,
        package="_global_",
        name=name,
    )
    return cfg

auto_store = store(group=lambda cfg: cfg.name)
mode_store = partial(global_store, group="modes")

auto_store(
    DatasetConfig,
    train_dataset=builds(OriginalDatasetCls, populate_full_signature=True, zen_partial=True, split="train"),
    name="movi_e",
)

mode_store(name="mode_1", dataset=dict(train_dataset=dict(other_param=True)))

# Note: NewDatasetCls does not accept "other_param" arg, and has def __init__(self, split: str, **kwargs)
mode_store(name="mode_2", dataset=dict(train_dataset=builds(NewDatasetCls, populate_full_signature=True, zen_partial=True, split="train")))