python / typing

Python static typing home. Hosts the documentation and a user help forum.
https://typing.readthedocs.io/
Other
1.6k stars 234 forks source link

Generic ParamSpec in subclass definitions #1405

Open alanhdu opened 1 year ago

alanhdu commented 1 year ago

Given some base class that is generic over a param-spec, I'd like to be able to define the param-spec using a subclass method implementation. Something like:

T = TypeVar("T")
P = ParamSpec("P")

class Base(Generic[P, T]):
    func: Callable[P, T]
    # for instance
    def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Tuple[T]:
        return (self.func(*args, **kwargs), )

class Subclass(Base):
    def func(self, x: int, *, y: str) -> bytes: ...

That is, the base class is generic over some function definition, and the subclass implement that function as a method. I'm not sure how common this pattern is, but I see it a fair amount when the class gets more complicated and you can't just use a decorator (e.g. torch.nn.Module.forward is a big example of this). Is there some other way of specifying "infer the parameters from this method implementation"? Is this something that would require an update to the specification, is it "just" a feature request to the method implementation "like" an assignment statement?

erictraut commented 1 year ago

As with any type parameter, you can explicitly specify the type arguments for P and T in the Subclass class definition. However, the syntax for ParamSpec allows only positional parameters in the specialization.

from typing import Generic, ParamSpec, TypeVar

T = TypeVar("T")
P = ParamSpec("P")

class Base(Generic[P, T]):
    def func(self, *args: P.args, **kwargs: P.kwargs) -> T:
        ...

    def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> tuple[T]:
        return (self.func(*args, **kwargs),)

class Subclass(Base[[int, str], bytes]):
    def func(self, x: int, y: str, /) -> bytes:
        ...

There isn't a way to specialize the ParamSpec with a signature that includes keyword arguments as shown in your code sample. You could specify ... (which is the ParamSpec equivalent of Any).

class Subclass(Base[..., bytes]):
    def func(self, x: int, *, y: str) -> bytes:
        ...
tmke8 commented 1 year ago

There isn't a way to specialize the ParamSpec with a signature that includes keyword arguments

I've also often wished for a way to specify keyword arguments in a concrete ParamSpec, specifically also for torch.nn.Module.forward:

from abc import abstractmethod
from typing import Generic, Optional, TypeVar, final
from typing_extensions import ParamSpec

import torch
from torch import Tensor, nn

P = ParamSpec("P")
T = TypeVar("T", covariant=True)

class BaseModule(nn.Module, Generic[P, T]):
    @abstractmethod
    def forward(self, *args: P.args, **kwargs: P.kwargs) -> T:
        raise NotImplementedError()

    @final
    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
        # do other stuff
        return self.forward(*args, **kwargs)

class Linear(BaseModule[[Tensor], Tensor]):
    """Layer with simple signature."""

    def forward(self, input: Tensor) -> Tensor:
        return input @ self.weights + self.bias

class Transformer(BaseModule[[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tensor]):
    """Complicated signature with optional arguments."""

    def forward(
        self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None
    ) -> Tensor:
        ...

tf = Transformer()
tf(torch.tensor([1, 1]), torch.tensor([1, 1]))  # type error

Ideally, I could just paste the signature in place of the ParamSpec:

class Transformer(BaseModule[(src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = ..., tgt_mask: Optional[Tensor] = ...), Tensor]): ...