lovasoa / marshmallow_dataclass

Automatic generation of marshmallow schemas from dataclasses.
https://lovasoa.github.io/marshmallow_dataclass/html/marshmallow_dataclass.html
MIT License
458 stars 77 forks source link

Forward reference NameError when loading a recursive json #135

Open AdamMabrouk97 opened 3 years ago

AdamMabrouk97 commented 3 years ago

Hello Everyone,

I'm using marshmallow-dataclass to load a json which represents a sequence of rules where each rule is represented by a LogicalGroup and applies a logical operator on its child expressions, knowing that an expression can itself be a LogicalGroup.

The input dict follows this structure:

import marshmallow_dataclass
from dataclasses import field
from api_handler import BaseSchema
from typing import Sequence, Union, Literal, Type, List, ForwardRef, TypeVar, Generic

filter_input = { "rules" :
  [{
    "groupOperator" : "and",
    "expressions" : [
      { "field": "xxxxx", "operator": "eq", "value": 'level1' },
      { "field": "xxxxx", "operator": "eq", "value": 'm'},
      { "field": "xxxxx", "operator": "eq", "value": "test"},
      {
        "groupOperator" : "or",
        "expressions" : [
          { "field": "xxxx", "operator": "eq", "value": 'level2' },
          { "field": "xxxx", "operator": "eq", "value": 'm' },
          { "field": "xxxx", "operator": "eq", "value": "test" }
        ]
      }
    ]
  }]
}

The dataclasses i'm using for this purpose are the following :

@marshmallow_dataclass.dataclass(base_schema=BaseSchema)
class Expression:
    field    : str
    operator : str
    value    : str 

@marshmallow_dataclass.dataclass(base_schema=BaseSchema)
class LogicalGroup:
    group_operator   : str
    expressions      : List[Union['LogicalGroup', Expression]] = field(default_factory=list)

@marshmallow_dataclass.dataclass(base_schema=BaseSchema)
class Filter:
    rules: List[LogicalGroup] = field(default_factory=list)

The problem is when i try to load the dict using the Filter dataclass i get the following error

