BeanieODM / beanie

Asynchronous Python ODM for MongoDB
http://beanie-odm.dev/
Apache License 2.0
1.94k stars 203 forks source link

[QUESTION] chained call behavior when use find+limit+skip+aggregate #313

Closed hd10180 closed 1 year ago

hd10180 commented 2 years ago

i try to chained call with method find, sort, limit, skip and finally use aggregate, but got unexpected result.

a. insert 100 records into db b. try to find data use {} as a condition c. use limit(10) to limit return numbers: what i need is 10 d. use aggregate to query across tables

expected: return 10 records
actually: return 100 records

my env:

Package           Version
----------------- ---------
beanie            1.11.6
motor             3.0.0
pydantic          1.9.1
pymongo           4.2.0
pytest            7.1.2
pytest-asyncio    0.19.0
pytest-cov        3.0.0

Q1: can i channed call the methods like: DocType.find().sort().limit().skip().aggregate() ? Q2: what is the correct result of the test case?

reproduce steps: the code is complete, it should run with pytest

codes

- models.py
```python
from beanie import PydanticObjectId, Document
from pydantic import BaseModel, Field
from pymongo import IndexModel, DESCENDING

class Cup(Document):
    name: str
    width: float
    height: float

    class Collection:
        name = "cup"
        indexes = [
            IndexModel([("name", DESCENDING)], unique=True),
        ]

class Water(Document):
    name: str

    class Collection:
        name = "water"
        indexes = [
            IndexModel([("name", DESCENDING)], unique=True),
        ]

class WaterInCup(Document):
    cup: PydanticObjectId
    water: PydanticObjectId

    class Collection:
        name = "water_in_cup"
        indexes = [
            IndexModel([("cup", DESCENDING), ("water", DESCENDING)], unique=True),
        ]

class CupInfo(BaseModel):
    id: PydanticObjectId = Field(alias="_id")
    name: str
    width: float
    height: float

class WaterInfo(BaseModel):
    id: PydanticObjectId = Field(alias="_id")
    name: str

class OutWaterInCup(BaseModel):
    id: PydanticObjectId = Field(alias="_id")
    cup: PydanticObjectId
    water: PydanticObjectId
    cup_info: CupInfo
    water_info: WaterInfo

from .models import Cup, Water, WaterInCup

LOGGER = logging.getLogger(name)

class Settings(BaseSettings): uri: str = "mongodb://127.0.0.1:27017/pytest" db_name: str = "test_db"

settings = Settings()

@pytest.fixture def motor_client(): LOGGER.info(">>>>>>>>>>init_client") return motor.motor_asyncio.AsyncIOMotorClient(settings.uri)

@pytest.fixture def db(motor_client): LOGGER.info(">>>>>>>>>>init_db") db = motor_client[settings.db_name] return db

@pytest.fixture(autouse=True) async def lifespan(motor_client, db):

!pre:

LOGGER.info(">>>>>>>>>>init entity")
await init_beanie(database=db, document_models=[Cup, Water, WaterInCup])  # type: ignore
LOGGER.info(">>>>>>>>>>init entity completed")

yield None

# !after:
LOGGER.info(f">>>>>>>>>>clean database:{settings.db_name}")
await motor_client.drop_database(settings.db_name)
LOGGER.info(f">>>>>>>>>>clean finished.")
- test_chain_call.py
```python
import pytest
from .models import Cup, Water, WaterInCup, OutWaterInCup
import logging

LOGGER = logging.getLogger(__name__)

