lebrice / SimpleParsing

Simple, Elegant, Typed Argument Parsing with argparse
MIT License
401 stars 50 forks source link

'choices' does not seem to be checked in load() method #103

Open aoikaneko opened 2 years ago

aoikaneko commented 2 years ago

Is your feature request related to a problem? Please describe. 'choices' does not seem to be checked in load() method and invalid value can be set via the method.

To reproduce:

config.yml

animal: bird

argtest.py

import dataclasses
import simple_parsing as sp
from simple_parsing.helpers import Serializable

@dataclasses.dataclass
class Config(Serializable):
    animal: str = sp.field(default='dog', choices=['dog', 'cat'])

if __name__ == '__main__':
    config = Config.load('config.yml')
    print(config)
    """
    Config(animal='bird')
    """

Describe the solution you'd like Check 'choices' in load() method and show an error message, something like invalid choice: 'bird' (choose from 'dog', 'cat')

Describe alternatives you've considered Currently, I check 'choices' manually in __post_init__.

def __post_init__(self):
    for field in dataclasses.fields(self):
        if sp.utils.is_choice(field):
            choices = field.metadata.get('custom_args', {}).get('choices', {})

            # Check if the value is in choices
            value = getattr(self, field.name)
            if value not in choices:
                raise ValueError(f'{field.name}: invalid choice: {value} (choose from {choices})')

Btw, thanks for the great library!

lebrice commented 2 years ago

Hey there @aoikaneko , thanks for posting!

Interesting proposition.

The function used for each field in load is currently determined by the kind of type annotation that is used on the field. If you want the behaviour you describe, you can get it now by using an Enum field.

Also, if we fix #2 (which may or may not be already supported), then the decode function for a Final["cat", "dog"] should ideally work as you described above.

This is an interesting idea. I'll add some tags to this in case someone wants to jump in and help :)

aoikaneko commented 2 years ago

Hi @lebrice , thank you for your suggestion!

I tried Enum but got the following output.

.../simple_parsing/helpers/serialization/decoding.py:158: UserWarning: Unable to find a decoding function for type <enum 'Animal'>. Will try to use the type as a constructor.
  f"Unable to find a decoding function for type {t}. "
Config(animal='bird')

config.yml

animal: bird

argtest2.py

import dataclasses
import enum

import simple_parsing as sp
from simple_parsing.helpers import Serializable

class Animal(enum.Enum):
    DOG = 'dog'
    CAT = 'cat'

@dataclasses.dataclass
class Config(Serializable):
    animal: Animal = Animal.DOG

if __name__ == '__main__':
    config = Config.load('config.yml')
    print(config)

I guess I do something wrong. Could you give me some advice?