filt                = Filter.Schema().load(filter_input)
Traceback (most recent call last):
  File "/home/adam/billing/billing/filter/filter.py", line 96, in <module>
    filt                = Filter.Schema().load(filter_input)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow_dataclass/__init__.py", line 628, in load
    all_loaded = super().load(data, many=many, **kwargs)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/schema.py", line 725, in load
    return self._do_load(
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/schema.py", line 859, in _do_load
    result = self._deserialize(
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/schema.py", line 667, in _deserialize
    value = self._call_and_store(
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/schema.py", line 496, in _call_and_store
    value = getter_func(data)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/schema.py", line 664, in <lambda>
    getter = lambda val: field_obj.deserialize(
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/fields.py", line 354, in deserialize
    output = self._deserialize(value, attr, data, **kwargs)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/fields.py", line 726, in _deserialize
    result.append(self.inner.deserialize(each, **kwargs))
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/fields.py", line 354, in deserialize
    output = self._deserialize(value, attr, data, **kwargs)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/fields.py", line 609, in _deserialize
    return self._load(value, data, partial=partial)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/fields.py", line 592, in _load
    valid_data = self.schema.load(value, unknown=self.unknown, partial=partial)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow_dataclass/__init__.py", line 628, in load
    all_loaded = super().load(data, many=many, **kwargs)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/schema.py", line 725, in load
    return self._do_load(
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/schema.py", line 859, in _do_load
    result = self._deserialize(
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/schema.py", line 667, in _deserialize
    value = self._call_and_store(
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/schema.py", line 496, in _call_and_store
    value = getter_func(data)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/schema.py", line 664, in <lambda>
    getter = lambda val: field_obj.deserialize(
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/fields.py", line 354, in deserialize
    output = self._deserialize(value, attr, data, **kwargs)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/fields.py", line 726, in _deserialize
    result.append(self.inner.deserialize(each, **kwargs))
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow/fields.py", line 354, in deserialize
    output = self._deserialize(value, attr, data, **kwargs)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/marshmallow_dataclass/union_field.py", line 56, in _deserialize
    typeguard.check_type(attr or "anonymous", result, typ)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/typeguard/__init__.py", line 655, in check_type
    expected_type = resolve_forwardref(expected_type, memo)
  File "/home/adam/thanos-envv/lib/python3.9/site-packages/typeguard/__init__.py", line 198, in resolve_forwardref
    return evaluate_forwardref(maybe_ref, memo.globals, memo.locals, frozenset())
  File "/usr/lib/python3.9/typing.py", line 533, in _evaluate
    eval(self.__forward_code__, globalns, localns),
  File "<string>", line 1, in <module>
NameError: name 'LogicalGroup' is not defined

I'm guessing the problem comes from declaring LogicalGroup as a ForwardRef inside type hint Union, because when i use only Union['LogicalGroup'] and modify my dict to be a nested dict of LogicalGroups without the Expressions it works fine.

Does someone have any idea on the source of the bug ? Or maybe a proposition to adress this problem in another way ?

Thanks in advance !

tclv commented 3 years ago

I also ran into this bug. The problem is the forward reference resolution used by typeguard. It uses the global and local namespace of the calling frame to achieve this, which in this case is the marshmallow_dataclass.union_field module, which does not contain the forward references of the schema.

Example:

from marshmallow_dataclass import dataclass
from typing import ForwardRef, Union

X_ref = ForwardRef("X")
Y_ref = ForwardRef("Y")
U = Union[Y_ref, X_ref]

@dataclass
class X:
    a: str

@dataclass
class Y:
    a: U

schema = X.Schema()

If we run the following in the same module we get a TypeError as expected and desired:

import typeguard
typeguard.check_type('a', X(a="a"), Y_ref)
...
TypeError: type of a must be __main__.Y; got __main__.X instead

While dumping an instance results in a NameError

Y.Schema().dump(Y(X('a')))
...
NameError: name 'Y' is not defined

The call from _serialize and _deserialize is identical to the previous call, but since typeguard.check_type uses the callee's frame we get different results.

Based on this, it looks like forward references combined with union types seem to be broken at the moment. Currently, the library uses them both in serialization and deserialization, while I think it is only necessary on serialization to make sure you have the correct field. In the case of deserialization it is checking whether the nested field actually returned the correct type, which seems to be true by construction (I think?).

Based on this, there’s a few possible solutions:

  1. There is a PR on typeguard that tries to address this specific issue. Their solution comes down to checking the string of the forward reference and comparing this to the type of the class. I am not sure whether this will be merged as it looks a little hacky and the library supports passing a global and local namespace which override the default namespace, so it's not necessarily a bug on their end. The MR boils down to

    try:
         expected_type = resolve_forwardref(expected_type, memo)
     except NameError:
         # Try checking the class if class is cycle or value is FordwardRef
         if isinstance(expected_type, ForwardRef):
             class_name = expected_type.__forward_arg__
             class_values = get_super_class_names(value)
             class_values.append(get_class_name(value))
             if class_name in class_values:
                 return
             else:
                 if len(class_values) == 1:
                     class_values = class_values[0]
                 msg = 'type of {} must be {}; got {} instead'.format(
                     argname, class_name, class_values)
                 raise TypeError(msg)

    In theory this should be enough for picking the correct serialization field, so if this does not get merged, it could be a workaround in this library.

  2. Try to pass the correct global and local namespace to typeguard. This seems a bit iffy as it relies on having the imports available from the calling module and having the exact amount of calling stacks from all different API calls. Difficult to control and depends on marshmallow internals.

  3. Pass a reference of the class type to marshmallow_dataclass schemas, that can be used for comparison checks on nested fields. This also provides some extra flexibility on validation giving access to class variables.

def _base_schema(
    clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None
) -> Type[marshmallow.Schema]:
    """
    Base schema factory that creates a schema for `clazz` derived either from `base_schema`
    or `BaseSchema`
    """

    # Remove `type: ignore` when mypy handles dynamic base classes
    # https://github.com/python/mypy/issues/2813
    class BaseSchema(base_schema or marshmallow.Schema):  # type: ignore
        __clazz__ = clazz
...
    return BaseSchema

This could then be used in serialization for an isinstance check. I am not sure if this would completely implement all the different checks in typeguard, but would at least solve the forward reference case.

Thoughts?

tclv commented 3 years ago

@lovasoa I can write a PR for my previous comment, but I'd like your input before I get started. Could you comment when you have time?