snok / asgi-correlation-id

Request ID propagation for ASGI apps
MIT License
369 stars 29 forks source link

How to pass correlation_id to tasks executed in a multithreaded environment? #45

Closed philippefutureboy closed 1 year ago

philippefutureboy commented 1 year ago

EDIT: Changed the name of the issue for better searchability; you can find the solution to the question here


Hey there!

I feel pretty stupid asking this question, but can you explain to me how I should create my logger instance to have a correlation_id?

Currently I create my logger at the top of the a router file:

import logging
from fastapi import APIRouter, HTTPException

LOG = logging.getLogger(__name__)

router = APIRouter(prefix="/my/route", responses={404: {"description": "Not found"}})

@router.get("/")
def handler():
   LOG.info("Hello!")

And I get

[2022-07-19T20:37:48] INFO [None] path.to.module | Hello

when my logging configuration is as follows:

    "formatters": {
        "default": {
            "format": "[%(asctime)s] %(levelname)s [%(correlation_id)s] %(name)s | %(message)s",
            "datefmt": "%Y-%m-%dT%H:%M:%S",
        }
    },
    app.add_middleware(
        CorrelationIdMiddleware,
        header_name='X-Request-ID',
        generator=lambda: uuid4().hex,
        validator=is_valid_uuid4,
        transformer=lambda a: a,
    )

-- I would like to have my correlation_id show up in my log like so:

[2022-07-19T20:37:48] INFO [8fe9728a] path.to.module | Hello

I can't get anything about it in both the Starlette and FastAPI documentation. It's like everybody knows this and it's not worth mentionning πŸ€”

Can you give me an example of how I should get a logger instance to have the request id show up?

Thanks for your help!

JonasKs commented 1 year ago

Did you add the filter? 😊 I'd recommend reading @sondrelg 's blog article here or this part of the readme 😊

philippefutureboy commented 1 year ago

Hi @JonasKs! Thanks for your answer :)

Did you add the filter?

I did:

LOGGING = {
    "version": 1,
    "disable_existing_loggers": 0,
    "filters": {  # correlation ID filter must be added here to make the %(correlation_id)s formatter work
        "correlation_id": {
            "()": "asgi_correlation_id.CorrelationIdFilter",
            "uuid_length": 8 if os.environ["ENVIRONMENT"] == "dev" else 32,
        },
    },
    "formatters": {
        "default": {
            "format": "[%(asctime)s] %(levelname)s [%(correlation_id)s] %(name)s | %(message)s",
            "datefmt": "%Y-%m-%dT%H:%M:%S",
        }
    },
    "handlers": {
        "stdout": {
            "level": "INFO",
            "class": "logging.StreamHandler",
            "formatter": "default",
            "filters": ["correlation_id"],
        },
    },
    "loggers": {
        "": {
            "handlers": ["stdout"],
            "propagate": True,
            "level": "INFO",
        },
    },
}

But that's not the whole story. I've simplified the example in the OP... in reality I use the logger as part of a ThreadPoolExecutor job called by the handler. Could that be the source of the issue? I run multiple blocking calls in parallel using multithreading, and I pass the logger along to the job:

import logging
from fastapi import APIRouter

LOG = logging.getLogger(__name__)

router = APIRouter(prefix="/my/route", responses={404: {"description": "Not found"}})

...

@router.get("/")
def handler(task_ids):
    task_results = run_all(
        logger=LOG,
        tasks={
            task_id: TaskTuple(
                callable=heavy_task,
                args=[],
                kwargs={},
            )
            for task_id in task_ids
        },
    )

    return task_results

def run_all(
    tasks: Dict[str, TaskTuple],
    logger: logging.Logger = None,
    max_workers: int = 20,
):
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        jobs = {}
        for name, task in tasks.items():
            fn, args, kwargs = task
            kwargs = kwargs.copy()
            signature = inspect.signature(fn)
            if "logger" in signature.parameters:
                if logger is None:
                    raise TypeError(logger)
                kwargs["logger"] = logger.getChild(name)

            jobs[name] = executor.submit(fn, *args, **kwargs)

    return {name: job.result() for name, job in jobs.items()}
JonasKs commented 1 year ago

This package uses contextvars to store the UUID, which won't be accessible from another thread. I'm on vacation without a PC at the moment, but this is probably solvable 😊

A quick fix would be to manually pass the ID to the task and then set it manually from that task.

philippefutureboy commented 1 year ago

Thanks for the additional info Jonas! That was enough to solve the issue πŸ‘Œ

I took the solution presented in this medium article as a way to transfer the contextvars.

Here's my final impl for the run_all ThreadPoolExecutor function:

def run_all(
    tasks: Dict[str, TaskTuple],
    pass_contextvars: bool = False,
    max_workers: int = 20,
):
    exec_kwargs = {}
    if pass_contextvars:
        parent_context = contextvars.copy_context()
        exec_kwargs = {"initializer": _set_context, "initargs": (parent_context,)}

    with ThreadPoolExecutor(max_workers=max_workers, **exec_kwargs) as executor:
        jobs = {}
        for name, task in tasks.items():
            fn, args, kwargs = task
            kwargs = kwargs.copy()
            jobs[executor.submit(fn, *args, **kwargs)] = name

    return {
        name: future.result()
        for future, name in zip(as_completed(jobs.keys()), jobs.values())
    }

def _set_context(context):
    for var, value in context.items():
        var.set(value)

Closing the issue!

JonasKs commented 1 year ago

Perfect😊