lebrice / SimpleParsing

Simple, Elegant, Typed Argument Parsing with argparse
MIT License
409 stars 51 forks source link

Add support for a config file with optional overriding #45

Closed yuvval closed 2 years ago

yuvval commented 3 years ago

Is your feature request related to a problem? Please describe. I would like to be able to load the args configuration from a json/yaml file, and at the same time override some values from the command line

Describe the solution you'd like e.g. python my_script.py --config=conf.yaml --lr=1e-2

Describe alternatives you've considered parsing the commandline argv for the overriding args - and manually set them - however this makes a lot it difficult - when dataclasses are nested

lebrice commented 3 years ago

Hi @yuvval , sorry for not getting back to you sooner!

hmm, that's interesting. You can already achieve what you want without the need for a new feature: For instance, you could do something like this (which, I admit, is a bit long, but gets the job done!):

from dataclasses import dataclass, fields, is_dataclass, asdict
from simple_parsing import ArgumentParser, Serializable
from typing import Dict, Union, Any
from simple_parsing.utils import dict_union

@dataclass
class HParams(Serializable):
    lr: float = 0.01
    foo: str = "hello"

    def replace(self, **new_params):
        new_hp_dict = dict_union(asdict(self), new_params, recurse=True)
        new_hp = type(self).from_dict(new_hp_dict)
        return new_hp

def differering_values(target: HParams, reference: HParams) -> Dict[str, Union[Any, Dict]]:
    """ Given a dataclass, and a 'reference' dataclass, returns a possibly nested dict
    of all the values that are different in `value` compared to in `reference`.
    """
    non_default_values = {}
    for field in fields(target):
        name = field.name
        target_value = getattr(target, name)
        reference_value = getattr(reference, name)
        if target_value == reference_value:
            continue
        if is_dataclass(target_value) and is_dataclass(reference_value):
            # Recurse in the case of unequal dataclasses.
            non_default_values[name] = differering_values(target_value, reference_value)
        else:
            non_default_values[name] = target_value
    return non_default_values

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_arguments(HParams, "hparams")
    parser.add_argument("--config_path", default="")
    args = parser.parse_args()

    config_path: str = args.config_path
    hparams: HParams = args.hparams

    default_hparams = HParams()

    if config_path:
        # IDEA: Create a new HParams object whose values are based are those from the
        # config, plus any hparam value parsed from the command-line different than the default.
        config_hparams = HParams.load(config_path)
        new_kwargs = differering_values(hparams, default_hparams)
        hparams = config_hparams.replace(**new_kwargs)

    print(f"Hparams: {hparams}")

Then, assuming the file.yaml file contains this:

foo: bar
lr: 0.05

You can then get what you want:

$ python test/test_issue45.py
Hparams: HParams(lr=0.01, foo='hello')
$ python test/test_issue45.py --config_path file.yaml
Hparams: HParams(lr=0.05, foo='bar')
$ python test/test_issue45.py --config_path file.yaml --foo bob
Hparams: HParams(lr=0.05, foo='bob')
$ python test/test_issue45.py --config_path file.yaml --lr 4.123
Hparams: HParams(lr=4.123, foo='bar')

Hope this helps! I'll close the issue for now, let me know if you have any other questions.

stevebyan commented 2 years ago

Here's how I do it. Perhaps this would make a good example.

import enum
from typing import *
from pathlib import Path
from dataclasses import dataclass

from simple_parsing import ArgumentParser, field, choice
from simple_parsing.helpers import Serializable

@dataclass
class CfgFileConfig:
    """Config file args"""
    load_config: Optional[Path] = None          # load config file
    save_config: Optional[Path] = None          # save config to specified file

@dataclass
class MyPreferences(Serializable):
    light_switch: bool = True # turn on light if true

def main(args=None):

    cfgfile_parser = ArgumentParser(add_help=False)
    cfgfile_parser.add_arguments(CfgFileConfig, dest="cfgfile")
    cfgfile_args, rest = cfgfile_parser.parse_known_args()

    cfgfile: CfgFileConfig = cfgfile_args.cfgfile

    file_config: Optional[Config] = None
    if cfgfile.load_config is not None:
        file_config = MyPreferences.load(cfgfile.load_config)

    parser = ArgumentParser()

    # add cfgfile args so they appear in the help message
    parser.add_arguments(CfgFileConfig, dest="cfgfile") 
    parser.add_arguments(MyPreferences, dest="my_preferences", default=file_config)

    args = parser.parse_args()

    prefs: MyPreferences = args.my_preferences
    print(prefs)

    if cfgfile.save_config is not None:
        prefs.save(cfgfile.save_config, indent=4)

if __name__ == '__main__':
    main()
psirenny commented 2 years ago

A downside to the diff approach is that it fails when a default value is passed into the command line. The user will expect the command line argument to supersede the config file's value. But it won't because it just so happens to equal the dataclasses' default value. I think the best workaround is to parse twice.

andrey-klochkov-liftoff commented 2 years ago

It'd be really nice to make it a built-in option like e.g. in Pyrallis or Clout.

lebrice commented 2 years ago

Sure thing, this makes sense. I'll take a look.

Yevgnen commented 2 years ago

Any news?

lebrice commented 2 years ago

No news atm, I've been busy with other stuff. I'll push that onto my stack of TODOs, hopefully I'll have something to show for it in a week or two!

lebrice commented 2 years ago

Ok I've added some better support for this in #158. Let me know what you think! :)