microsoft / pyright

Static Type Checker for Python
Other
13.04k stars 1.39k forks source link

Type inference sometimes fails with operators #8334

Closed robert-bn closed 1 month ago

robert-bn commented 1 month ago

Consider the following code

from collections.abc import Callable

class A[T]:
    def __rshift__[U](self, f: Callable[[T], U]) -> U:
        ...

    def seq[U](self, f: Callable[[T], U]) -> U:
        ...

def example(a: A[int]):
    reveal_type(a.seq(lambda i: i))
    reveal_type(a >> (lambda i: i))

This produces the following output when analysed with pyright:

example.py:11:17 - information: Type of "a.seq(lambda i: i)" is "int"
example.py:12:17 - information: Type of "a >> lambda i: i" is "Unknown"

Since __rshift__ and seq are the same type, I would expect type inference to work the same and infer that the i is type int in both cases.

I've seen the behaviour as __rshift__ for all the operators I've tried, but I haven't tested them all.

Tested with cli pyright version 1.1.370

erictraut commented 1 month ago

Pyright is working correctly here, so this isn't a bug.

In Python, binary operators like >> are relatively complex. The runtime evaluates both the left and right operands and determines whether to call __rshift__ on the left operand or its inverse (__lshift__) on the right operands. It may attempt one and then the other if the first fails. For this reason, static type checkers must evaluate the types of the left operand and the right operand independently without the help of bidirectional type inference. Lambdas require bidirectional type inference for complete type evaluation because their parameter types cannot be annotated.

When you call a.seq directly, bidirectional type inference can be used because the argument lambda i: i has an expected type of Callable[[T], U]. The same is true if you manually call __rshift__ (a.__rshift__(lambda i: i))). But if you use the >> operator, a static type checker cannot determine the types unambiguously.

If you want to use the >> with a lambda expression, you can create a temporary variable with the needed type infromation:

def example(a: A[int]):
    temp: Callable[[int], int] = lambda i: i
    reveal_type(a >> temp)

or

def example(a: A[int]):
    def temp(i: int): return i
    reveal_type(a >> temp)
robert-bn commented 1 month ago

Ah okay. Thank you for the detailed explanation.