sbdchd / celery-types

:seedling: Type stubs for Celery and its related packages
Apache License 2.0
85 stars 39 forks source link

task_cls #95

Open christianbundy opened 2 years ago

christianbundy commented 2 years ago

Hi there. I'm working with a repository where we've overridden the default task class and added some methods, and we're using task_cls to initialize Celery with this as the default class. I'm seeing lots of errors about these custom methods not existing, and I think this is because celery-types doesn't support task_cls.

Is this something you'd be open to supporting in the future?

christianbundy commented 2 years ago

Here's the Celery documentation we used to make these changes, in case it's useful: https://docs.celeryq.dev/en/latest/userguide/tasks.html#custom-task-classes

sbdchd commented 2 years ago

yeah totally open to it, seems a little tricky, might have to add an explicit generic to param to Celery

https://docs.celeryq.dev/en/latest/userguide/tasks.html#app-wide-usage

NixBiks commented 1 year ago

any workaround available? My current solution is disabling the type check there

@app.task(name="handle_request", bind=True)  # type: ignore
def handle_request(self: SessionTask):
    ...
christianbundy commented 1 year ago

My workaround: define a custom CustomCelery subclass of Celery that uses your CustomTask everywhere that Task would be used.

_Task = TypeVar("_Task", bound=CustomTask[Any, Any])

class CustomCelery(Celery):
    """
    HACK: Required until https://github.com/sbdchd/celery-types/issues/95 is resolved.
    """

    @overload  # type: ignore[override, no-overload-impl]
    def task(self, fun: Callable[Params, Return]) -> CustomTask[Params, Return]:
        ...

    @overload
    def task(
        self,
        *,
        name: str = ...,
        serializer: str = ...,
        bind: bool = ...,
        autoretry_for: tuple[type[Exception], ...] = ...,
        max_retries: int = ...,
        default_retry_delay: int = ...,
        acks_late: bool = ...,
        ignore_result: bool = ...,
        soft_time_limit: int = ...,
        time_limit: int = ...,
        base: type[_Task],
        retry_kwargs: dict[str, Any] = ...,
        retry_backoff: bool | int = ...,
        retry_backoff_max: int = ...,
        retry_jitter: bool = ...,
        typing: bool = ...,
        rate_limit: str | None = ...,
        trail: bool = ...,
        send_events: bool = ...,
        store_errors_even_if_ignored: bool = ...,
        autoregister: bool = ...,
        track_started: bool = ...,
        acks_on_failure_or_timeout: bool = ...,
        reject_on_worker_lost: bool = ...,
        throws: tuple[type[Exception], ...] = ...,
        expires: float | datetime.datetime | None = ...,
        priority: int | None = ...,
        resultrepr_maxsize: int = ...,
        request_stack: _LocalStack = ...,
        abstract: bool = ...,
        queue: str = ...,
    ) -> Callable[[Callable[..., Any]], _Task]:
        ...

    @overload
    def task(
        self,
        *,
        name: str = ...,
        serializer: str = ...,
        bind: Literal[False] = ...,
        autoretry_for: tuple[type[Exception], ...] = ...,
        max_retries: int = ...,
        default_retry_delay: int = ...,
        acks_late: bool = ...,
        ignore_result: bool = ...,
        soft_time_limit: int = ...,
        time_limit: int = ...,
        base: None = ...,
        retry_kwargs: dict[str, Any] = ...,
        retry_backoff: bool | int = ...,
        retry_backoff_max: int = ...,
        retry_jitter: bool = ...,
        typing: bool = ...,
        rate_limit: str | None = ...,
        trail: bool = ...,
        send_events: bool = ...,
        store_errors_even_if_ignored: bool = ...,
        autoregister: bool = ...,
        track_started: bool = ...,
        acks_on_failure_or_timeout: bool = ...,
        reject_on_worker_lost: bool = ...,
        throws: tuple[type[Exception], ...] = ...,
        expires: float | datetime.datetime | None = ...,
        priority: int | None = ...,
        resultrepr_maxsize: int = ...,
        request_stack: _LocalStack = ...,
        abstract: bool = ...,
        queue: str = ...,
    ) -> Callable[[Callable[Params, Return]], CustomTask[Params, Return]]:
        ...

    @overload
    def task(
        self,
        *,
        name: str = ...,
        serializer: str = ...,
        bind: Literal[True],
        autoretry_for: tuple[type[Exception], ...] = ...,
        max_retries: int = ...,
        default_retry_delay: int = ...,
        acks_late: bool = ...,
        ignore_result: bool = ...,
        soft_time_limit: int = ...,
        time_limit: int = ...,
        base: None = ...,
        retry_kwargs: dict[str, Any] = ...,
        retry_backoff: bool | int = ...,
        retry_backoff_max: int = ...,
        retry_jitter: bool = ...,
        typing: bool = ...,
        rate_limit: str | None = ...,
        trail: bool = ...,
        send_events: bool = ...,
        store_errors_even_if_ignored: bool = ...,
        autoregister: bool = ...,
        track_started: bool = ...,
        acks_on_failure_or_timeout: bool = ...,
        reject_on_worker_lost: bool = ...,
        throws: tuple[type[Exception], ...] = ...,
        expires: float | datetime.datetime | None = ...,
        priority: int | None = ...,
        resultrepr_maxsize: int = ...,
        request_stack: _LocalStack = ...,
        abstract: bool = ...,
        queue: str = ...,
    ) -> Callable[
        [Callable[Concatenate[CustomTask[Params, Return], Params], Return]],
        CustomTask[Params, Return],
    ]:
        ...

And then use that annotation:

celery_app: CustomCelery = celery.Celery("foo", task_cls=CustomTask)  # type: ignore[assignment]

This is hacky and brittle and I'm unhappy with it, but it only requires two type ignores rather than one per task.

tcrasset commented 1 year ago

@christianbundy What does your CustomTask definition look like?

Is Params a ParamSpec?

How did you type your decorated functions instead? Could you show an example?

EDIT: Forgot to override app.task to take in Generics as described in README.md

I can't seem to make it pass typecheck, AND work at runtime.

I get

_Task = TypeVar("_Task", bound=CeleryTask[Any, Any]) E TypeError: 'type' object is not subscriptable

~at runtime when trying to pass in generics to CeleryTask.~

This is my current task:

_Params = ParamSpec('_Params')
_Results = TypeVar('_Results')

class CeleryTask(celery.Task[_Params, _Results]):
    """Custom celery task to be able to log and store progress of the execution."""

    def __init__(self, *args: _Params.args, **kwargs: _Params.kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.task_start: Optional[float]
        self.previous_progress: Optional[TaskProgress]
        self.job_id: Optional[UUID] = None
christianbundy commented 1 year ago

Yep, same as yours. I have some custom init / methods / etc, but the important bits are the same:

Params = ParamSpec("Params")
Return = TypeVar("Return")

class CustomTask(Task[Params, Return]):
   ...