eladrich / pyrallis

Pyrallis is a framework for structured configuration parsing from both cmd and files. Simply define your desired configuration structure as a dataclass and let pyrallis do the rest!
https://eladrich.github.io/pyrallis/
MIT License
198 stars 7 forks source link

Handling of typing.Literal #8

Open phelps-matthew opened 2 years ago

phelps-matthew commented 2 years ago

Hey there! I've been putting pyrallis to action lately and can't over emphasize how much cleaner it's been to integrate and apply configurations.

That said, there are a couple use cases I'm interested to hear your thoughts on. First being how pyrallis handles typing.Literal. I see in decoding/decode_field that the Literal field will decode to 'Any' and that Literal[arg1, arg2, ..] will decode into arg1. In the case of Literal["constant"] this decodes to "constant", which is not a type.

More succinctly, does the patternLiteral[object] make good sense to use and if so, what might be the way for pyralis to handle this type appropriately?

Traceback (most recent call last):
  File "/anaconda3/envs/spec/lib/python3.9/site-packages/pyrallis/parsers/decoding.py", line 65, in decode_dataclass
    field_value = decode_field(field, raw_value)
  File "/anaconda3/envs/spec/lib/python3.9/site-packages/pyrallis/parsers/decoding.py", line 102, in decode_field
    return decode(field_type, raw_value)
  File "/anaconda3/envs/spec/lib/python3.9/functools.py", line 877, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/anaconda3/envs/spec/lib/python3.9/site-packages/pyrallis/parsers/decoding.py", line 33, in decode
    return get_decoding_fn(t)(raw_value)
  File "/anaconda3/envs/spec/lib/python3.9/site-packages/pyrallis/parsers/decoding.py", line 176, in get_decoding_fn
    raise Exception(f"No decoding function for type {t}, consider using pyrallis.decode.register")
Exception: No decoding function for type typing.Literal['onecycle', 'constant'], consider using pyrallis.decode.register
eladrich commented 2 years ago

Hi @phelps-matthew, So happy to hear! Exactly what we were aiming for 😎

Interesting question! From the technical point of view, I think that supporting Literals shouldn't be too challenging, where the expected behavior should be similar to enums - failure in parsing values from cmd/files that do not match one of defined values.

As to the pattern itself, if pyrallis would support literals how would you expect people to use them in their configurations? From my understanding of literals, I expect that people would mostly use them as a quick enum replacement. Personally, I'm less of a fan as you don't get good validation and autocomplete, but defining a lot of enums can become quite cumbersome so I can see the usefulness here.

phelps-matthew commented 2 years ago

Ah, I was entirely unaware of enum! Thank you. With the working example from tests/test_choice.py, I now see that it pretty much covers all of my use cases that I was previously attempting to use Literal for (and with autocomplete and better validation as you mentioned).

With respect to enum, I'm wondering if it is possible to access enum attributes by key in the CLI. For example:

class LRMethod1(Enum):
    onecycle: torch.optim.lr_scheduler = OneCycleLR
    lambdalr: torch.optim.lr_scheduler = LambdaLR

class LRMethod2(Enum):
    onecycle: str = "onecycle"
    constant: str = "constant"

@dataclass()
class TrainConfig:
    """config for training instance"""

    lr_method: LRMethod = LRMethod.onecycle

When creating configuration instances in python, method1 is great, but method2 seems to be required to use in CLI.

In the end, my aim is to avoid, as much as possible, a step from parsing strings into appropriate objects. What sort of pattern do you tend to use in this type of scenario?

brentyi commented 2 years ago

Hello!

As a datapoint on Literal, one pattern I've used for experiments where cross-validation is possible is:

@dataclass
class Experiment:
    experiment_name: str
    dataset_fold: Literal[0, 1, 2, 3]

to indicate that there are only 4 possible folds. IMO this makes more sense with a Literal than an enum.

To comment on enums, perhaps this line https://github.com/eladrich/pyrallis/blob/4204d4fba4a10f360646d90e0d8f07f54e8e75d8/pyrallis/parsers/decoding.py#L166 should be replaced with

    return lambda key: t[key]

to remove the need for the redundant string, enable support for enum.auto(), etc.

eladrich commented 2 years ago

@phelps-matthew While I understand the desire to avoid parsing of strings, and directly create the object, I think that in some cases it is still better to keep the two separated.

For example, an lr scheduler is a stateful object, and IMO shouldn't be part of your configuration

