aio-libs / aiohttp

Asynchronous HTTP client/server framework for asyncio and Python
https://docs.aiohttp.org
Other
14.99k stars 1.99k forks source link

Added an argument to TraceConfig that can be set to a constant. #6544

Open satodaiki opened 2 years ago

satodaiki commented 2 years ago

Is your feature request related to a problem?

hi, there.

In the current TraceConfig configuration, context_factory can be set, but no other values can be set. This is a little inconvenient when hiding bodies that you don't want to be output.

example:

import asyncio
import uuid
from datetime import datetime
from typing import List

import aiohttp
from pydantic import BaseModel, Field

class TraceConfigModel(BaseModel):
    request_id: uuid.UUID = Field(default_factory=uuid.uuid4)
    request_start_time: datetime = Field(default_factory=datetime.utcnow)

    hide_request_body_urls: List[str] = Field(["https://httpstat.us/200"])

class TraceConfigModelFactory:

    def __init__(self, **kwargs):
        pass

    def on_request_start(
        self,
        params: aiohttp.TraceRequestStartParams,
    ):
        self._trace_config_model = TraceConfigModel()

    @property
    def trace_config_model(self):
        return self._trace_config_model

async def _on_request_start(
    session: aiohttp.ClientSession,
    context: TraceConfigModelFactory,
    params: aiohttp.TraceRequestStartParams,
):
    context.on_request_start(params)

async def _on_request_end(
    session: aiohttp.ClientSession,
    context: TraceConfigModelFactory,
    params: aiohttp.TraceRequestEndParams,
):
    body = ''
    for url in context.trace_config_model.hide_request_body_urls:
        if url == str(params.url):
            body = '#######'
        else:
            body = await params.response.text('utf-8')

    print(body)

async def main():
    trace_config = aiohttp.TraceConfig(TraceConfigModelFactory)
    trace_config.on_request_start.append(_on_request_start)
    trace_config.on_request_end.append(_on_request_end)

    session = aiohttp.ClientSession(
        trace_configs=[trace_config]
    )

    async with session as sess:
        async with sess.get("https://httpstat.us/200") as res:
            pass

if __name__ == '__main__':
    asyncio.run(main())

(I'm using pydantic for readability.)

When you do this, you will see the hidden body on the standard output. However, since the context_factory is regenerated with each request, putting "hide_request_body_urls" in TraceConfigModel is not the preferred implementation. I thought it would be better to have it as a constant when creating the TraceConfig instance if possible.

Describe the solution you'd like

The easiest way to use it is to add an argument for the constant to the TraceConfig initialization argument. Or you can add that to the ClientSession side.

example:

TraceConfig(trace_config_ctx_factory=TraceConfigModelFactory, constant={"url": ["https://httpstat.us/200"]})

(The variable names could be a bit better.)

Describe alternatives you've considered

As a workaround at the moment, it is possible to use dependency injection to get around this. We use python_dependency_injector for this.

example

import asyncio
import os
import uuid
from datetime import datetime
from typing import List

import aiohttp
from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject
from pydantic import AnyHttpUrl, BaseModel, BaseSettings, Field

os.environ["url"] = '["https://httpstat.us/200"]'

class TraceConfigSetting(BaseSettings):
    hide_request_body_urls: List[AnyHttpUrl] = Field(..., env='url')

class TraceConfigContainer(containers.DeclarativeContainer):

    config = providers.Singleton(TraceConfigSetting)

class TraceConfigModel(BaseModel):
    request_id: uuid.UUID = Field(default_factory=uuid.uuid4)
    request_start_time: datetime = Field(default_factory=datetime.utcnow)

class TraceConfigModelFactory:

    def __init__(self, **kwargs):
        pass

    def on_request_start(
        self,
        params: aiohttp.TraceRequestStartParams,
    ):
        self._trace_config_model = TraceConfigModel()

    @property
    def trace_config_model(self):
        return self._trace_config_model

@inject
async def _on_request_start(
    session: aiohttp.ClientSession,
    context: TraceConfigModelFactory,
    params: aiohttp.TraceRequestStartParams,
    config: TraceConfigSetting = Provide[TraceConfigContainer.config],
):
    context.on_request_start(params)

@inject
async def _on_request_end(
    session: aiohttp.ClientSession,
    context: TraceConfigModelFactory,
    params: aiohttp.TraceRequestEndParams,
    config: TraceConfigSetting = Provide[TraceConfigContainer.config],
):
    body = ''
    for url in config.hide_request_body_urls:
        if url == str(params.url):
            body = '#######'
        else:
            body = await params.response.text('utf-8')

    print(body)

async def main():
    trace_config_container = TraceConfigContainer()
    trace_config_container.wire(modules=[__name__])

    trace_config = aiohttp.TraceConfig(TraceConfigModelFactory)
    trace_config.on_request_start.append(_on_request_start)
    trace_config.on_request_end.append(_on_request_end)

    session = aiohttp.ClientSession(
        trace_configs=[trace_config]
    )

    async with session as sess:
        async with sess.get("https://httpstat.us/200") as res:
            pass

if __name__ == '__main__':
    asyncio.run(main())

Related component

Client

Additional context

No response

Code of Conduct

Dreamsorcerer commented 3 weeks ago

In your example, aren't you planning to use the params from the context to set the request_id/request_start_time? I don't see how a constant would be of use to you...

Dreamsorcerer commented 3 weeks ago

However, since the context_factory is regenerated with each request, putting "hide_request_body_urls" in TraceConfigModel is not the preferred implementation.

I really don't see why you don't just put a reference in the class:

HIDE_REQUEST_BODY_URLS = ("https://httpstat.us/200",)

class TraceConfigModelFactory:
    def __init__(self, **kwargs):
        self._hide_request_body_urls = HIDE_REQUEST_BODY_URLS

or

class TraceConfigModelFactory:
    hide_request_body_urls = ("https://httpstat.us/200",)