Hey @lebrice, I don't know where to suggest commit code. Here are the snippets for insert key-value pair to indicate selected subgroup key in to_dict method and reverse the selected subgroup in from_dict method. This would pass the last test in #197.
Possible solution for the issue #204.
The changes are enclosed by #########
def to_dict(dc, dict_factory: type[dict] = dict, recurse: bool = True) -> dict:
"""Serializes this dataclass to a dict.h
NOTE: This 'extends' the `asdict()` function from
the `dataclasses` package, allowing us to not include some fields in the
dict, or to perform some kind of custom encoding (for instance,
detaching `Tensor` objects before serializing the dataclass to a dict).
"""
if not is_dataclass(dc):
raise ValueError("to_dict should only be called on a dataclass instance.")
d: dict[str, Any] = dict_factory()
for f in fields(dc):
name = f.name
value = getattr(dc, name)
# Do not include in dict if some corresponding flag was set in metadata.
include_in_dict = f.metadata.get("to_dict", True)
if not include_in_dict:
continue
custom_encoding_fn = f.metadata.get("encoding_fn")
if custom_encoding_fn:
# Use a custom encoding function if there is one.
d[name] = custom_encoding_fn(value)
continue
###### insert subgroups selected key ##########
subgroups_dict = f.metadata.get('subgroups')
if subgroups_dict:
for g_name, g_cls in subgroups_dict.items():
if isinstance(value, g_cls):
_target = f"__subgroups__@{name}"
d[_target] = g_name
########################################
encoding_fn = encode
# TODO: Make a variant of the serialization tests that use the static functions everywhere.
if is_dataclass(value) and recurse:
try:
encoded = to_dict(value, dict_factory=dict_factory, recurse=recurse)
except TypeError:
encoded = to_dict(value)
logger.debug(f"Encoded dataclass field {name}: {encoded}")
else:
try:
encoded = encoding_fn(value)
except Exception as e:
logger.error(
f"Unable to encode value {value} of type {type(value)}! Leaving it as-is. (exception: {e})"
)
encoded = value
d[name] = encoded
return d
def from_dict(
cls: type[Dataclass], d: dict[str, Any], drop_extra_fields: bool | None = None
) -> Dataclass:
"""Parses an instance of the dataclass `cls` from the dict `d`.
Args:
cls (Type[Dataclass]): A `dataclass` type.
d (Dict[str, Any]): A dictionary of `raw` values, obtained for example
when deserializing a json file into an instance of class `cls`.
drop_extra_fields (bool, optional): Whether or not to drop extra
dictionary keys (dataclass fields) when encountered. There are three
options:
- True:
The extra keys are dropped, and this function returns an
instance of `cls`.
- False:
The extra keys (if any) are kept, and we search through the
subclasses of `cls` for the first dataclass which has all the
required fields.
- None (default):
`drop_extra_fields = not cls.decode_into_subclasses`.
Raises:
RuntimeError: If an error is encountered while instantiating the class.
Returns:
Dataclass: An instance of the dataclass `cls`.
"""
if d is None:
return None
obj_dict: dict[str, Any] = d.copy()
init_args: dict[str, Any] = {}
non_init_args: dict[str, Any] = {}
if drop_extra_fields is None:
drop_extra_fields = not getattr(cls, "decode_into_subclasses", False)
logger.debug("drop_extra_fields is None. Using cls attribute.")
if cls in {Serializable, FrozenSerializable, SerializableMixin}:
# Passing `Serializable` means that we want to find the right
# subclass depending on the keys.
# We set the value to False when `Serializable` is passed, since
# we use this mechanism when we don't know which dataclass to use.
logger.debug("cls is `SerializableMixin`, drop_extra_fields = False.")
drop_extra_fields = False
logger.debug(f"from_dict for {cls}, drop extra fields: {drop_extra_fields}")
for field in fields(cls) if is_dataclass(cls) else []:
name = field.name
if name not in obj_dict:
if (
field.metadata.get("to_dict", True)
and field.default is MISSING
and field.default_factory is MISSING
):
logger.warning(
f"Couldn't find the field '{name}' in the dict with keys " f"{list(d.keys())}"
)
continue
########### decode subgroups or decode_field ################
if field.metadata.get("subgroups", None):
# decode subgroups from dict
subgroups_dict = field.metadata.get("subgroups")
_target = f"__subgroups__@{field.name}"
if _target in obj_dict:
_target_cls = subgroups_dict[obj_dict[_target]]
obj_dict.pop(_target)
raw_value = obj_dict.pop(name)
if isinstance(raw_value, str):
# init when
field_value = subgroups_dict[raw_value]()
else:
field_value = from_dict(_target_cls, raw_value, drop_extra_fields=True)
else:
raw_value = obj_dict.pop(name)
field_value = decode_field(field, raw_value, containing_dataclass=cls)
######################################################
if field.init:
init_args[name] = field_value
else:
non_init_args[name] = field_value
extra_args = obj_dict
# If there are arguments left over in the dict after taking all fields.
if extra_args:
if drop_extra_fields:
logger.warning(f"Dropping extra args {extra_args}")
extra_args.clear()
elif issubclass(cls, (Serializable, FrozenSerializable, SerializableMixin)):
# Use the first Serializable derived class that has all the required
# fields.
logger.debug(f"Missing field names: {extra_args.keys()}")
# Find all the "registered" subclasses of `cls`. (from Serializable)
derived_classes: list[type[SerializableMixin]] = []
for subclass in cls.subclasses:
if issubclass(subclass, cls) and subclass is not cls:
derived_classes.append(subclass)
logger.debug(f"All serializable derived classes of {cls} available: {derived_classes}")
# All the arguments that the dataclass should be able to accept in
# its 'init'.
req_init_field_names = set(chain(extra_args, init_args))
# Sort the derived classes by their number of init fields, so that
# we choose the first one with all the required fields.
derived_classes.sort(key=lambda dc: len(get_init_fields(dc)))
for child_class in derived_classes:
logger.debug(f"child class: {child_class.__name__}, mro: {child_class.mro()}")
child_init_fields: dict[str, Field] = get_init_fields(child_class)
child_init_field_names = set(child_init_fields.keys())
if child_init_field_names >= req_init_field_names:
# `child_class` is the first class with all required fields.
logger.debug(f"Using class {child_class} instead of {cls}")
return from_dict(child_class, d, drop_extra_fields=False)
init_args.update(extra_args)
try:
instance = cls(**init_args) # type: ignore
except TypeError as e:
# raise RuntimeError(f"Couldn't instantiate class {cls} using init args {init_args}.")
raise RuntimeError(
f"Couldn't instantiate class {cls} using init args {init_args.keys()}: {e}"
)
for name, value in non_init_args.items():
logger.debug(f"Setting non-init field '{name}' on the instance.")
setattr(instance, name, value)
return instance
Hey @lebrice, I don't know where to suggest commit code. Here are the snippets for insert key-value pair to indicate selected subgroup key in
to_dict
method and reverse the selected subgroup infrom_dict
method. This would pass the last test in #197.Possible solution for the issue #204.
The changes are enclosed by
#########