lebrice / SimpleParsing

Simple, Elegant, Typed Argument Parsing with argparse
MIT License
384 stars 46 forks source link

HF Trainer TrainingArguments can't be used with `default_factory` #275

Open levmckinney opened 11 months ago

levmckinney commented 11 months ago

Describe the bug I've really been loving using simple-parsing in my projects. It looks like you are trying to maintain compatibility with hugging faces dataclass #172. One use case I've been trying to get to work that involves this is to expose the TrainingArguments dataclass on the command line using simple-parsing so that I don't have to manual pass all the different configuration options through. This was working great until I tried to add default arguments, at which point I started running into errors of the form:

ValueError: IntervalStrategy.STEPS is not a valid IntervalStrategy, please select one of ['no', 'steps', 'epoch']

I believe this is because at some point simple-parsing converts IntervalStrategy.STEP into the string literal

'IntervalStrategy.STEP'

To Reproduce

# issue.py
from dataclasses import dataclass

from transformers import TrainingArguments
from simple_parsing import field, parse

@dataclass
class HParams:
    """You can use Enums"""

    sub_component: TrainingArguments = field(
        default_factory=lambda : TrainingArguments(evaluation_strategy="steps")
    )

if __name__ == "__main__":
    my_preferences: HParams = parse(HParams)
    print(my_preferences)

So you don't have to dig through hugging faces code, here is a minimal replication of what's happening.

See https://github.com/huggingface/transformers/pull/17933 for why it inherits from string

# simplified.py
# ======================= Their Code =======================
from typing import Union
import enum
from dataclasses import dataclass

from simple_parsing import parse, field

class Color(str, enum.Enum):
    RED = "red"
    ORANGE = "orange"
    BLUE = "blue"

@dataclass
class SubComponent:
    color: Union[str, Color] = Color.BLUE

    def __post_init__(self):
        self.color = Color(self.color)

# ======================= My Code =======================
@dataclass
class HParams:
    """You can use Enums"""

    sub_component: SubComponent = field(
        default_factory=lambda : SubComponent(color="red")
    )

if __name__ == "__main__":
    hparams: HParams = parse(HParams)
    print(hparams)

Expected behavior A clear and concise description of what you expected to happen.

$ python issue.py
HParams(TrainingArguments(...))
$ python simplified.py
HParams(sub_component=SubComponent(color=<Color.Red: 'red'>))

Actual behavior I get errors of the form:

$ python issue.py
Traceback (most recent call last):
  File "/home/lev/Projects/robust-llm/test_enum_parsing.py", line 15, in <module>
    hparams: HParams = parse(HParams)
                       ^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 1021, in parse
    parsed_args = parser.parse_args(args)
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/argparse.py", line 1869, in parse_args
    args, argv = self.parse_known_args(args, namespace)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 349, in parse_known_args
    parsed_args = self._postprocessing(parsed_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 581, in _postprocessing
    parsed_args = self._instantiate_dataclasses(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 849, in _instantiate_dataclasses
    value_for_dataclass_field = _create_dataclass_instance(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 1137, in _create_dataclass_instance
    return constructor(**constructor_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 111, in __init__
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/transformers/training_args.py", line 1199, in __post_init__
    self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/enum.py", line 714, in __call__
    return cls.__new__(cls, value)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/enum.py", line 1138, in __new__
    raise exc
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/enum.py", line 1115, in __new__
    result = cls._missing_(value)
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/transformers/utils/generic.py", line 348, in _missing_
    raise ValueError(
ValueError: IntervalStrategy.STEPS is not a valid IntervalStrategy, please select one of ['no', 'steps', 'epoch']

Here is the simplified example that replicates the basic issue without the HF stuff.

$ python simplified.py
Traceback (most recent call last):
  File "/home/lev/Projects/robust-llm/test_enum_parsing.py", line 30, in <module>
    hparams: HParams = parse(HParams)
                       ^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 1021, in parse
    parsed_args = parser.parse_args(args)
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/argparse.py", line 1869, in parse_args
    args, argv = self.parse_known_args(args, namespace)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 349, in parse_known_args
    parsed_args = self._postprocessing(parsed_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 581, in _postprocessing
    parsed_args = self._instantiate_dataclasses(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 849, in _instantiate_dataclasses
    value_for_dataclass_field = _create_dataclass_instance(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 1137, in _create_dataclass_instance
    return constructor(**constructor_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 4, in __init__
  File "/home/lev/Projects/robust-llm/test_enum_parsing.py", line 18, in __post_init__
    self.color = Color(self.color)
                 ^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/enum.py", line 714, in __call__
    return cls.__new__(cls, value)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/enum.py", line 1130, in __new__
    raise ve_exc
ValueError: 'Color.RED' is not a valid Color

Desktop (please complete the following information):

Additional context My current understanding is that our Enum class Color or IntervalStrategy inheriting from str is causing the problem. This seems to be a hack on Hugging Faces side to help with serialization see https://github.com/huggingface/transformers/pull/17933.

lebrice commented 11 months ago

Hey @levmckinney , thanks for posting!

I'm familiar with this issue, let me try to recall what's going on. I believe what's happening is that SimpleParsing is parsing the value from str into an Enum, so in the __post_init__, you're calling the Color constructor with a Color instance, rather than a string.

I'll try to whip up a solution on Monday, but for now, I think you could fix it with something like:


@dataclass
class SubComponent:
    color: Union[str, Color] = Color.BLUE

    def __post_init__(self):
        if isinstance(self.color, str):
            self.color = Color(self.color)

I'm surprised though, I thought I had this issue already nailed down with my HuggingFace example / test. I guess one other approach would be to leave those HF classes as-is, but to add a custom handler for them..

I have to think about this, I'll get back to you, thanks again for posting!