@brentyi Thanks for joining the discussion!

As a datapoint on Literal, one pattern I've used for experiments where cross-validation is possible is:

@dataclass
class Experiment:
    experiment_name: str
    dataset_fold: Literal[0, 1, 2, 3]

to indicate that there are only 4 possible folds. IMO this makes more sense with a Literal than an enum.

This is actually a really good example, I can see many cases where this makes sense. Will give it some thought, but I think this convices me that Literals would be a nice addition.

As to the enums, that's a great point. What I'm debating is whether I still want to support the string values if they're given. As some people would probably expect the parsing to be based on the string values as in regular enum initialization, and you see many code examples that use enums in the following manner

class Colors(Enum):
    CYAN = 'cyan'
    PURPLE = 'purple'
eladrich commented 2 years ago

@phelps-matthew I see now that you're not trying to initialize the object itself, but rather to create an enum of the possible classes, so in your code you could do something like

scheduler = cfg.lr_method()

Without the need for checking the string values, am I right?

That's actually pretty neat, and could be supported with @brentyi suggestion.

phelps-matthew commented 2 years ago

Without the need for checking the string values, am I right?

@eladrich Yes exactly!

phelps-matthew commented 2 years ago

Currently I am decoding

class LRMethod(Enum):
    onecycle: torch.optim.lr_scheduler = OneCycleLR
    lambdalr: torch.optim.lr_scheduler = LambdaLR

@dataclass()
class TrainConfig:
    """config for training instance"""

    # lr schedule type: (onecycle, lambdalr)
    lr_method: LRMethod = LRMethod.onecycle

with

pyrallis.decode.register(LRMethod, lambda x: LRMethod[x].value)

which actually isn't too bad of a pattern at all.

eladrich commented 2 years ago

Nice! Smart thinking with the registry here.

I think I'm convinced now that using the Enum names makes more sense and will patch it soon so that you won't need to register every new Enum in the future

phelps-matthew commented 2 years ago

Yes, Enums are proving rather useful! Thanks for the discussion.

With the method of my previous post, I realized that one does not have a means to complete the loop encode -> decode -> encode. When one deserializes directly to a class or function that is not Enum (e.g. the function torch.nn.functional.mse_loss or the class torch.optim.lr_scheduler.OneCycleLR), then one loses the ability to serialize such objects into string names that came from its spawning Enum class.

Moreover, since functions are not static attributes, Enum handles them differently by directly passing the value. Hence, we lose the outer properties of .name and .value that would otherwise be useful e.g.

class A(Enum):
    a = 3
    f = lambda x: x ** 2

>>> A.a
<A.a: 3>
>>> A.f
<function A.<lambda> at 0x7f7c5bf7d670>

This may seem a bit hacky, but to preserve serialization/deserialization, generalize to functions, and allow such functions to be directly callable (without using Enum.x.value) I'm using a wrapper.

class Wrap:
    """wrapper for serializing/deserializing functions/classes"""

    def __init__(self, fn):
        self.fn = fn

    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

    def __repr__(self):
        # could not use self.fn.__repr__() for some reason?
        return repr(self.fn)

class Criterion(Enum):
    """Enum class for criterion"""

    mse = Wrap(F.mse_loss)
    l1 = Wrap(F.l1_loss)

    # allows one to call directly
    def __call__(self, *args, **kwargs):
        return self.value(*args, **kwargs)

@dataclass()
class TrainConfig:
    """config for training instance"""
    # loss function: (mse, l1)
    criterion: Criterion = Criterion.mse

pyrallis.encode.register(Criterion, lambda x: x.name)
pyrallis.decode.register(Criterion, lambda x: Criterion[x])

Now one could do something like

cfg = TrainConfig()
cfg.criterion.name
>>> mse
cfg.criterion(x, y)
>>> ...
eladrich commented 2 years ago

Interesting, wasn't aware of how functions are handled in enums. Your solution seems a valid option, for completeness this StackOverflow thread offers some other solutions to overcome this, for example by using partial - https://stackoverflow.com/a/40339397/862394

From my understanding, the problem is specific to functions and not classes, so the scheduler example should work as-is.

eladrich commented 2 years ago

The latest release (v0.2.2) now uses name-based enums as suggested by @brentyi. @phelps-matthew Feel free to give it a try and let us know how it works for you.

phelps-matthew commented 2 years ago

@eladrich Beautiful. Tested and this allowed me to get rid of all my encode/decode registry statements for Enums, thank you!!