s3rius / FastAPI-template

Feature rich robust FastAPI template.
MIT License
1.79k stars 161 forks source link

How to use DAO in a websocket router? #172

Closed eggb4by closed 1 year ago

eggb4by commented 1 year ago

I have a websocket router like this @router.websocket( path="/ws", ) async def websocket( websocket: WebSocket, ): await websocket.accept() ... and I want to use DAO to save the message in websocket,but if I use async def websocket( websocket: WebSocket, dao: MessageDAO = Depends(), ): when client connect to the websocket, I got a error File "/Users/xxls/Desktop/Project/db/dependencies.py", line 17, in get_db_session session: AsyncSession = request.app.state.db_session_factory() └ <taskiq_dependencies.dependency.Dependency object at 0x111e2bd10>

AttributeError: 'Dependency' object has no attribute 'app'

s3rius commented 1 year ago

The problem with websocket is totallt FastAPI-related. You have multiple options here.

As I saw, it's fixed in latest version https://github.com/tiangolo/fastapi/releases/tag/0.97.0.

But if you still want to use DAOs in websockets with older versions of fastAPI I would suggest to intialize it explicitly.

I'll create an example later today.

eggb4by commented 1 year ago

Thanks! It seemed this framework use fastapi = "^0.89.1" but after I upgrade fastapi to 0.97.0, some error occured in startup. File "/Users/xxls/opt/anaconda3/envs/news_assistant/lib/python3.11/site-packages/starlette/routing.py", line 677, in lifespan async with self.lifespan_context(app) as maybe_state: File "/Users/xxls/opt/anaconda3/envs/news_assistant/lib/python3.11/site-packages/starlette/routing.py", line 566, in __aenter__ await self._router.startup() File "/Users/xxls/opt/anaconda3/envs/news_assistant/lib/python3.11/site-packages/starlette/routing.py", line 654, in startup await handler() File "/Users/xxls/Desktop/Project/proj/web/lifetime.py", line 63, in _startup setup_prometheus(app) File "/Users/xxls/Desktop/Project/proj/web/lifetime.py", line 39, in setup_prometheus PrometheusFastApiInstrumentator(should_group_status_codes=False).instrument( File "/Users/xxls/opt/anaconda3/envs/news_assistant/lib/python3.11/site-packages/prometheus_fastapi_instrumentator/instrumentation.py", line 121, in instrument app.add_middleware( File "/Users/xxls/opt/anaconda3/envs/news_assistant/lib/python3.11/site-packages/starlette/applications.py", line 139, in add_middleware raise RuntimeError("Cannot add middleware after an application has started") RuntimeError: Cannot add middleware after an application has started

s3rius commented 1 year ago

This is another issue with FastAPI. It's still not fixed. But we already figured out a workaround for that.

https://github.com/s3rius/FastAPI-template/issues/149

eggb4by commented 1 year ago

I try this. App startup is OK,but still have something wrong with websocket TypeError: get_db_session() missing 1 required positional argument: 'request'

s3rius commented 1 year ago

That's weird. Can you try depend on app instead? Like this: app: FastAPI = Depends().

eggb4by commented 1 year ago

Sorry, I'm new in fastapi. code like this? in views.py

from fastapi import FastAPI

@router.websocket(
    path="/ws",
)

async def websocket(
    websocket: WebSocket,
    app: FastAPI = Depends(),
):
    await websocket.accept()

got File "/Users/xxls/opt/anaconda3/envs/news_assistant/lib/python3.11/site-packages/fastapi/utils.py", line 103, in create_response_field raise fastapi.exceptions.FastAPIError( fastapi.exceptions.FastAPIError: Invalid args for response field! Hint: check that typing.Optional[typing.List[starlette.routing.BaseRoute]] is a valid Pydantic field type. If you are using a return type annotation that is not a valid Pydantic field (e.g. Union[Response, dict, None]) you can disable generating the response model from the type annotation with the path operation decorator parameter response_model=None. Read more: https://fastapi.tiangolo.com/tutorial/response-model/

s3rius commented 1 year ago

It complains about return type annotations. Can you fix it and try again?

eggb4by commented 1 year ago

Yep, It's wired. I add a short test code

@router.websocket(
    path="/ws_test",
)
async def websocket_teest(
    websocket: WebSocket,
    app: FastAPI = Depends(),
) -> str:
    await websocket.accept()
    headers = websocket.headers
    try:
        while True:
            message_text = await websocket.receive_text()
            loguru.logger.info(f" --> get message: {message_text}")
            try:
                await websocket.send_text(f" get message {message_text}")
            except Exception as e:
                loguru.logger.error(" error", e)
    except WebSocketDisconnect:
        loguru.logger.info(" --> disconnect!")
    return "hello world"

The error is

fastapi.exceptions.FastAPIError: Invalid args for response field! Hint: check that typing.Optional[typing.List[starlette.routing.BaseRoute]] is a valid Pydantic field type. If you are using a return type annotation that is not a valid Pydantic field (e.g. Union[Response, dict, None]) you can disable generating the response model from the type annotation with the path operation decorator parameter response_model=None. Read more: https://fastapi.tiangolo.com/tutorial/response-model/

s3rius commented 1 year ago

Can you fix formatting using code blocks?

```python
{your code here}
```
eggb4by commented 1 year ago

Can you fix formatting using code blocks?

{your code here}

fixed,sorry..

s3rius commented 1 year ago

That's really weird, indeed.

eggb4by commented 1 year ago

That's really weird, indeed.

Could you please write a sample about use DAO in websocket router in this template.Thanks

eggb4by commented 1 year ago

I tried and have a solution. In FastAPI, WebSocket and Request are all from starlette.requests.HTTPConnection and In db/dependencies.py, function get_db_session have params request

async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession, None]:

so I add a

async def get_ws_session(websocket: WebSocket) -> AsyncGenerator[AsyncSession, None]:
    session: AsyncSession = websocket.app.state.db_session_factory()

    try:  # noqa: WPS501
        yield session
    finally:
        await session.commit()
        await session.close()

to get websocket db session from WebSocket object and In DAO, I have a new MessageWSDAO like this

class MessageWSDAO:

    def __init__(self, session: AsyncSession = Depends(get_ws_session)):
        self.session = session

    async def get_latest_message_by_user(self, user_id) -> MessageModel:
        ...

Now it works

eggb4by commented 1 year ago

This is just a makeshift solution, not that elegant.

HEKUCHAN commented 5 months ago

I encountered the same error and successfully resolved it using the following approach.

In the documentation, I came across the following information:

When you want to define dependencies that should be compatible with both HTTP and WebSockets, you can define a parameter that takes an HTTPConnection instead of a Request or a WebSocket. Websocket class Request class HTTPConnection class

Consequently, I made a modification, substituting Request with HTTPConnection:

from typing import AsyncGenerator

from fastapi.requests import HTTPConnection
from sqlalchemy.ext.asyncio import AsyncSession

async def get_db_session(connection: HTTPConnection) -> AsyncGenerator[AsyncSession, None]:
    """
    Create and get database session.

    :param connection: current HTTPConnection.
    :yield: database session.
    """
    session: AsyncSession = connection.app.state.db_session_factory()

    try:  # noqa: WPS501
        yield session
    finally:
        await session.commit()
        await session.close()

This modification allows you to use the same DAO in your websocket route. For example:

@router.websocket(path="/ws")
async def simple_ws(
    websocket: WebSocket,
    user_dao: UserDAO = Depends(),
):
    await websocket.accept()

    try:
        while True:
            _data = await websocket.receive_text()
            user = await user_dao.create_user()
            await user_dao.session.commit()
    except WebSocketDisconnect:
        print('The connection was disconnected.')

Hope this helps someone.