python / mypy

Optional static typing for Python
https://www.mypy-lang.org/
Other
18.57k stars 2.84k forks source link

Float multiplication of numpy.ndarray in lambda is incorrectly analysed #8001

Open jamesohortle opened 5 years ago

jamesohortle commented 5 years ago

Note: if you are reporting a wrong signature of a function or a class in the standard library, then the typeshed tracker is better suited for this report: https://github.com/python/typeshed/issues

Please provide more information to help us understand the issue:

Bug.

from typing import NewType, Tuple, Deque

import numpy as np

class Point(np.ndarray):
    def __new__(cls, x: float, y: float) -> np.ndarray:
        return np.array((x, y), dtype=np.float32)

Nose = NewType("Nose", Tuple[Point, ...])

def average_position(positions: Deque) -> Nose:
    factor = 1.0 / len(positions)
    point_sum = [Point(0.0, 0.0) for _ in range(10)]
    for nose in positions:
        for i, point in enumerate(nose):
            point_sum[i] += point
    avg_pos = Nose(tuple(map(lambda p: factor * p, point_sum)))
    return avg_pos

def average_position_with_multiply(positions: Deque) -> Nose:
    factor = 1.0 / len(positions)
    point_sum = [Point(0.0, 0.0) for _ in range(10)]
    for nose in positions:
        for i, point in enumerate(nose):
            point_sum[i] += point
    avg_pos = Nose(tuple(map(lambda p: np.multiply(factor, p), point_sum)))
    return avg_pos

Mypy gives the errors below:

bug.py:20: error: Argument 1 to "map" has incompatible type "Callable[[Point], float]"; expected "Callable[[Point], Point]"
bug.py:20: error: Incompatible return value type (got "float", expected "Point")
Found 2 errors in 1 file (checked 1 source file)

In the lambda in average_position(), mypy incorrectly (?) determines the type as Callable[[Point], float], while using the NumPy function np.multiply() (which does the same thing) gives no error. Both functions output the same correct value.

mypy==0.740
Python 3.8.0
[mypy]
python_version = 3.8

[mypy-cv2,dlib,numpy]
ignore_missing_imports = True

I am unsure if this is actually an error, or if I've done something wrong somewhere, but both functions run and output correctly.

JukkaL commented 5 years ago

The root of the problem is that mypy incorrectly infers the type of multiplying Point by a float. Here's a shorter example:

from typing import Any

C: Any
class D(C):
    pass

reveal_type(0.5 * D())   # float, but should be Any

As a workaround, you can try installing stubs for numpy (https://github.com/numpy/numpy-stubs).

ilevkivskyi commented 2 years ago

It looks like #8019 didn't actually fix this issue, at least Jukka's simple example above still fails. It accidentally passes in tests (see testOpWithInheritedFromAny) because in builtins fixtures float doesn't have __add__(). I am going to mark it xfail in https://github.com/python/mypy/pull/14077 (since I need that __add__() that actually present in real builtins stubs, for my tests).

ilevkivskyi commented 2 years ago

cc @TH3CHARLie @msullivan