octabytes / FireO

Google Cloud Firestore modern and simplest convenient ORM package in Python. FireO is specifically designed for the Google's Firestore
https://fireo.octabyte.io
Apache License 2.0
247 stars 29 forks source link

FireO in FastAPI-filters #190

Open ADR-007 opened 1 year ago

ADR-007 commented 1 year ago

Hi!

I implemented FireO in FastAPI-filters, but I'm not sure when I'll make a PR to the library. Please let me know if you want to use it. I'll try to find a time for it then. :)

Code snippet:

from copy import deepcopy
from types import UnionType
from typing import Any, Dict, Generic, Literal, Optional, Tuple, Type, TypeVar, Union

import fastapi_filter.base.filter as filter_lib
from fastapi import Query, params
from fastapi_filter.base.filter import BaseFilterModel
from fireo.managers.managers import Manager
from fireo.queries.filter_query import FilterQuery
from fireo.queries.query_set import QuerySet
from pydantic import root_validator, validator
from pydantic.fields import SHAPE_LIST, FieldInfo, ModelField, Undefined

_orm_operator_filter = {
    "": lambda query, field_name, value: query.filter(field_name, "==", value),
    "not_eq": lambda query, field_name, value: query.filter(field_name, "!=", value),
    "gt": lambda query, field_name, value: query.filter(field_name, ">", value),
    "gte": lambda query, field_name, value: query.filter(field_name, ">=", value),
    "in": lambda query, field_name, value: query.filter(field_name, "in", value),
    "isnull": lambda query, field_name, value: query.filter(field_name, ("==" if value is True else "!="), None),
    "lt": lambda query, field_name, value: query.filter(field_name, "<>", value),
    "lte": lambda query, field_name, value: query.filter(field_name, "<=", value),
    "not_in": lambda query, field_name, value: query.filter(field_name, "not_in", value),
    "contains": lambda query, field_name, value: query.filter(field_name, "array-contains", value),
    "overlap": lambda query, field_name, value: query.filter(field_name, "array-contains-any", value),
    "startswith": lambda query, field_name, value: (
        query.filter(field_name, ">=", value).filter(field_name, "<", value + "\ufffd")
    ),
}
_orm_op_conflicts_with_sorting = set(_orm_operator_filter) - {"", "in", "not_in", "isnull"}

class FireoFilter(BaseFilterModel):
    """Base filter for Firestore related filters.

    Example:
        ```python
        class MyModel(Model):
            name: TextField(required=True)
            count: NumberField(int_only=True)
            created_at: DatetimeField()

        class MyModelFilter(Filter):
            id: Optional[int]
            id__in: Optional[str]
            count: Optional[int]
            count__lte: Optional[int]
            created_at__gt: Optional[datetime]
            name__not_eq: Optional[str]
            name__not_in: Optional[list[str]]
"""

@validator("*", pre=True)
def split_str(cls, value, field: ModelField):
    return value

@root_validator()
def validate_filter_and_sort_combinations(cls, values: Dict[str, Any]) -> Dict[str, Any]:
    """Validate that the filter and sort combinations are valid for Firestore.

    Changes:
        - If there is an inequality filter, the first sort order must be the same.
        - If there is an inequality filter and no sort order, it will be added.
    """
    orders = values.get(cls.Constants.ordering_field_name, None)

    unequal_filter_fields = set()
    for raw_field_name, value in values.items():
        if value is None:
            continue

        field_name, _, raw_operator = raw_field_name.partition("__")
        if raw_operator in _orm_op_conflicts_with_sorting:
            unequal_filter_fields.add(field_name)

    if not unequal_filter_fields:
        return values

    if len(unequal_filter_fields) > 1:
        raise ValueError(
            f"Cannot have inequality on multiple fields: {unequal_filter_fields}"
        )

    if not orders:
        # Pagination does not work without this ordering
        values[cls.Constants.ordering_field_name] = list(unequal_filter_fields)
        return values

    first_order = orders[0].lstrip("+-")
    filter_field = unequal_filter_fields.pop()
    if filter_field != first_order:
        raise ValueError(
            f"Inequality filter property and first sort order must be the same: {filter_field} and {first_order}"
        )

    return values

def filter(self, query: FilterQuery | QuerySet | Manager) -> FilterQuery:
    for raw_field_name, value in self.filtering_fields:
        field_value = getattr(self, raw_field_name)
        if isinstance(field_value, FireoFilter):
            query = field_value.filter(query)
            continue

        field_name, _, raw_operator = raw_field_name.partition("__")
        query = _orm_operator_filter[raw_operator](query, field_name, value)

    return query

def sort(self, query: FilterQuery | QuerySet | Manager) -> FilterQuery:
    if not self.ordering_values:
        return query

    for order in self.ordering_values:
        query = query.order(order)

    return query

ListItem = TypeVar('ListItem')

class CommaSepList(list, Generic[ListItem]): @classmethod def __get_validators__(cls): yield cls.validate

@classmethod
def validate(cls, v):
    if isinstance(v, list) and len(v) == 1:
        v = v[0].split(',')

    return v

class Order(str, Generic[ListItem]): def __class_getitem__(cls, items) -> Type[Literal[ListItem]]: # type: ignore if not isinstance(items, tuple): items = (items,)

    assert {type(item) for item in items} == {str}
    options = tuple(
        f'{neg}{arg}'
        for neg in ['', '-']
        for arg in items
    )
    return Literal[options]  # type: ignore

def _list_to_commalist(type): if getattr(type, "origin", None) is list: return CommaSepList[type.args[0]] return type_

def _list_to_str_fields(Filter: Type[BaseFilterModel]): """Prepare filter fields to be used in query params.

Unlike the original implementation, this one:
    - allows to use lists in query params as multiple values for the same field
    - split comma separated values in query params to lists, so "split_str"
        is no longer needed
"""
ret: Dict[str, Tuple[Union[object, Type], Optional[FieldInfo]]] = {}
for f in Filter.__fields__.values():
    field_info = deepcopy(f.field_info)
    if not isinstance(field_info.default, params.Query):
        if field_info.default is not Undefined:
            default = field_info.default
        elif f.required:
            default = ...
        else:
            default = None

        field_info.default = Query(default)

    field_type = Filter.__annotations__.get(f.name, f.outer_type_)
    if f.shape == SHAPE_LIST:
        if issubclass(type(field_type), UnionType):
            items = []
            for arg in field_type.__args__:
                items.append(_list_to_comma_list(arg))
            new_field_type = Union[tuple(items)]  # type: ignore
        else:
            new_field_type = _list_to_comma_list(field_type)
        ret[f.name] = (new_field_type, field_info)
    else:
        ret[f.name] = (field_type if f.required else Optional[field_type], field_info)

return ret

filter_lib._list_to_str_fields = _list_to_str_fields