lebrice / SimpleParsing

Simple, Elegant, Typed Argument Parsing with argparse
MIT License
386 stars 47 forks source link

Subgroups choice isn't serialized in yaml file #173

Closed kyamagu closed 1 year ago

kyamagu commented 1 year ago

Describe the bug When there is a subgroup, parse API fails.

To Reproduce

from dataclasses import dataclass

from tempfile import NamedTemporaryFile
from simple_parsing import ArgumentParser, parse, subgroups, Serializable

@dataclass
class ModelConfig(Serializable):
    pass

@dataclass
class ModelAConfig(ModelConfig):
    lr: float = 3e-4
    num_blocks: int = 2

@dataclass
class ModelBConfig(ModelConfig):
    lr: float = 1e-3
    dropout: float = 0.1

@dataclass
class Config(Serializable):
    model: ModelConfig = subgroups(
        {"model_a": ModelAConfig, "model_b": ModelBConfig},
        default=ModelAConfig(),
    )

config = Config()
print(config)
print()

with NamedTemporaryFile(mode="w+", suffix=".yaml") as f:
    config.dump_yaml(f)
    print(open(f.name).read())
    print()
    config2 = parse(Config, config_path=f.name, args=[])
    print(config2)

Expected behavior Code should

Config(model=ModelAConfig(lr=0.0003, num_blocks=2))

model:
  lr: 0.0003
  num_blocks: 2

Config(model=ModelAConfig(lr=0.0003, num_blocks=2))

Actual behavior parse fails.

Config(model=ModelAConfig(lr=0.0003, num_blocks=2))

model:
  lr: 0.0003
  num_blocks: 2

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-30-a77499e72e1d>](https://localhost:8080/#) in <module>
     35     print(open(f.name).read())
     36     print()
---> 37     config2 = parse(Config, config_path=f.name, args=[])
     38     print(config2)

2 frames
[/usr/local/lib/python3.7/dist-packages/simple_parsing/parsing.py](https://localhost:8080/#) in parse_known_args(self, args, namespace, attempt_to_reorder)
    375                         # A dynamic kind of default value?
    376                         raise RuntimeError(
--> 377                             f"Don't yet know how to detect which subgroup in {subgroup_dict} was "
    378                             f"chosen, if the default value is {value} and the field default is "
    379                             f"{field_wrapper.field.default}"

RuntimeError: Don't yet know how to detect which subgroup in {'model_a': <class '__main__.ModelAConfig'>, 'model_b': <class '__main__.ModelBConfig'>} was chosen, if the default value is {'lr': 0.0003, 'num_blocks': 2} and the field default is ModelAConfig(lr=0.0003, num_blocks=2)

Desktop (please complete the following information):

Additional context

Serialization module seems to throw a warning instead of runtime error.

Config.loads_yaml(config.dumps_yaml())
WARNING:simple_parsing.helpers.serialization.serializable:Dropping extra args {'lr': 0.0003, 'num_blocks': 2}
Config(model=ModelConfig())
lebrice commented 1 year ago

Hey there @kyamagu , thanks for posting!

Yeah, the parse and subgroups functions are relatively new, so I expected some bugs to show up eventually. This seems to only occur when using a dataclass with subgroups. Nice catch!

I see one problem with your expected behaviour though: How are we supposed to figure out which ModelConfig to instantiate in this case? (since the fields are the same, and a subgroup wasn't chosen with the --model argument?)

In this case, I suppose we could fallback to using the default subgroup choice (ModelAConfig), but more generally, the config file contents has to match the chosen subgroup fields. This might become an issue when the subgroups have different fields (e.g. AdamConfig vs SGDConfig in ML scripts). For example:

# this is fine:
$ python issue.py --model=model_a --config_path=model_a.yaml
Config(model=ModelAConfig(lr=0.0003, num_blocks=2))

# this is fine-ish, because the fields are the same  
$ python issue.py --model=model_b --config_path=model_a.yaml
Config(model=ModelBConfig(lr=0.0003, num_blocks=2))

# What should we do here?
$ python issue.py --config_path=config.yaml 
??

What do you think? Do you have an opinion on this?

Thanks again for posting! :)

kyamagu commented 1 year ago

@lebrice For this, I believe there should be an additional field to disambiguate subgroups in the config file, though this might require having an additional field in the dataclass to use subgroups. In any case, class information has to be serialized somewhere.

_model: ModelAConfig
model:
  lr: 0.0003
  num_blocks: 2
@dataclass
class Config(Serializable):
    _model: str = "ModelAConfig"
    model: ModelConfig = subgroups(
        {"model_a": ModelAConfig, "model_b": ModelBConfig},
        default=ModelAConfig(),
    )
lebrice commented 1 year ago

@kyamagu I just realized, if you just dump objects to a yaml file, without serializing them to a dictionary first, pyyaml automatically saves the types of the objects, along with their attributes.

Perhaps there could be a convert_to_dicts=False argument to save_yaml, that way, the object types could be preserved and restored upon saving/loading with pyyaml. I'll check how easy this would be to do. Seems easier to me than adding an extra key in the yaml files.

I'll keep you posted, let me know what you think! :)

norabelrose commented 1 year ago

Have we decided how we want to implement this functionality? I just ran into this bug and it's a pretty serious problem for my intended use-case: I'd like to use subgroups and be able to completely deserialize a config from yaml and run it. I'm willing to contribute some code to make this functionality happen.

lebrice commented 1 year ago

Hey there @norabelrose, is serializing the dataclasses with yaml directly not a good solution in your case?

norabelrose commented 1 year ago

Oh sorry I don't know, I've never actually used yaml directly before. Maybe that works. But it still seems a bit weird to have a partially broken YAML serialization functionality in this library, so we should probably either fix it or just scrap it in favor of letting yaml do it directly?

lebrice commented 1 year ago

Hmm yeah good question. We could actually probably do something like what's done in hydra, and save it as a _target_: "<module>.<type(obj).__name__>" entry in that entry of the yaml file! That would be interesting.

Adding this kind of "_target_" entry to give the type of the dataclasses would have to be opt-in, and when done, it would have to be for all dataclass fields, not just for subgroups. This is because this kind of deserialization issue happens anytime the stored data is for a subclass of the annotation's type. Adding it whenever we're serializing a dataclass would therefore also fix the issue with subgroups.

The last time someone brought this up was @zhiruiluo with his #211 , and the solution I had proposed was to add an argument in to_dict and from_dict. Now I'm thinking that we might want to re-evaluate exactly where/how to change the API to enable this.

lebrice commented 1 year ago

Hey @kyamagu , @norabelrose , @zhiruiluo , I made #233 to address this. Closing this for now, but please do let me know if you have any feedback. :slightly_smiling_face: