Closed yuvval closed 2 years ago
Hi @yuvval , sorry for not getting back to you sooner!
hmm, that's interesting. You can already achieve what you want without the need for a new feature: For instance, you could do something like this (which, I admit, is a bit long, but gets the job done!):
from dataclasses import dataclass, fields, is_dataclass, asdict
from simple_parsing import ArgumentParser, Serializable
from typing import Dict, Union, Any
from simple_parsing.utils import dict_union
@dataclass
class HParams(Serializable):
lr: float = 0.01
foo: str = "hello"
def replace(self, **new_params):
new_hp_dict = dict_union(asdict(self), new_params, recurse=True)
new_hp = type(self).from_dict(new_hp_dict)
return new_hp
def differering_values(target: HParams, reference: HParams) -> Dict[str, Union[Any, Dict]]:
""" Given a dataclass, and a 'reference' dataclass, returns a possibly nested dict
of all the values that are different in `value` compared to in `reference`.
"""
non_default_values = {}
for field in fields(target):
name = field.name
target_value = getattr(target, name)
reference_value = getattr(reference, name)
if target_value == reference_value:
continue
if is_dataclass(target_value) and is_dataclass(reference_value):
# Recurse in the case of unequal dataclasses.
non_default_values[name] = differering_values(target_value, reference_value)
else:
non_default_values[name] = target_value
return non_default_values
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_arguments(HParams, "hparams")
parser.add_argument("--config_path", default="")
args = parser.parse_args()
config_path: str = args.config_path
hparams: HParams = args.hparams
default_hparams = HParams()
if config_path:
# IDEA: Create a new HParams object whose values are based are those from the
# config, plus any hparam value parsed from the command-line different than the default.
config_hparams = HParams.load(config_path)
new_kwargs = differering_values(hparams, default_hparams)
hparams = config_hparams.replace(**new_kwargs)
print(f"Hparams: {hparams}")
Then, assuming the file.yaml
file contains this:
foo: bar
lr: 0.05
You can then get what you want:
$ python test/test_issue45.py
Hparams: HParams(lr=0.01, foo='hello')
$ python test/test_issue45.py --config_path file.yaml
Hparams: HParams(lr=0.05, foo='bar')
$ python test/test_issue45.py --config_path file.yaml --foo bob
Hparams: HParams(lr=0.05, foo='bob')
$ python test/test_issue45.py --config_path file.yaml --lr 4.123
Hparams: HParams(lr=4.123, foo='bar')
Hope this helps! I'll close the issue for now, let me know if you have any other questions.
Here's how I do it. Perhaps this would make a good example.
import enum
from typing import *
from pathlib import Path
from dataclasses import dataclass
from simple_parsing import ArgumentParser, field, choice
from simple_parsing.helpers import Serializable
@dataclass
class CfgFileConfig:
"""Config file args"""
load_config: Optional[Path] = None # load config file
save_config: Optional[Path] = None # save config to specified file
@dataclass
class MyPreferences(Serializable):
light_switch: bool = True # turn on light if true
def main(args=None):
cfgfile_parser = ArgumentParser(add_help=False)
cfgfile_parser.add_arguments(CfgFileConfig, dest="cfgfile")
cfgfile_args, rest = cfgfile_parser.parse_known_args()
cfgfile: CfgFileConfig = cfgfile_args.cfgfile
file_config: Optional[Config] = None
if cfgfile.load_config is not None:
file_config = MyPreferences.load(cfgfile.load_config)
parser = ArgumentParser()
# add cfgfile args so they appear in the help message
parser.add_arguments(CfgFileConfig, dest="cfgfile")
parser.add_arguments(MyPreferences, dest="my_preferences", default=file_config)
args = parser.parse_args()
prefs: MyPreferences = args.my_preferences
print(prefs)
if cfgfile.save_config is not None:
prefs.save(cfgfile.save_config, indent=4)
if __name__ == '__main__':
main()
A downside to the diff approach is that it fails when a default value is passed into the command line. The user will expect the command line argument to supersede the config file's value. But it won't because it just so happens to equal the dataclasses' default value. I think the best workaround is to parse twice.
It'd be really nice to make it a built-in option like e.g. in Pyrallis or Clout.
Sure thing, this makes sense. I'll take a look.
Any news?
No news atm, I've been busy with other stuff. I'll push that onto my stack of TODOs, hopefully I'll have something to show for it in a week or two!
Ok I've added some better support for this in #158. Let me know what you think! :)
Is your feature request related to a problem? Please describe. I would like to be able to load the args configuration from a json/yaml file, and at the same time override some values from the command line
Describe the solution you'd like e.g. python my_script.py --config=conf.yaml --lr=1e-2
Describe alternatives you've considered parsing the commandline argv for the overriding args - and manually set them - however this makes a lot it difficult - when dataclasses are nested