redis / redis-om-python

Object mapping, and more, for Redis and Python
MIT License
1.06k stars 108 forks source link

Issue with validate_model using FieldInfo in pydantic v2 #572

Closed CodeToFreedom closed 2 months ago

CodeToFreedom commented 8 months ago

Environment:

redis-om: 0.2.1
pydantic: 2.4.2
pydantic-core: 2.11.0
fastapi: 0.104.0

Description:

When using the check method in the custom RedisModel which internally calls the validate_model function from pydantic v1, the validation fails with pydantic v2 models having FieldInfo.

Problematic Code:

The issue seems to arise from the line: value = field.get_default()

Possible solution Changing the code to the following seems to fix that:

def validate_model(  # noqa: C901 (ignore complexity)
    model: Type[BaseModel], input_data: 'DictStrAny', cls: 'ModelOrDc' = None
) -> Tuple['DictStrAny', 'SetStr', Optional[ValidationError]]:
    """
    validate data against a model.
    """
    values = {}
    errors = []
    # input_data names, possibly alias
    names_used = set()
    # field names, never aliases
    fields_set = set()
    config = model.__config__
    check_extra = config.extra is not Extra.ignore
    cls_ = cls or model

    for validator in model.__pre_root_validators__:
        try:
            input_data = validator(cls_, input_data)
        except (ValueError, TypeError, AssertionError) as exc:
            return {}, set(), ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], cls_)

    for name, field in model.__fields__.items():
        value = input_data.get(field.alias, _missing)
        using_name = False
        if value is _missing and config.allow_population_by_field_name and field.alt_alias:
            value = input_data.get(field.name, _missing)
            using_name = True

        if value is _missing:
            if field.required:
                errors.append(ErrorWrapper(MissingError(), loc=field.alias))
                continue

            value = field.get_default()

            **//////////////////////////BEGINNING OF ADDED CODE** 
            if value and hasattr(value,'default'):

                if hasattr(value,'default_factory') and value.default_factory:
                    value = value.default_factory() 
                else:
                    value=value.default

            if str(value) == 'PydanticUndefined':
                value=None
            **/////////////////////////////END OF ADDED CODE**

            if not config.validate_all and not field.validate_always:
                values[name] = value
                continue
        else:
            fields_set.add(name)
            if check_extra:
                names_used.add(field.name if using_name else field.alias)

        v_, errors_ = field.validate(value, values, loc=field.alias, cls=cls_)
        if isinstance(errors_, ErrorWrapper):
            errors.append(errors_)
        elif isinstance(errors_, list):
            errors.extend(errors_)
        else:
            values[name] = v_

    if check_extra:
        if isinstance(input_data, GetterDict):
            extra = input_data.extra_keys() - names_used
        else:
            extra = input_data.keys() - names_used
        if extra:
            fields_set |= extra
            if config.extra is Extra.allow:
                for f in extra:
                    values[f] = input_data[f]
            else:
                for f in sorted(extra):
                    errors.append(ErrorWrapper(ExtraError(), loc=f))

    for skip_on_failure, validator in model.__post_root_validators__:
        if skip_on_failure and errors:
            continue
        try:
            values = validator(cls_, values)
        except (ValueError, TypeError, AssertionError) as exc:
            errors.append(ErrorWrapper(exc, loc=ROOT_KEY))

    if errors:
        return values, fields_set, ValidationError(errors, cls_)
    else:
        return values, fields_set, None
slorello89 commented 3 months ago

Should be fixed by #603