lebrice / SimpleParsing

Simple, Elegant, Typed Argument Parsing with argparse
MIT License
386 stars 47 forks source link

The deserialization failed on nested dataclasses. #210

Closed zhiruiluo closed 1 year ago

zhiruiluo commented 1 year ago

Describe the bug A clear and concise description of what the bug is.

To Reproduce

from __future__ import annotations
from simple_parsing import Serializable
from simple_parsing.helpers.serialization import from_dict, to_dict
from dataclasses import dataclass, field
import functools
import pytest

@dataclass
class InnerConfig:
    arg1: int = 1
    arg2: str = 'foo'
    arg1_post_init: str = field(init=False)

    def __post_init__(self):
        self.arg1_post_init = str(self.arg1)

@dataclass
class OuterConfig1(Serializable):
    out_arg: int = 0
    inner: InnerConfig = field(default_factory=InnerConfig)

@dataclass
class OuterConfig2(Serializable):
    out_arg: int = 0
    inner: InnerConfig = field(default_factory=functools.partial(InnerConfig, arg2='bar'))

@dataclass
class Level1:
    arg: int = 1

@dataclass
class Level2:
    arg: int = 1
    prev: Level1 = field(default_factory=Level1)

@dataclass
class Level3:
    arg: int = 1
    prev: Level2 = field(default_factory=Level2)

@pytest.mark.parametrize(
    ('config'),
    [
        (OuterConfig1()),
        (OuterConfig2()),
        (Level2()),
        (Level3()),
    ]
)
def test_nested_dataclasses_serialization(config: object):
    config_dict = to_dict(config)
    print(config_dict)
    new_config = from_dict(
        config.__class__,
        config_dict,
        drop_extra_fields=True,
    )
    assert config == new_config

Expected behavior A clear and concise description of what you expected to happen.

All passed

Actual behavior A clear and concise description of what is happening.

All failed except `Level2()`

Desktop (please complete the following information):

Additional context Add any other context about the problem here.

lebrice commented 1 year ago

Hey there @zhiruiluo , I'm unable to reproduce the issue:

from __future__ import annotations
from dataclasses import dataclass, field
import functools

import pytest
from simple_parsing.helpers.serialization import from_dict, to_dict
from simple_parsing.utils import Dataclass

@dataclass
class Level1:
    arg: int = 1

@dataclass
class Level2:
    arg: int = 1
    prev: Level1 = field(default_factory=Level1)

@pytest.mark.parametrize("config", [Level1(arg=2), Level2(arg=2, prev=Level1(arg=3)),])
def test_nested_dataclass(config: Dataclass):
    _from_dict = functools.partial(from_dict, type(config))
    assert _from_dict(to_dict(config)) == config
    assert _from_dict(to_dict(config), drop_extra_fields=True)

    assert to_dict(_from_dict(to_dict(config))) == to_dict(config)
    assert _from_dict(to_dict(_from_dict(to_dict(config)))) == _from_dict(to_dict(config))

Do you have an example that doesn't work?

lebrice commented 1 year ago

Hmm actually I am getting some errors when using your tests directly. I'll investigate.

lebrice commented 1 year ago

The tests raise the following UserWarning:

UserWarning: Unable to find a decoding function for the annotation <class 'test_issue_210.Level2'> (of type <class 'type'>). Will try to use the type as a constructor. Consider registering a decoding function using `register_decoding_fn`, or posting an issue on GitHub.

This is somewhat expected, given that the dataclasses don't inherit from Serializable. However, it makes sense that we should be able to create dataclasses from dictionaries, even if they don't inherit from Serializable. I'll take a closer look at how we can fix this.

zhiruiluo commented 1 year ago

Hey there @zhiruiluo , I'm unable to reproduce the issue:

from __future__ import annotations
from dataclasses import dataclass, field
import functools

import pytest
from simple_parsing.helpers.serialization import from_dict, to_dict
from simple_parsing.utils import Dataclass

@dataclass
class Level1:
    arg: int = 1

@dataclass
class Level2:
    arg: int = 1
    prev: Level1 = field(default_factory=Level1)

@pytest.mark.parametrize("config", [Level1(arg=2), Level2(arg=2, prev=Level1(arg=3)),])
def test_nested_dataclass(config: Dataclass):
    _from_dict = functools.partial(from_dict, type(config))
    assert _from_dict(to_dict(config)) == config
    assert _from_dict(to_dict(config), drop_extra_fields=True)

    assert to_dict(_from_dict(to_dict(config))) == to_dict(config)
    assert _from_dict(to_dict(_from_dict(to_dict(config)))) == _from_dict(to_dict(config))

Do you have an example that doesn't work?

One layer of nesting will work, but two layers or more nesting will fail.

zhiruiluo commented 1 year ago

The tests raise the following UserWarning:

UserWarning: Unable to find a decoding function for the annotation <class 'test_issue_210.Level2'> (of type <class 'type'>). Will try to use the type as a constructor. Consider registering a decoding function using `register_decoding_fn`, or posting an issue on GitHub.

This is somewhat expected, given that the dataclasses don't inherit from Serializable. However, it makes sense that we should be able to create dataclasses from dictionaries, even if they don't inherit from Serializable. I'll take a closer look at how we can fix this.

After adding serializable to all dataclasses, the test passed. I misunderstood the actual behivor of the Serializable. Something might be wrong for #211. I was naturally thinking that to_dict and from_dict will work without Serializable.

lebrice commented 1 year ago

Yes you're right @zhiruiluo , from_dict and to_dict are supposed to work with all dataclasses, not just those that subclass Serializable.

I'll make a bugfix PR for this soon. Thanks for pointing it out.