microsoft / pyright

Static Type Checker for Python
Other
13.35k stars 1.46k forks source link

Higher-order function type variable regression in 1.1.376 #8852

Closed colehaus closed 1 month ago

colehaus commented 2 months ago
from __future__ import annotations

from collections.abc import Callable
from typing import Generic, TypeVar, TypeVarTuple

DType = TypeVar("DType")
DType2 = TypeVar("DType2")
DType3 = TypeVar("DType3")
Shape = TypeVarTuple("Shape")
Shape2 = TypeVarTuple("Shape2")
Shape3 = TypeVarTuple("Shape3")
Dim1 = TypeVar("Dim1")
Dim2 = TypeVar("Dim2")
SeqLen = TypeVar("SeqLen")
BatchLen = TypeVar("BatchLen")
EmbedDim = TypeVar("EmbedDim")
Float = TypeVar("Float")

class ndarray(Generic[*Shape, DType]): ...  # noqa: N801

def vmap(
    fun: Callable[[ndarray[*Shape, DType], ndarray[*Shape2, DType2]], ndarray[*Shape3, DType3]],
) -> Callable[[ndarray[Dim1, *Shape, DType], ndarray[Dim1, *Shape2, DType2]], ndarray[Dim1, *Shape3, DType3]]: ...

# Specialized version of `vmap` just for diagnostic purposes. Still errors.
def vmap2(
    fun: Callable[[ndarray[Dim2, DType], ndarray[Dim2, DType]], ndarray[DType]],
) -> Callable[[ndarray[Dim1, Dim2, DType], ndarray[Dim1, Dim2, DType]], ndarray[Dim1, Dim2, DType]]: ...

def foo(x: ndarray[SeqLen, EmbedDim, Float], y: ndarray[SeqLen, EmbedDim, Float]):
    _ = vmap(cosine_similarity)(x, y)

def foo2(x: ndarray[SeqLen, EmbedDim, Float], y: ndarray[SeqLen, EmbedDim, Float]):
    _ = vmap2(cosine_similarity)(x, y)

def cosine_similarity(
    predictions: ndarray[Dim1, Float], targets: ndarray[Dim1, Float], epsilon: Float | ndarray[Float] = ...
) -> ndarray[Float]: ...

The preceding code type checks under 1.1.375 and Mypy. It errors under 1.1.376 with the following complaints:

pyright.py:50:33 - error: Argument of type "ndarray[SeqLen@foo, EmbedDim@foo, Float@foo]" cannot be assigned to parameter of type "ndarray[Dim1@vmap, Dim1@cosine_similarity, Any]"
    "ndarray[SeqLen@foo, EmbedDim@foo, Float@foo]" is incompatible with "ndarray[SeqLen@foo, SeqLen@foo, Any]"
      Type parameter "Shape@ndarray" is invariant, but "*tuple[SeqLen@foo, EmbedDim@foo]" is not the same as "*tuple[SeqLen@foo, SeqLen@foo]" (reportArgumentType)
pyright.py:50:36 - error: Argument of type "ndarray[SeqLen@foo, EmbedDim@foo, Float@foo]" cannot be assigned to parameter of type "ndarray[Dim1@vmap, Dim1@cosine_similarity, Any]"
    "ndarray[SeqLen@foo, EmbedDim@foo, Float@foo]" is incompatible with "ndarray[SeqLen@foo, SeqLen@foo, Any]"
      Type parameter "Shape@ndarray" is invariant, but "*tuple[SeqLen@foo, EmbedDim@foo]" is not the same as "*tuple[SeqLen@foo, SeqLen@foo]" (reportArgumentType)
pyright.py:54:34 - error: Argument of type "ndarray[SeqLen@foo2, EmbedDim@foo2, Float@foo2]" cannot be assigned to parameter of type "ndarray[Dim1@vmap2, Dim1@cosine_similarity, Any]"
    "ndarray[SeqLen@foo2, EmbedDim@foo2, Float@foo2]" is incompatible with "ndarray[SeqLen@foo2, SeqLen@foo2, Any]"
      Type parameter "Shape@ndarray" is invariant, but "*tuple[SeqLen@foo2, EmbedDim@foo2]" is not the same as "*tuple[SeqLen@foo2, SeqLen@foo2]" (reportArgumentType)
pyright.py:54:37 - error: Argument of type "ndarray[SeqLen@foo2, EmbedDim@foo2, Float@foo2]" cannot be assigned to parameter of type "ndarray[Dim1@vmap2, Dim1@cosine_similarity, Any]"
    "ndarray[SeqLen@foo2, EmbedDim@foo2, Float@foo2]" is incompatible with "ndarray[SeqLen@foo2, SeqLen@foo2, Any]"
      Type parameter "Shape@ndarray" is invariant, but "*tuple[SeqLen@foo2, EmbedDim@foo2]" is not the same as "*tuple[SeqLen@foo2, SeqLen@foo2]" (reportArgumentType)
4 errors, 0 warnings, 0 informations
erictraut commented 2 months ago

Thanks for the bug report. I'm able to repro, and I agree it's a bug. I'll investigate further.

As a quick note to myself, here's a stripped-down sample that exhibits the problem. If I rename Dim1 to something else, the problem goes away. Likewise, if I swap out the definition of N for the commented-out version (which should be equivalent, in theory), the problem goes away.

from typing import Any, Generic, TypeVar, TypeVarTuple, Callable

D = TypeVar("D")
S = TypeVarTuple("S")

class N(Generic[*S, D]): ...

# class N[*S, D]:
#     x: D

def func1[*S1, D1, *S2, D2, Dim1](
    c: Callable[[N[*S1, D1], N[*S2, D2]], Any],
) -> Callable[[N[Dim1, *S1, D1], N[Dim1, *S2, D2]], Any]: ...

def func2[X, Y, Z](x: N[X, Y, Z], y: N[X, Y, Z]):
    _ = func1(func3)(x, y)

def func3[Dim1, T](x: N[Dim1, T], y: N[Dim1, T]) -> N[T]: ...
erictraut commented 1 month ago

This is addressed in pyright 1.1.380.