lebrice / SimpleParsing

Simple, Elegant, Typed Argument Parsing with argparse
MIT License
399 stars 49 forks source link

Saving the chosen subgroup name somewhere on args #139

Closed lebrice closed 2 years ago

lebrice commented 2 years ago

Currently, when using subgroups, the value for the subgroup name is not saved anywhere on the args, since the subgroup overwrites the original argument in the second phase of parsing:


@dataclass
class Config:
    person: Person = subgroups({"bob": Bob, "alice": Alice}, default=Bob)

parser = ArgumentParser()
parser.add_arguments(Config, "config")

args = parser.parse_args()
print(args)

Recap of Subgroups:

(As a recap, here is what subgroups do:)

$ python test/test_issue_blab.py --help
usage: test_issue_blab.py [-h] [--person {bob,alice}] [--person.age int] [--person.cool bool]

optional arguments:
  -h, --help            show this help message and exit

Config ['config']:
  Configuration dataclass.

  --person {bob,alice}  (default: bob)

Bob ['config.person']:
  Person named Bob.

  --person.age int      (default: 32)
  --person.cool bool    (default: True)
$ python issue.py --person alice --help
usage: test_issue_blab.py [-h] [--person {bob,alice}] [--person.age int] [--person.popular bool]

optional arguments:
  -h, --help            show this help message and exit

Config ['config']:
  Configuration dataclass.

  --person {bob,alice}  (default: bob)

Alice ['config.person']:
  Person named Alice.

  --person.age int      (default: 13)
  --person.popular bool
                        (default: True)

The problem

$ python issue.py --person alice
Namespace(config=Config(person=Alice(age=13, popular=True)))

It would be useful to be able to extract the name of the chose subgroup. For instance, when using a subgroup for a choice between the HParams class of different models, we want to extract the choice of the model:

@dataclass
class Config:
     model: Model.HParams = subgroups({
         "simple": SimpleModel.HParams,
         "advanced": AdvancedModel.HParams,
         }, default=SimpleModel.HParams)

parser = ArgumentParser()
parser.add_arguments(Config, "config")
args = parser.parse_args()
config: Config = args.config
model_hparams: Model.HParams = config.model

# ehhh, what kind of Model was chosen, exactly?
# --> Currently no way of knowing, except by doing something sneaky
#       like this:
# Extract type of model from type of HParams:
model_type_str = type(model_hparams).__qualname__.rpartition(".")[0]
# Find matching type
model_type = [
    model_type for model_type in Model.__subclasses__()
    if model_type.__qualname__ == model_type_str
][0]

This however doesn't work, since different models might share the same HParams class! It's not a viable solution.

I'm thinking that having something like a subgroups dictionary in the args itself could be useful. Something like this:

$ python issue.py --person alice
Namespace(config=Config(person=Alice(age=13, popular=True)), subgroups={'config.person': 'alice'})