Fatal1ty / mashumaro

Fast and well tested serialization library
Apache License 2.0
774 stars 45 forks source link

Allow for more complex logic for subclass Discrimnator field matching #184

Closed samongoose closed 10 months ago

samongoose commented 10 months ago

Is your feature request related to a problem? Please describe. In building a deserializer for a large complex configuration file, some subclasses have identical shapes, but it would still be useful to distinguish by a type field. Unfortunately, in our situation there are multiple values for type that map to the same subclass.

Describe the solution you'd like I haven't dug into the existing code deep enough to know what's feasible, so there are probably a number of possible solutions (or maybe none, I suppose).

The least intrusive I could see would be having variant_trigger_fn return a list[str] instead of str.

A more involved solution might be adding an inverse to the variant_trigger_fn that takes a str and returns a class.

Finally, the workaround below could be improved upon if there was a hook in the base class that could accomplish the same thing.

Describe alternatives you've considered I have been able to work around this using __pre_deserialize__:

@dataclass
class ClientEvent(DataClassDictMixin):
    client_ip: str
    type: str
    _type = "unknown"

@dataclass
class ClientConnectedEvent(ClientEvent):
    _type = "connected"

@dataclass
class ClientDisconnectedEvent(ClientEvent):
    _type = "disconnected"

def get_type(typ: str):
    if typ in ["disconnected", "connected"]:
        return typ
    if typ == "d/c":
        return "disconnected"
    return "unknown"

@dataclass
class AggregatedEvents(DataClassDictMixin):
    list: List[
        Annotated[
            ClientEvent, Discriminator(field="_type", include_subtypes=True, include_supertypes=True)
        ]
    ]
    @classmethod
    def __pre_deserialize__(cls, d: dict[Any, Any]) -> dict[Any, Any]:
        d["list"] = [dict({"_type": get_type(event["type"])}, **event) for event in d["list"]]
        return d

events = AggregatedEvents.from_dict(
    {
        "list": [
            {"type": "connected", "client_ip": "10.0.0.42"},
            {"type": "disconnected", "client_ip": "10.0.0.42"},
            {"type": "N/A", "client_ip": "10.0.0.42"},
            {"type": "d/c", "client_ip": "10.0.0.42"},
        ]
    }
)      

# Produces:
AggregatedEvents(list=[ClientConnectedEvent(client_ip='10.0.0.42', type='connected'), ClientDisconnectedEvent(client_ip='10.0.0.42', type='disconnected'), ClientEvent(client_ip='10.0.0.42', type='N/A'), ClientDisconnectedEvent(client_ip='10.0.0.42', type='d/c')])

This works reasonably well (and preserves the original type which can be useful). The major downside is that the __pre_deserialize__ method needs to be implemented in each class that makes use of ClientEvent, and in our situation that ends up being several. It would be more convenient if there was a hook in the ClientEvent base class that could accomplish the same thing.

Additional context I'm just getting started with this library and it's great so far. It's possible I'm missing something and this is already doable. Alternatively, the workaround, well, works, so feel free to close this if this is not a feature you're looking to add. Thanks!

Fatal1ty commented 10 months ago

The least intrusive I could see would be having variant_trigger_fn return a list[str] instead of str.

At first glance, this is the best way to go. The only difficulty here is what kind of code should be generated when variant_tagger_fn is used. Right now it looks like this:

...
for variant in (*iter_all_subclasses(__main__.ClientEvent), __main__.ClientEvent):
    try:
        variants_map[variant_tagger_fn(variant)] = variant
    except KeyError:
        continue
...

Important part here is variants_map[variant_tagger_fn(variant)] = variant. If variant_tagger_fn returns a list, we will have to iterate through all the items and set an appropriate variant for all of them. If we decide to do some introspection at runtime (like isinstance for example), it will lead to a decrease in performance. We could add a new parameter to Discriminator to configure, but it seems to be overcomplicated.

As an alternative, I'd like you to consider creating an explicit ClientUnknownEvent and using a class level discriminator. There is an undocumented attribute __mashumaro_subtype_variants__ that is set for a class with a class level discriminator. You can manually register all "type" aliases in it.

@dataclass
class ClientEvent(DataClassDictMixin):
    client_ip: str
    type: str

    class Config:
        debug = True
        discriminator = Discriminator(field="type", include_subtypes=True)

@dataclass
class ClientUnknownEvent(ClientEvent):
    type = "unknown"

@dataclass
class ClientConnectedEvent(ClientEvent):
    type = "connected"

@dataclass
class ClientDisconnectedEvent(ClientEvent):
    type = "disconnected"

for key, value in (("N/A", ClientUnknownEvent), ('d/c', ClientDisconnectedEvent)):
    ClientEvent.__mashumaro_subtype_variants__[key] = value

@dataclass
class AggregatedEvents(DataClassDictMixin):
    list: List[ClientEvent]

events = AggregatedEvents.from_dict(
    {
        "list": [
            {"type": "connected", "client_ip": "10.0.0.42"},
            {"type": "disconnected", "client_ip": "10.0.0.42"},
            {"type": "N/A", "client_ip": "10.0.0.42"},
            {"type": "d/c", "client_ip": "10.0.0.42"},
        ]
    }
)

# Produces:
AggregatedEvents(list=[ClientConnectedEvent(client_ip='10.0.0.42', type='connected'), ClientDisconnectedEvent(client_ip='10.0.0.42', type='disconnected'), ClientUnknownEvent(client_ip='10.0.0.42', type='N/A'), ClientDisconnectedEvent(client_ip='10.0.0.42', type='d/c')])
mishamsk commented 10 months ago

btw I had the same case where multiple tag values map to the same class (the tag field is a a literal with multiple possible values). I was thinking about opening a PR to allow variant_tagger_fn but for now I thought I have enough PR's open already:-) also, in my use case, the number of classes is so small that I actually build a fully manual deserializer, seemed quicker than doing a PR

Fatal1ty commented 10 months ago

If we decide to do some introspection at runtime (like isinstance for example), it will lead to a decrease in performance.

I've been thinking about it and came to the conclusion that this will not have much of an impact. We iterate over variants and register them only when there is no tag in the registry. I'm going to allow variant_tagger_fn return a list, so that the following code will handle it:

variant_tags = variant_tagger_fn(variant)
if type(variant_tags) is list:
    for varint_tag in variant_tags:
        variants_map[varint_tag] = variant
else:
    variants_map[variant_tags] = variant