class TestPaginate:
    @pytest.fixture(autouse=True)
    async def prepare_data(self):
        for i in range(100):
            cup = await Cup(name=f"Cup_{i}", width=40, height=100).save()
            water = await Water(name=f"water_{i}").save()
            await WaterInCup(cup=cup.id, water=water.id).save()

    async def test_aggregate_query(self):
        query = {}
        pipeline = [
            {
                "$lookup": {
                    "from": "cup",
                    "let": {"cid": "$cup"},
                    "pipeline": [
                        {"$match": {"$expr": {"$eq": ["$_id", "$$cid"]}}},
                    ],
                    "as": "cup_info",
                }
            },
            {
                "$unwind": {
                    "path": "$cup_info",
                    "preserveNullAndEmptyArrays": False,
                }
            },
            {
                "$lookup": {
                    "from": "water",
                    "let": {"wid": "$water"},
                    "pipeline": [
                        {"$match": {"$expr": {"$eq": ["$_id", "$$wid"]}}},
                    ],
                    "as": "water_info",
                }
            },
            {
                "$unwind": {
                    "path": "$water_info",
                    "preserveNullAndEmptyArrays": False,
                }
            },
        ]
        from beanie.odm.enums import SortDirection

        res = (
            await WaterInCup.find(query)
            .sort([("_id", SortDirection.DESCENDING)])
            .limit(10)
            .skip(0)
            .aggregate(pipeline, projection_model=OutWaterInCup)
            .to_list()
        )
        LOGGER.critical(
            ">>>>>>  the expected result length is 10 but got 100 instead.  <<<<<<"
        )
        assert len(res) == 100  # PASSED
        assert len(res) == 10  # FAILED
hd10180 commented 2 years ago

after debug i found: the find query's limit, skip, sort_expressions doesn't pass to the aggregation. see https://github.com/roman-right/beanie/blob/main/beanie/odm/queries/find.py#L527-L556

    def aggregate(
        self,
        aggregation_pipeline: List[Any],
        projection_model: Optional[Type[FindQueryProjectionType]] = None,
        session: Optional[ClientSession] = None,
        ignore_cache: bool = False,
        **pymongo_kwargs,
    ) -> Union[
        AggregationQuery[Dict[str, Any]],
        AggregationQuery[FindQueryProjectionType],
    ]:
        """
        Provide search criteria to the [AggregationQuery](https://roman-right.github.io/beanie/api/queries/#aggregationquery)
        :param aggregation_pipeline: list - aggregation pipeline. MongoDB doc:
        <https://docs.mongodb.com/manual/core/aggregation-pipeline/>
        :param projection_model: Type[BaseModel] - Projection Model
        :param session: Optional[ClientSession] - PyMongo session
        :param ignore_cache: bool
        :return:[AggregationQuery](https://roman-right.github.io/beanie/api/queries/#aggregationquery)
        """
        self.set_session(session=session)
        return self.AggregationQueryType(
            aggregation_pipeline=aggregation_pipeline,
            document_model=self.document_model,
            projection_model=projection_model,
            find_query=self.get_filter_query(),
            ignore_cache=ignore_cache,
            **pymongo_kwargs,
        ).set_session(session=self.session)

maybe we can pass them like

        return self.AggregationQueryType(
            aggregation_pipeline=aggregation_pipeline,
            document_model=self.document_model,
            projection_model=projection_model,
            find_query=self.get_filter_query(),
            ignore_cache=ignore_cache,
            limit=self.limit_number,  # pass the limit
            skip=self.skip_number,  # pass the skip
            sort_expressions=self.sort_expressions,  # pass the sort exp
            **pymongo_kwargs,
        ).set_session(session=self.session)

and then use the params in aggregation.py's get_aggregation_pipeline function

    def get_aggregation_pipeline(
        self,
    ) -> List[Mapping[str, Any]]:
        match_pipeline: List[Mapping[str, Any]] = (
            [{"$match": self.find_query}] if self.find_query else []
        )
        # use the params
        sort_pipeline = {"$sort": {i[0]: i[1] for i in self.sort_expressions}}
        if sort_pipeline["$sort"]:
            match_pipeline.append(sort_pipeline)
        if self.skip_number and self.skip_number != 0:
            match_pipeline.append({"$skip": self.skip_number})
        if self.limit_number and self.limit_number != 0:
            match_pipeline.append({"$limit": self.limit_number})
        # end of patch
        projection_pipeline: List[Mapping[str, Any]] = []
        if self.projection_model:
            projection = get_projection(self.projection_model)
            if projection is not None:
                projection_pipeline = [{"$project": projection}]
        return match_pipeline + self.aggregation_pipeline + projection_pipeline

@roman-right

roman-right commented 1 year ago

Hey! Sorry for the delay. It looks like a bug. I'll pick it up soon. Thank you

github-actions[bot] commented 1 year ago

This issue is stale because it has been open 30 days with no activity.

github-actions[bot] commented 1 year ago

This issue was closed because it has been stalled for 14 days with no activity.