microsoft / pyright

Static Type Checker for Python
Other
13.32k stars 1.45k forks source link

Quasi-regression around `Unknown` and lambdas #8690

Closed colehaus closed 2 months ago

colehaus commented 2 months ago
from __future__ import annotations

from collections.abc import Callable
from typing import Generic, NamedTuple, TypeVar

DType = TypeVar("DType")
DType2 = TypeVar("DType2")

class Wrapped(NamedTuple, Generic[DType]):
    inner: DType

def fn1(input_: int) -> Wrapped[int]: ...
def fn2(input_: int) -> int: ...

def vmap(fun: Callable[[DType], DType2]) -> None: ...

vmap(lambda t: fn1(reveal_type(t)).inner)
vmap(lambda t: fn2(reveal_type(t)))
vmap(fn2)

In pyright 1.1.373 and 1.1.374 we get output like the following:

pyright.py:21:32 - information: Type of "t" is "Unknown"
pyright.py:22:20 - error: Argument of type "DType@vmap" cannot be assigned to parameter "input_" of type "int" in function "fn2"
    "object*" is incompatible with "int" (reportArgumentType)
pyright.py:22:32 - information: Type of "t" is "DType@vmap"
1 error, 0 warnings, 2 informations

In pyright 1.1.375, we get output like the following:

pyright.py:21:20 - error: Argument of type "DType@vmap" cannot be assigned to parameter "input_" of type "int" in function "fn1"
    "object*" is incompatible with "int" (reportArgumentType)
pyright.py:21:32 - information: Type of "t" is "DType@vmap"
pyright.py:22:20 - error: Argument of type "DType@vmap" cannot be assigned to parameter "input_" of type "int" in function "fn2"
    "object*" is incompatible with "int" (reportArgumentType)
 pyright.py:22:32 - information: Type of "t" is "DType@vmap"
2 errors, 0 warnings, 2 informations

Note that the inferred type for t has transformed from Unknown to DType and this causes the new error.

I call this a quasi-regression because the new behavior treats these two cases more consistently and it seems they ought to be treated consistently. But IMO the behavior is now consistently bad. In all Pyright versions, the final non-lambda version type checks and it doesn't seem like eta expansion ought to affect the type checker. But I recognize this is maybe not compatible with the general sort of type inference Python typing tools do. (MyPy also flags both fn1 and fn2 lines as errors.)

erictraut commented 2 months ago

Looking at this more closely, I think the new behavior is correct. This was previously a false negative. The problem is that DType has no specified upper bound, so its upper bound defaults to object. Pyright's error message mentions the type object* which is effectively an intersection between object and DType.

For comparison, pyright's new error is consistent with mypy's in this case.

Your code sample looks suspect to me, but perhaps this is because it's simplified from real-world code. In particular, vmap is using two type variables and each one appears only once. That means you could replace them both with Any.

colehaus commented 2 months ago

Yes, this is a much-simplified version of real, useful code. I will try to come up with a case that's a little more realistic but still tolerably simple and isolated.

Do you have thoughts on the part about eta expansion? Do you think pyright should also be flagging the final line as an error?

erictraut commented 2 months ago

I'm not familiar with the term "eta expansion".

No, I don't think pyright should flag the final line as an error. Mypy and pyright currently agree.

colehaus commented 2 months ago

Ah, sorry, it's a term more commonly used in functional programming communities. It refers to transformations like vmap(fn)vmap(lambda x: fn(x)) (vs eta reduction which is vmap(lambda x: fn(x))vmap(fn)). These sorts of transformations should always be meaning preserving so it seems notable that pyright is treating lines which are just eta transformations of each other differently.

colehaus commented 2 months ago

Here is a less radically simplified demonstration:

from __future__ import annotations

from collections.abc import Callable
from typing import Generic, NamedTuple, TypeVar, TypeVarTuple

DType = TypeVar("DType")
DType2 = TypeVar("DType2")
Shape = TypeVarTuple("Shape")
Shape2 = TypeVarTuple("Shape2")
Dim1 = TypeVar("Dim1")
SeqLen = TypeVar("SeqLen")
BatchLen = TypeVar("BatchLen")

class ndarray(Generic[*Shape, DType]): ...  # noqa: N801

def vmap(
    fun: Callable[[ndarray[*Shape, DType]], ndarray[*Shape2, DType2]],
) -> Callable[[ndarray[Dim1, *Shape, DType]], ndarray[Dim1, *Shape2, DType2]]: ...

class Output(NamedTuple, Generic[SeqLen]):
    out: ndarray[SeqLen, float]

def fn(tkns: ndarray[BatchLen, SeqLen, float]):
    _ = vmap(lambda x: call(x).out)(tkns)

def call(input_: ndarray[SeqLen, float]) -> Output[SeqLen]: ...

No errors in 1.1.374 or earlier but this error in 1.1.375:

pyright.py:28:29 - error: Argument of type "ndarray[*Shape@vmap, DType@vmap]" cannot be assigned to parameter "input_" of type "ndarray[SeqLen@call, float]" in function "call"
    "ndarray[*Shape@vmap, DType@vmap]" is incompatible with "ndarray[SeqLen@call, float]"
      Type parameter "Shape@ndarray" is invariant, but "*Shape@vmap" is not the same as "*tuple[SeqLen@call]"
      Type parameter "DType@ndarray" is invariant, but "DType@vmap" is not the same as "float" (reportArgumentType)
colehaus commented 2 months ago

I don't actually know what's going on, but I will speculate wildly in the hopes that it clarifies what I'm getting at. It seems like when a fixed function with a known type signature is supplied to a higher order function like vmap, pyright has all the information to unify the higher-order function's type variables with the types provided. But when a lambda is supplied the type information from the wrapped function is "hidden". And then pyright tries to unify the higher order function's type variables with Unknown or something and ends up rigidly fixing the higher order function's type variables to their upper bound. But then values that have been assigned this rigid type aren't really usable by the actual function wrapped in the lambda.

I don't know if any of that makes sense or clarifies. Feel free to totally ignore this point if it doesn't help.