lebrice / SimpleParsing

Simple, Elegant, Typed Argument Parsing with argparse
MIT License
386 stars 47 forks source link

Add `replace` function similar to `dataclasses.replace` #197

Closed zhiruiluo closed 1 year ago

zhiruiluo commented 1 year ago

The replace function of the dataclasses module has the signature of Dataclasses.replace(obj, /, **changes):

However, the Dataclass.replace doesn't work with nested dataclasses, subgroups, and other features in simple-parsing. To solve this, the simple_parsing.replace should be supplemented as an extension to dataclasses.replace.

The signature of replace function for simple_parsing is simple_parsing.replace(obj: object, changes: Dict[str, Any]) and it supports nested dataclasses, subgroups and other features through parse API.

lebrice commented 1 year ago

I'm thinking of something like this:

from simple_parsing.helpers.serialization.serializable import from_dict, to_dict

def replace(dataclass: DataclassT, **new_params) -> DataclassT:
    dataclass_dict = to_dict(dataclass, recurse=True)
    new_dataclass_dict = dict_union(dataclass_dict, new_params, recurse=True)
    return from_dict(type(dataclass), new_dataclass_dict, drop_extra_fields=True)
zhiruiluo commented 1 year ago

Hey there @zhiruiluo , thanks for the PR.

Interesting idea, however I'm not convinced that this formula is the right way to go. I think it would perhaps be better to use the dict_union, to_dict and from_dict utility functions, as I've done here: https://github.com/lebrice/SimpleParsing/blob/master/simple_parsing/helpers/hparams/hyperparameters.py#L200

More importantly, in this context it doesn't make much sense in my opinion for the user to pass values as they would be in the command-line. We don't need to do the round-trip from value -> str -> value again (through the parsing logic). This is also not in-line with the dataclasses.replace function, in my opinion.

Let me know what you think.

My proposed implement still needs more consideration for efficiency for sure, but I am more concerned about the useage cases for the simple_parsing.replace directly on dataclasses with dictionary. Let me try your current implementation with my proposed tests first.

zhiruiluo commented 1 year ago

I'm thinking of something like this:

from simple_parsing.helpers.serialization.serializable import from_dict, to_dict

def replace(dataclass: DataclassT, **new_params) -> DataclassT:
    dataclass_dict = to_dict(dataclass, recurse=True)
    new_dataclass_dict = dict_union(dataclass_dict, new_params, recurse=True)
    return from_dict(type(dataclass), new_dataclass_dict, drop_extra_fields=True)

Yeah this is much more efficient way. But, the subgroups entries need special care here. For example,

@dataclass
class A:
    a: float = 0.0

@dataclass
class B:
    b: str = "bar"

@dataclass
class AB:
    a_or_b: A | B = subgroups({"a": A, "b": B}, default="a")

config = AB(a_or_b='a')
# Replacing with {'b': 'test'} is not possible unless
# replacing with {'a_or_b':'b', 'b':'test'}

replace(config,  {'a_or_b':'b', 'b':'test'})

Here, the a_or_b has changed its type from A to B where plain dict_union cannot handle this. As per #173, subgroups needs additional information to do the serialization as well. That is because of the dynamic strucutre in the nature of subgroups. The run-time replacement of dataclasses is necessary while I am doing grid search of differents models. The master node generates different searching parameter sets for parallel training. Due to the involving of the subgroups, which is a really good feature that many people needs, replacement needs many hand-wire code to achieve.

lebrice commented 1 year ago

Hey @zhiruiluo , the replace function takes the instantiated subgroup as an input, not a serialized dict. Does that still not work?

zhiruiluo commented 1 year ago

That doesn't work.

Hey @zhiruiluo , the replace function takes the instantiated subgroup as an input, not a serialized dict. Does that still not work?

from __future__ import annotations

from dataclasses import dataclass, field

from simple_parsing import subgroups

from .test_utils import TestSetup

from simple_parsing.helpers.hparams.hyperparameters import dict_union
from simple_parsing.helpers.serialization.serializable import from_dict, to_dict
from simple_parsing.utils import DataclassT

def replace(dataclass: DataclassT, new_params) -> DataclassT:
    dataclass_dict = to_dict(dataclass, recurse=True)
    new_dataclass_dict = dict_union(dataclass_dict, new_params, recurse=True)
    return from_dict(type(dataclass), new_dataclass_dict, drop_extra_fields=True)

def test_replace_subgroups():
    @dataclass
    class C:
        c: bool = False

    @dataclass
    class D:
        d: int = 0

    @dataclass(frozen=True)
    class CD(TestSetup):
        c_or_d: C | D = subgroups({"c": C, "d": D}, default="c")

        other_arg: str = "bob"

    config = CD.setup('--c_or_d c')
    # passed
    assert replace(config, {"c_or_d": {'c': True}}).c_or_d.c == True

    config = CD.setup('--c_or_d d')
    # failed "AttributeError: 'dict' object has no attribute 'd'"
    assert replace(config, {"c_or_d": {'d': 2}}).c_or_d.d == 2

    config = CD.setup('--c_or_d d')
    # failed AttributeError: 'C' object has no attribute 'd'
    assert replace(config, {"c_or_d": D(d=2)}).c_or_d.d == 2
zhiruiluo commented 1 year ago

Reworked my solution with to_dict, from_dict, and dict_union. Modifed to_dict and from_dict to work with subgroups based on the idea of #204.

zhiruiluo commented 1 year ago

Hey there @zhiruiluo , hope you're doing good.

Thanks @lebrice! Have my baby girl born recently, so excited! I am not sure how many types are there, but that's definitely a good starting point we could work on later.

zhiruiluo commented 1 year ago

Hey @lebrice, thanks for your reviewing and feedbacks. They helped a lot. I agreed some test cases were confusing and __subgroups__@[key] makes them less readable. I've been trying to reduce that thing in some cases. However, I don't think we can totally avoid it in few cases.

Here I am giving an example,

@dataclass
class A(TestSetup):
    a: float = 0.0

@dataclass
class B(TestSetup):
    b: str = "bar"

@dataclass
class AB(TestSetup):
    a_or_b: A | B = subgroups({"a": A, "b": B}, default="a")  

def test():
    import unittest
    case = unittest.TestCase()
    case.assertDictEqual(
        to_dict(AB(a_or_b=B(b='test'))), 
        {'__subgroups__@a_or_b': 'b', 'a_or_b': {'b': 'test'}}
    )

The __subgroups__@[key] has to be appeared in the nested dict to avoid key conflict. Although, we can still represent it in a plain dictionary that is equilvalent to AB(a_or_b=B(b='test')) without using __subgroups__@[key] such as : {'a_or_b': 'b', 'a_or_b.b': 'test'}

zhiruiluo commented 1 year ago

Hey @lebrice, thanks for summarizing the contributios and pointing out a definite plan to improve the whole thing. I totally agree with your plan and will continue work with you on making this avaiable in the future release.

I will work on the finializing the first contribution as you indicated above. This PR will be suspended as the plan.

lebrice commented 1 year ago

Thanks a lot @zhiruiluo , closing this then.