Open alanhdu opened 1 year ago
As with any type parameter, you can explicitly specify the type arguments for P
and T
in the Subclass
class definition. However, the syntax for ParamSpec
allows only positional parameters in the specialization.
from typing import Generic, ParamSpec, TypeVar
T = TypeVar("T")
P = ParamSpec("P")
class Base(Generic[P, T]):
def func(self, *args: P.args, **kwargs: P.kwargs) -> T:
...
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> tuple[T]:
return (self.func(*args, **kwargs),)
class Subclass(Base[[int, str], bytes]):
def func(self, x: int, y: str, /) -> bytes:
...
There isn't a way to specialize the ParamSpec
with a signature that includes keyword arguments as shown in your code sample. You could specify ...
(which is the ParamSpec
equivalent of Any
).
class Subclass(Base[..., bytes]):
def func(self, x: int, *, y: str) -> bytes:
...
There isn't a way to specialize the
ParamSpec
with a signature that includes keyword arguments
I've also often wished for a way to specify keyword arguments in a concrete ParamSpec
, specifically also for torch.nn.Module.forward
:
from abc import abstractmethod
from typing import Generic, Optional, TypeVar, final
from typing_extensions import ParamSpec
import torch
from torch import Tensor, nn
P = ParamSpec("P")
T = TypeVar("T", covariant=True)
class BaseModule(nn.Module, Generic[P, T]):
@abstractmethod
def forward(self, *args: P.args, **kwargs: P.kwargs) -> T:
raise NotImplementedError()
@final
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
# do other stuff
return self.forward(*args, **kwargs)
class Linear(BaseModule[[Tensor], Tensor]):
"""Layer with simple signature."""
def forward(self, input: Tensor) -> Tensor:
return input @ self.weights + self.bias
class Transformer(BaseModule[[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tensor]):
"""Complicated signature with optional arguments."""
def forward(
self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None
) -> Tensor:
...
tf = Transformer()
tf(torch.tensor([1, 1]), torch.tensor([1, 1])) # type error
Ideally, I could just paste the signature in place of the ParamSpec
:
class Transformer(BaseModule[(src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = ..., tgt_mask: Optional[Tensor] = ...), Tensor]): ...
Given some base class that is generic over a param-spec, I'd like to be able to define the param-spec using a subclass method implementation. Something like:
That is, the base class is generic over some function definition, and the subclass implement that function as a method. I'm not sure how common this pattern is, but I see it a fair amount when the class gets more complicated and you can't just use a decorator (e.g.
torch.nn.Module.forward
is a big example of this). Is there some other way of specifying "infer the parameters from this method implementation"? Is this something that would require an update to the specification, is it "just" a feature request to the method implementation "like" an assignment statement?