omry / omegaconf

Flexible Python configuration system. The last one you will ever need.
BSD 3-Clause "New" or "Revised" License
1.97k stars 111 forks source link

Merging yaml and structured config with enum #699

Closed collinmccarthy closed 3 years ago

collinmccarthy commented 3 years ago

Describe the bug My use case for OmegaConf is to be able to define a default config, and let the user specify a YAML file with overrides. This YAML file can be in standard form, with nested dictionaries, or as a "flattened" dictionary with dot-notation like simple.num=10. I use flatten and inflate methods to handle both cases, and when using enums the type information is lost in one specific situation, shown below.

The merging works if I inflate the YAML file, then merge with the default config. The merging "fails" (in the sense the type information is lost) if I flatten both configs, then merge, then inflate. The type information is important so I can compare with config.simple.height == Height.TALL or something like that. I'm submitting this bug report mainly because this was difficult for me to track down and I don't think this behavior is intentional.

To Reproduce

"""Flatten and inflate a nested dictionary."""
from dataclasses import dataclass, field
from enum import Enum
from collections.abc import MutableMapping
from typing import Any, Dict, List, Optional, Tuple

from omegaconf import OmegaConf

def flatten(
    nested_dict: MutableMapping,
    parent_key: Optional[str] = None,
    sep: str = '.'
) -> Dict:
    """Flatten a nested dictionary into a flat dictionary.

    Refactored version of https://stackoverflow.com/a/6027615/12422298.

    Example:
        >>> d = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
        >>> flatten(d)
        {'a': 1, 'b.c': 2, 'b.d.e': 3}
    """
    items: List[Tuple[Any, Any]] = []
    for key, value in nested_dict.items():
        new_key = key if parent_key is None else f'{parent_key}{sep}{key}'
        if isinstance(value, MutableMapping):  # Works for Dict, DictConfig, others extending this
            flattened_subdict = flatten(nested_dict=value, parent_key=new_key, sep=sep)
            items.extend(flattened_subdict.items())
        else:
            items.append((new_key, value))
    return dict(items)

def inflate(
    flattened_dict: MutableMapping,
    sep: str = '.'
) -> Dict:
    """Inflate a flattened dictionary into a nested dictionary.

    Refactored version of https://gist.github.com/fmder/494aaa2dd6f8c428cede
    referenced in the comments of https://stackoverflow.com/a/6027615/12422298.

    Example:
        >>> d = {'a': 1, 'b.c': 2, 'b.d.e': 3}
        >>> inflate(d)
        {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
    """
    items: Dict = dict()
    for key, value in flattened_dict.items():
        keys = key.split(sep)

        # Create sub-dicts for each separator
        sub_items = items
        for sub_key in keys[:-1]:
            # Create sub-dict if it doesn't exist
            if sub_key not in sub_items:
                sub_items[sub_key] = dict()

            # Step down into sub-dict for next sub-key
            sub_items = sub_items[sub_key]

        sub_items[keys[-1]] = value
    return items

class Height(Enum):
    SHORT = 0
    TALL = 1

@dataclass
class SimpleTypesNoEnum:
    num: int = 10
    description: str = "text"

@dataclass
class SimpleTypesEnum:
    num: int = 10
    description: str = "text"
    height: Height = Height.SHORT

@dataclass
class ConfigEnum:
    simple: SimpleTypesEnum = field(default_factory=SimpleTypesEnum)

@dataclass
class ConfigNoEnum:
    simple: SimpleTypesNoEnum = field(default_factory=SimpleTypesNoEnum)

def get_config(with_enum: bool):
    return OmegaConf.structured(ConfigEnum) if with_enum else OmegaConf.structured(ConfigNoEnum)

def get_yaml(with_enum: bool, flattened: bool):
    if with_enum:
        if flattened:
            yaml_str = """
            simple.num: 22
            simple.description: test
            simple.height: TALL
            """
        else:
            yaml_str = """
            simple:
                num: 22
                description: test
                height: TALL
            """
    else:  # no enum
        if flattened:
            yaml_str = """
            simple.num: 22
            simple.description: test
            """
        else:
            yaml_str = """
            simple:
                num: 22
                description: test
            """
    return yaml_str

def test_merge(with_enum: bool):
    default_config = get_config(with_enum)
    enum_str = 'with enum' if with_enum else 'no enum'

    # Test 1: Test flatten / inflate on structured config
    default_config2 = OmegaConf.create(inflate(flatten(default_config)))
    print('- ' * 40)
    print(f'Test 1 {enum_str}: default_config == inflate(flatten(default_config)):')
    print(default_config == default_config2)
    if with_enum:
        print(f'LHS.height: {default_config.simple.height},  RHS.height: {default_config2.simple.height}')

    # Test 2: Test flatten / inflate on merged config
    yaml_config = OmegaConf.create(get_yaml(with_enum=with_enum, flattened=False))
    merge_config = OmegaConf.merge(default_config, yaml_config)
    merge_config2 = OmegaConf.create(inflate(flatten(merge_config)))
    print('- ' * 40)
    print(f'Test 2 {enum_str}: merge_config == inflate(flatten(merge_config)):')
    print(merge_config == merge_config2)
    if with_enum:
        print(f'LHS.height: {merge_config.simple.height},  RHS.height: {merge_config2.simple.height}')

    # Test 3: Take flatten yaml, inflate it, then merge it; compare with previous
    yaml_flat_config = OmegaConf.create(get_yaml(with_enum=with_enum, flattened=True))
    yaml_config = OmegaConf.create(inflate(yaml_flat_config))
    merge_config2 = OmegaConf.merge(default_config, yaml_config)
    print('- ' * 40)
    print(f'Test 3 {enum_str}: merge_config == merge(default_config, inflate(yaml_flat_config)):')
    print(merge_config == merge_config2)
    if with_enum:
        print(f'LHS.height: {merge_config.simple.height},  RHS.height: {merge_config2.simple.height}')

    # Test 4: Take flatten yaml, flatten default, merge, then inflate; compare with previous
    default_flat_config = OmegaConf.create(flatten(default_config))
    merge_flat_config = OmegaConf.merge(default_flat_config, yaml_flat_config)
    merge_config2 = OmegaConf.create(inflate(merge_flat_config))
    print('- ' * 40)
    print(f'Test 4 {enum_str}: merge_config == inflate(merge(default_flat_config, yaml_flat_config))')
    print(merge_config == merge_config2)
    if with_enum:
        print(f'LHS.height: {merge_config.simple.height},  RHS.height: {merge_config2.simple.height}')
    print('- ' * 40)

