Open Bohdan-Ilchyshyn opened 6 months ago
I create singleton middleware. It checks whether such a task already exists and, if so, should cancel its sending in the pre_send func. How to do it correctly? Return None or raise exception?
Middleware code
import inspect import time from hashlib import md5 from typing import Any, Coroutine, Union from cashews import cache from loguru import logger from orjson import orjson from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult class SingletonMiddleware(TaskiqMiddleware): SINGLETON_LABEL = "singleton" UNIQUE_ON_LABEL = "unique_on" LOCK_EXPIRE_LABEL = "lock_expire" KEY_PREFIX = "TKQ_SINGLETON_LOCK_" def __init__( self, default_lock_expire: int = 60, ) -> None: super().__init__() self.default_lock_expire = default_lock_expire def pre_send( self, message: "TaskiqMessage", ) -> "Union[TaskiqMessage, Coroutine[Any, Any, TaskiqMessage]]": if self.is_singleton_task(message): return self.lock_and_run(message) else: return message async def post_execute( self, message: "TaskiqMessage", result: "TaskiqResult[Any]", ) -> "Union[None, Coroutine[Any, Any, None]]": if self.is_singleton_task(message): await self.release_lock(message) return None async def on_error( self, message: "TaskiqMessage", result: "TaskiqResult[Any]", exception: BaseException, ) -> "Union[None, Coroutine[Any, Any, None]]": if self.is_singleton_task(message): await self.release_lock(message) return None def is_singleton_task(self, message: "TaskiqMessage") -> bool: return self.SINGLETON_LABEL in message.labels @staticmethod async def unlock(lock_key: str, task_id: str) -> bool: return await cache.unlock(lock_key, task_id) @staticmethod async def lock(lock_key: str, task_id: str, expire: int) -> bool: return await cache.set_lock(key=lock_key, value=task_id, expire=expire) @staticmethod async def locked(lock_key: str) -> bool: return await cache.is_locked(key=lock_key) @staticmethod async def get_existing_task_id(lock_key: str) -> int: return await cache.get(key=lock_key) async def lock_and_run(self, message: TaskiqMessage) -> TaskiqMessage | None: lock_acquired = await self.acquire_lock(message) if lock_acquired: return message else: lock_key = self.generate_lock(message) existing_task_id = self.get_existing_task_id(lock_key) logger.warning(f"Attempted to queue a duplicate of task ID {existing_task_id}") # raise SendTaskError() return None async def get_lock_expire(self, message: "TaskiqMessage") -> int: if self.LOCK_EXPIRE_LABEL in message.labels: return message.labels[self.LOCK_EXPIRE_LABEL] elif 'timeout' in message.labels: task_timeout = int(message.labels['timeout']) task_timeout += 5 * 60 return task_timeout else: return self.default_lock_expire async def release_lock(self, message: "TaskiqMessage") -> bool: lock_key = self.generate_lock(message) unlocked = await self.unlock(lock_key, message.task_id) return unlocked async def acquire_lock(self, message: "TaskiqMessage") -> bool: lock_key = self.generate_lock(message) lock_expire = await self.get_lock_expire(message) locked = await self.lock(lock_key, message.task_id, lock_expire) return locked @staticmethod def generate_lock_key(task_name: str, task_args: list, task_kwargs: dict, key_prefix: str) -> str: str_args = str(orjson.dumps(task_args, option=orjson.OPT_SORT_KEYS)) str_kwargs = str(orjson.dumps(task_kwargs, option=orjson.OPT_SORT_KEYS)) task_hash = md5((task_name + str_args + str_kwargs).encode()).hexdigest() return key_prefix + task_hash def generate_lock(self, message: "TaskiqMessage") -> str: task = self.broker.find_task(message.task_name) if unique_on := message.labels.get('unique_on'): if isinstance(unique_on, str): unique_on = [unique_on] sig = inspect.signature(task.original_func) bound = sig.bind(*message.args, **message.kwargs).arguments unique_args = [] unique_kwargs = {key: bound[key] for key in unique_on} else: unique_args = message.args unique_kwargs = message.kwargs lock_key = self.generate_lock_key( task_name=str(message.task_name), task_args=unique_args, task_kwargs=unique_kwargs, key_prefix=self.KEY_PREFIX, ) return lock_key
Task example
@broker.task( singleton=True, unique_on=['id', 'name'] ) async def my_singleton_task(id: str, name: str) -> None: pass
I create singleton middleware. It checks whether such a task already exists and, if so, should cancel its sending in the pre_send func. How to do it correctly? Return None or raise exception?
Middleware code
Task example