lovasoa / marshmallow_dataclass

Automatic generation of marshmallow schemas from dataclasses.
https://lovasoa.github.io/marshmallow_dataclass/html/marshmallow_dataclass.html
MIT License
456 stars 78 forks source link

Support for type-generic custom fields #269

Open thomashk0 opened 3 months ago

thomashk0 commented 3 months ago

Hi,

I have a use-case where I need to have a custom field defined in the base schema for a generic type. For example:

import typing
from dataclasses import dataclass

import marshmallow
import marshmallow_dataclass

_T = typing.TypeVar("_T")

class CustomType(typing.Generic[_T]):
    def __init__(self, v: _T):
        self._value = v

    def value(self) -> _T:
        return self._value

class CustomTypeField(marshmallow.fields.Field):
    def _serialize(self, value: CustomType, attr, obj, **kwargs):
        return {"value": value.value()}

    def _deserialize(self, value, attr, data, **kwargs):
        return CustomType(value["value"])

In this example, I want any instance of CustomType to use the field CustomTypeField. The natural approach would be to set it in the TYPE_MAPPING of a base schema:

class BaseSchema(marshmallow.Schema):
    TYPE_MAPPING = {CustomType: CustomTypeField}

@dataclass
class Foo:
    x: CustomType
    y: CustomType[int]
    z: int

schema = marshmallow_dataclass.class_schema(Foo, base_schema=BaseSchema)()
obj = Foo(x=CustomType("aa"), y=CustomType(3), z=4)
schema.dump(obj)

With the current marshmallow_dataclass version, this does not work. Indeed, field y has type CustomType[int], which is not in BaseSchema.TYPE_MAPPING. The following error is produced:

Traceback (most recent call last):
  File "scratch_2.py", line 38, in main
    schema = marshmallow_dataclass.class_schema(Foo, base_schema=BaseSchema)()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "marshmallow_dataclass/__init__.py", line 462, in class_schema
    return _internal_class_schema(clazz, base_schema)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "marshmallow_dataclass/__init__.py", line 552, in _internal_class_schema
    attributes.update(
  File "marshmallow_dataclass/__init__.py", line 555, in <genexpr>
    _field_for_schema(
  File "marshmallow_dataclass/__init__.py", line 890, in _field_for_schema
    if issubclass(typ, Enum):
       ^^^^^^^^^^^^^^^^^^^^^
TypeError: issubclass() arg 1 must be a class

Proposed changes

My current workaround is to add a new dictionary in the base schema called GENERIC_TYPE_MAPPING that contains field overrides, discarding any generic argument.

class BaseSchema:
    GENERIC_TYPE_MAPPING = {CustomType: AssetField}

The lookup is implemented as follows:

def _field_by_generic_type(
    typ: Union[type, Any], base_schema: Optional[Type[marshmallow.Schema]]
) -> Optional[Type[marshmallow.fields.Field]]:
    origin = typing_inspect.get_origin(typ)
    type_mapping = getattr(base_schema, "GENERIC_TYPE_MAPPING", {})
    if origin is not None:
        return type_mapping.get(origin)
    else:
        return type_mapping.get(typ)

And the _field_for_schema function is modified to check this:

field = _field_by_generic_type(typ, base_schema)
if field:
    return field(**metadata)

My questions are:

Thanks in advance, Thomas.