if __name__ == '__main__':
    print('-' * 80)
    test_merge(with_enum=False)
    print('-' * 80)
    test_merge(with_enum=True)
    print('-' * 80)

This produces the following output:

--------------------------------------------------------------------------------
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Test 1 no enum: default_config == inflate(flatten(default_config)):
True
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Test 2 no enum: merge_config == inflate(flatten(merge_config)):
True
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Test 3 no enum: merge_config == merge(default_config, inflate(yaml_flat_config)):
True
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Test 4 no enum: merge_config == inflate(merge(default_flat_config, yaml_flat_config))
True
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
--------------------------------------------------------------------------------
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Test 1 with enum: default_config == inflate(flatten(default_config)):
True
LHS.height: Height.SHORT,  RHS.height: Height.SHORT
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Test 2 with enum: merge_config == inflate(flatten(merge_config)):
True
LHS.height: Height.TALL,  RHS.height: Height.TALL
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Test 3 with enum: merge_config == merge(default_config, inflate(yaml_flat_config)):
True
LHS.height: Height.TALL,  RHS.height: Height.TALL
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Test 4 with enum: merge_config == inflate(merge(default_flat_config, yaml_flat_config))
False
LHS.height: Height.TALL,  RHS.height: TALL
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
--------------------------------------------------------------------------------

Expected behavior As shown in the last few lines of the output, test 4 fails when using enums. The flattened YAML config has a string for simple.height and the flattened structured config has an enum, as expected. But when they merge they produce a string, rather than an enum like in test 3. I expected test 3 and test 4 to work for both cases, with and without enums.

Additional context

Thank you for your help! Really loving omegaconf!

omry commented 3 years ago

Thanks for reporting. Can you boil this down to a simple minimal example that does not take a few pages of code? ideally it would have an enum, a dataclass containing it, and an OmegaConf.merge call.

Your current PR includes so much new code that it's too time consuming to understand if you are reporting a bug in OmegaConf or in your own logic.

omry commented 3 years ago

In addition, please test this on OmegaConf 2.1, you can try one of the dev releases (latest is 2.1.0.dev26).

omry commented 3 years ago

As a side note: Please be aware that OmegaConf objects are overriding the __eq__ so it's dangerous to rely on it for testing.

collinmccarthy commented 3 years ago

Will do. I can't make it "simpler" per say, but I can remove the stuff that works which will cut it down a bit. I'll get to that later today.

omry commented 3 years ago

just use a bit of hard coded config (did not run this):

cfg1 = OmegaConf.create({"height":  "TALL"})
cfg2 = OmegaConf.merge(SimpleTypesEnum, cfg1)
assert cfg2.highet == Height.TALL
collinmccarthy commented 3 years ago

I think this is as simple as I can get it. Same thing in 2.1, so maybe this is expected behavior or more of a "feature request". I had thought the types would be preserved in both cases, but maybe there's no way to access the constructor to preserve the type in the second case? Not entirely sure...

from enum import Enum
from dataclasses import dataclass
from omegaconf import OmegaConf

class Height(Enum):
    SHORT = 0
    TALL = 1

@dataclass
class Simple:
    height: Height = Height.SHORT

struct_config = OmegaConf.structured(Simple)
struct_config = OmegaConf.structured(Simple(**{'height': struct_config['height']}))
dict_config = OmegaConf.create({'height': 'TALL'})
merge_config = OmegaConf.merge(struct_config, dict_config)
assert type(merge_config.height) == type(struct_config.height)  # Works

struct_config = OmegaConf.structured(Simple)
struct_config = OmegaConf.create({'height': struct_config['height']})
merge_config = OmegaConf.merge(struct_config, dict_config)
assert type(merge_config.height) == type(struct_config.height)  # Fails

Thanks!

omry commented 3 years ago

This is expected. In your second struct_config, there is no type associated with height, which means the node can take any supported type. You are merging a string onto it, and this is what you end up with.

The correct way to do what you are trying to do is:

schema = OmegaConf.structured(Simple)
cfg = OmegaConf.create({"height": Height.SHORT})
# cfg = OmegaConf.create({"height": "SHORT"})  # This will also work.
merged = OmegaConf.merge(schema, cfg)
assert type(merged.height) == Height
collinmccarthy commented 3 years ago

Okay, got it. So basically I can't "flatten" a config created with OmegaConf.structured() into a dictionary, and then merge it, because during the merging process it won't have the typing information of the dataclass to retain the original types.

Thank you for your help in understanding this!

omry commented 3 years ago

Correct. to retain the type information, I recommend using pickle. As an alternative, take a look at Hydra which allows you to drive the config composition via the config (among many other things).