lebrice / SimpleParsing

Simple, Elegant, Typed Argument Parsing with argparse
MIT License
386 stars 47 forks source link

Add `simple_parsing.replace` to make it easy to replace nested fields (rework of #197) #201

Closed lebrice closed 1 year ago

zhiruiluo commented 1 year ago

Hey @lebrice, I don't know where to suggest commit code. Here are the snippets for insert key-value pair to indicate selected subgroup key in to_dict method and reverse the selected subgroup in from_dict method. This would pass the last test in #197.

Possible solution for the issue #204.

The changes are enclosed by #########

def to_dict(dc, dict_factory: type[dict] = dict, recurse: bool = True) -> dict:
    """Serializes this dataclass to a dict.h

    NOTE: This 'extends' the `asdict()` function from
    the `dataclasses` package, allowing us to not include some fields in the
    dict, or to perform some kind of custom encoding (for instance,
    detaching `Tensor` objects before serializing the dataclass to a dict).
    """
    if not is_dataclass(dc):
        raise ValueError("to_dict should only be called on a dataclass instance.")

    d: dict[str, Any] = dict_factory()
    for f in fields(dc):
        name = f.name
        value = getattr(dc, name)

        # Do not include in dict if some corresponding flag was set in metadata.
        include_in_dict = f.metadata.get("to_dict", True)
        if not include_in_dict:
            continue

        custom_encoding_fn = f.metadata.get("encoding_fn")
        if custom_encoding_fn:
            # Use a custom encoding function if there is one.
            d[name] = custom_encoding_fn(value)
            continue

        ###### insert subgroups selected key ##########
        subgroups_dict = f.metadata.get('subgroups')
        if subgroups_dict:
            for g_name, g_cls in subgroups_dict.items():
                if isinstance(value, g_cls):
                    _target = f"__subgroups__@{name}"
                    d[_target] = g_name
        ########################################

        encoding_fn = encode
        # TODO: Make a variant of the serialization tests that use the static functions everywhere.
        if is_dataclass(value) and recurse:
            try:
                encoded = to_dict(value, dict_factory=dict_factory, recurse=recurse)
            except TypeError:
                encoded = to_dict(value)
            logger.debug(f"Encoded dataclass field {name}: {encoded}")
        else:
            try:
                encoded = encoding_fn(value)
            except Exception as e:
                logger.error(
                    f"Unable to encode value {value} of type {type(value)}! Leaving it as-is. (exception: {e})"
                )
                encoded = value
        d[name] = encoded
    return d
def from_dict(
    cls: type[Dataclass], d: dict[str, Any], drop_extra_fields: bool | None = None
) -> Dataclass:
    """Parses an instance of the dataclass `cls` from the dict `d`.

    Args:
        cls (Type[Dataclass]): A `dataclass` type.
        d (Dict[str, Any]): A dictionary of `raw` values, obtained for example
            when deserializing a json file into an instance of class `cls`.
        drop_extra_fields (bool, optional): Whether or not to drop extra
            dictionary keys (dataclass fields) when encountered. There are three
            options:
            - True:
                The extra keys are dropped, and this function returns an
                instance of `cls`.
            - False:
                The extra keys (if any) are kept, and we search through the
                subclasses of `cls` for the first dataclass which has all the
                required fields.
            - None (default):
                `drop_extra_fields = not cls.decode_into_subclasses`.

    Raises:
        RuntimeError: If an error is encountered while instantiating the class.

    Returns:
        Dataclass: An instance of the dataclass `cls`.
    """
    if d is None:
        return None

    obj_dict: dict[str, Any] = d.copy()

    init_args: dict[str, Any] = {}
    non_init_args: dict[str, Any] = {}

    if drop_extra_fields is None:
        drop_extra_fields = not getattr(cls, "decode_into_subclasses", False)
        logger.debug("drop_extra_fields is None. Using cls attribute.")

        if cls in {Serializable, FrozenSerializable, SerializableMixin}:
            # Passing `Serializable` means that we want to find the right
            # subclass depending on the keys.
            # We set the value to False when `Serializable` is passed, since
            # we use this mechanism when we don't know which dataclass to use.
            logger.debug("cls is `SerializableMixin`, drop_extra_fields = False.")
            drop_extra_fields = False

    logger.debug(f"from_dict for {cls}, drop extra fields: {drop_extra_fields}")
    for field in fields(cls) if is_dataclass(cls) else []:
        name = field.name
        if name not in obj_dict:
            if (
                field.metadata.get("to_dict", True)
                and field.default is MISSING
                and field.default_factory is MISSING
            ):
                logger.warning(
                    f"Couldn't find the field '{name}' in the dict with keys " f"{list(d.keys())}"
                )
            continue

        ########### decode subgroups or decode_field ################
        if field.metadata.get("subgroups", None):
            # decode subgroups from dict
            subgroups_dict = field.metadata.get("subgroups")

            _target = f"__subgroups__@{field.name}"
            if _target in obj_dict:
                _target_cls = subgroups_dict[obj_dict[_target]]
                obj_dict.pop(_target)

            raw_value = obj_dict.pop(name)
            if isinstance(raw_value, str):
                # init when 
                field_value = subgroups_dict[raw_value]()
            else: 
                field_value = from_dict(_target_cls, raw_value, drop_extra_fields=True)
        else:
            raw_value = obj_dict.pop(name)
            field_value = decode_field(field, raw_value, containing_dataclass=cls)
        ######################################################
        if field.init:
            init_args[name] = field_value
        else:
            non_init_args[name] = field_value

    extra_args = obj_dict

    # If there are arguments left over in the dict after taking all fields.
    if extra_args:
        if drop_extra_fields:
            logger.warning(f"Dropping extra args {extra_args}")
            extra_args.clear()

        elif issubclass(cls, (Serializable, FrozenSerializable, SerializableMixin)):
            # Use the first Serializable derived class that has all the required
            # fields.
            logger.debug(f"Missing field names: {extra_args.keys()}")

            # Find all the "registered" subclasses of `cls`. (from Serializable)
            derived_classes: list[type[SerializableMixin]] = []
            for subclass in cls.subclasses:
                if issubclass(subclass, cls) and subclass is not cls:
                    derived_classes.append(subclass)
            logger.debug(f"All serializable derived classes of {cls} available: {derived_classes}")

            # All the arguments that the dataclass should be able to accept in
            # its 'init'.
            req_init_field_names = set(chain(extra_args, init_args))

            # Sort the derived classes by their number of init fields, so that
            # we choose the first one with all the required fields.
            derived_classes.sort(key=lambda dc: len(get_init_fields(dc)))

            for child_class in derived_classes:
                logger.debug(f"child class: {child_class.__name__}, mro: {child_class.mro()}")
                child_init_fields: dict[str, Field] = get_init_fields(child_class)
                child_init_field_names = set(child_init_fields.keys())

                if child_init_field_names >= req_init_field_names:
                    # `child_class` is the first class with all required fields.
                    logger.debug(f"Using class {child_class} instead of {cls}")
                    return from_dict(child_class, d, drop_extra_fields=False)

    init_args.update(extra_args)
    try:
        instance = cls(**init_args)  # type: ignore
    except TypeError as e:
        # raise RuntimeError(f"Couldn't instantiate class {cls} using init args {init_args}.")
        raise RuntimeError(
            f"Couldn't instantiate class {cls} using init args {init_args.keys()}: {e}"
        )

    for name, value in non_init_args.items():
        logger.debug(f"Setting non-init field '{name}' on the instance.")
        setattr(instance, name, value)
    return instance
lebrice commented 1 year ago

Closing in favour of #197