Closed colehaus closed 2 months ago
Looking at this more closely, I think the new behavior is correct. This was previously a false negative. The problem is that DType
has no specified upper bound, so its upper bound defaults to object
. Pyright's error message mentions the type object*
which is effectively an intersection between object
and DType
.
For comparison, pyright's new error is consistent with mypy's in this case.
Your code sample looks suspect to me, but perhaps this is because it's simplified from real-world code. In particular, vmap
is using two type variables and each one appears only once. That means you could replace them both with Any
.
Yes, this is a much-simplified version of real, useful code. I will try to come up with a case that's a little more realistic but still tolerably simple and isolated.
Do you have thoughts on the part about eta expansion? Do you think pyright should also be flagging the final line as an error?
I'm not familiar with the term "eta expansion".
No, I don't think pyright should flag the final line as an error. Mypy and pyright currently agree.
Ah, sorry, it's a term more commonly used in functional programming communities. It refers to transformations like vmap(fn)
→ vmap(lambda x: fn(x))
(vs eta reduction which is vmap(lambda x: fn(x))
→ vmap(fn)
). These sorts of transformations should always be meaning preserving so it seems notable that pyright is treating lines which are just eta transformations of each other differently.
Here is a less radically simplified demonstration:
from __future__ import annotations
from collections.abc import Callable
from typing import Generic, NamedTuple, TypeVar, TypeVarTuple
DType = TypeVar("DType")
DType2 = TypeVar("DType2")
Shape = TypeVarTuple("Shape")
Shape2 = TypeVarTuple("Shape2")
Dim1 = TypeVar("Dim1")
SeqLen = TypeVar("SeqLen")
BatchLen = TypeVar("BatchLen")
class ndarray(Generic[*Shape, DType]): ... # noqa: N801
def vmap(
fun: Callable[[ndarray[*Shape, DType]], ndarray[*Shape2, DType2]],
) -> Callable[[ndarray[Dim1, *Shape, DType]], ndarray[Dim1, *Shape2, DType2]]: ...
class Output(NamedTuple, Generic[SeqLen]):
out: ndarray[SeqLen, float]
def fn(tkns: ndarray[BatchLen, SeqLen, float]):
_ = vmap(lambda x: call(x).out)(tkns)
def call(input_: ndarray[SeqLen, float]) -> Output[SeqLen]: ...
No errors in 1.1.374 or earlier but this error in 1.1.375:
pyright.py:28:29 - error: Argument of type "ndarray[*Shape@vmap, DType@vmap]" cannot be assigned to parameter "input_" of type "ndarray[SeqLen@call, float]" in function "call"
"ndarray[*Shape@vmap, DType@vmap]" is incompatible with "ndarray[SeqLen@call, float]"
Type parameter "Shape@ndarray" is invariant, but "*Shape@vmap" is not the same as "*tuple[SeqLen@call]"
Type parameter "DType@ndarray" is invariant, but "DType@vmap" is not the same as "float" (reportArgumentType)
I don't actually know what's going on, but I will speculate wildly in the hopes that it clarifies what I'm getting at. It seems like when a fixed function with a known type signature is supplied to a higher order function like vmap
, pyright has all the information to unify the higher-order function's type variables with the types provided. But when a lambda is supplied the type information from the wrapped function is "hidden". And then pyright tries to unify the higher order function's type variables with Unknown
or something and ends up rigidly fixing the higher order function's type variables to their upper bound. But then values that have been assigned this rigid type aren't really usable by the actual function wrapped in the lambda.
I don't know if any of that makes sense or clarifies. Feel free to totally ignore this point if it doesn't help.
In pyright 1.1.373 and 1.1.374 we get output like the following:
In pyright 1.1.375, we get output like the following:
Note that the inferred type for
t
has transformed fromUnknown
toDType
and this causes the new error.I call this a quasi-regression because the new behavior treats these two cases more consistently and it seems they ought to be treated consistently. But IMO the behavior is now consistently bad. In all Pyright versions, the final non-lambda version type checks and it doesn't seem like eta expansion ought to affect the type checker. But I recognize this is maybe not compatible with the general sort of type inference Python typing tools do. (MyPy also flags both
fn1
andfn2
lines as errors.)