cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.46k stars 545 forks source link

[Feature Request] Generic typing for scale kernels #2491

Open TobyBoyne opened 3 months ago

TobyBoyne commented 3 months ago

🚀 Feature Request

Improve type hinting - using generic types - for kernels that operate on other kernels (such as the ScaleKernel).

Motivation

Is your feature request related to a problem? Please describe. Currently, after building a model following the Simple GP Regression tutorial, the type hint for the base kernel is not fully expressive.

model = ExactGPModel(train_x, train_y, likelihood)
model.covar_module.base_kernel  
# Pylance tells me this has type "Kernel"
# I know the type is actually "RBFKernel"

This isn't too much of a problem with the RBFKernel, which uses the lengthscale parameter that belongs to the base class. However, when you start to work with non-standard kernels, that have different attributes, then it can become annoying to lose the type hints.

Pitch

Describe the solution you'd like
Use Generic types to implement this. It would look something like:

from typing import TypeVar, Generic

KernelType = TypeVar("KernelType", bound=Kernel)

class ScaleKernel(Kernel, Generic[KernelType]):
    def __init__(self, base_kernel: KernelType, ...):
        ....

kernel = ScaleKernel(RBFKernel())
kernel.base_kernel
# Pylance tells me this is an "RBFKernel"!

Describe alternatives you've considered
I'm not sure there are any alternatives? One might argue that the type hinting isn't that important, but I think the modern dev experience uses a lot of autocomplete (for me at least), and this would make that a lot smoother.

Are you willing to open a pull request?
Yes :)