jazzband / django-auditlog

A Django app that keeps a log of changes made to an object.
MIT License
1.07k stars 405 forks source link

Async middleware #639

Open TheSteveBurgess opened 2 months ago

TheSteveBurgess commented 2 months ago

Hi,

Are there any plans to implement async in the current middleware? At the moment using the middleware forces the entire request pipeline to be sync as the middleware doesn't support a async.

Cheers,

Steve.

mantulen commented 2 months ago

I was able to achieve this by creating custom middleware (based on the official one), it is both sync and async capable and works just fine for my use cases:

Currently running on Django 5.0.4 and AuditLog 3.0.0

import contextlib
import time
from contextvars import ContextVar
from functools import partial

from asgiref.sync import iscoroutinefunction
from auditlog.models import LogEntry
from django.contrib.auth import get_user_model
from django.db.models.signals import pre_save
from django.utils.decorators import sync_and_async_middleware

duid = ContextVar('duid')
remote_address = ContextVar('remote_address')

UserModel = get_user_model()

@contextlib.contextmanager
def set_actor(actor, remote_addr=None):
    signal_duid = ('set_actor', time.time())
    duid.set(signal_duid)
    remote_address.set(remote_addr)

    set_actor = partial(_set_actor, user=actor, signal_duid=signal_duid)
    pre_save.connect(
        set_actor,
        sender=LogEntry,
        dispatch_uid=signal_duid,
        weak=False,
    )

    try:
        yield
    finally:
        pre_save.disconnect(sender=LogEntry, dispatch_uid=signal_duid)

def _set_actor(user, sender, instance, signal_duid, **kwargs):
    try:
        ctx_duid = duid.get()
        ctx_remote_address = remote_address.get()
    except LookupError:
        return

    if signal_duid != ctx_duid:
        return

    if sender == LogEntry and isinstance(user, UserModel) and instance.actor is None:
        instance.actor = user

    instance.remote_addr = ctx_remote_address

@sync_and_async_middleware
def AuditlogMiddleware(get_response):
    if iscoroutinefunction(get_response):

        async def middleware(request):
            if hasattr(request, 'auser'):
                user = await request.auser()
                if user.is_authenticated:
                    context = set_actor(actor=user, remote_addr=request.META['REMOTE_ADDR'])
                else:
                    context = contextlib.nullcontext()
            else:
                context = contextlib.nullcontext()
            with context:
                return await get_response(request)
    else:

        def middleware(request):
            if hasattr(request, 'user') and request.user.is_authenticated:
                context = set_actor(actor=request.user, remote_addr=request.META['REMOTE_ADDR'])
            else:
                context = contextlib.nullcontext()
            with context:
                return get_response(request)

    return middleware
mantulen commented 2 months ago

In addition to my last post, I am also using this sync and async middleware to get the 'REMOTE_ADDR' meta:

from asgiref.sync import iscoroutinefunction
from django.utils.decorators import sync_and_async_middleware

@sync_and_async_middleware
def RemoteAddressMiddleware(get_response):
    if iscoroutinefunction(get_response):

        async def middleware(request):
            remote_addr = (
                request.META.get('REMOTE_ADDR', '')
                or request.META.get('HTTP_X_FORWARDED_FOR', '').split(',')[0].strip()
            )

            if '.' in remote_addr and ':' in remote_addr:
                remote_addr = remote_addr.split(':')[0].strip()
            elif '[' in remote_addr:
                remote_addr = remote_addr[1:].split(']')[0].strip()

            request.META['REMOTE_ADDR'] = remote_addr

            return await get_response(request)

    else:

        def middleware(request):
            remote_addr = (
                request.META.get('REMOTE_ADDR', '')
                or request.META.get('HTTP_X_FORWARDED_FOR', '').split(',')[0].strip()
            )

            if '.' in remote_addr and ':' in remote_addr:
                remote_addr = remote_addr.split(':')[0].strip()
            elif '[' in remote_addr:
                remote_addr = remote_addr[1:].split(']')[0].strip()

            request.META['REMOTE_ADDR'] = remote_addr

            return get_response(request)

    return middleware

This should be relatively at the top of your MIDDLEWARE list, example:

MIDDLEWARE = [
    'config.middleware.RemoteAddressMiddleware', # <-- Custom sync/async REMOTE_ADDR middleware
    'django.middleware.security.SecurityMiddleware',
    'django.contrib.sessions.middleware.SessionMiddleware',
    'django.middleware.common.CommonMiddleware',
    'django.middleware.csrf.CsrfViewMiddleware',
    'django.contrib.auth.middleware.AuthenticationMiddleware',
    'config.middleware.RemoteUserMiddleware',
    'django.contrib.messages.middleware.MessageMiddleware',
    'config.middleware.AuditlogMiddleware', # <-- Custom sync/async AuditLog middleware
]