inducer / pytato

Lazily evaluated arrays in Python
Other
8 stars 16 forks source link

Fixes for mypy 1.7 #472

Closed inducer closed 7 months ago

matthiasdiener commented 7 months ago

Here is a diff for the remaining mypy 1.7 errors:

```diff diff --git a/pytato/array.py b/pytato/array.py index 9366dc8..485c953 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -2530,7 +2530,7 @@ def dot(a: ArrayOrScalar, b: ArrayOrScalar) -> ArrayOrScalar: elif a.ndim == b.ndim == 2: return a @ b elif a.ndim == 0 or b.ndim == 0: - return a * b + return cast(Array, a * b) elif b.ndim == 1: return pt.sum(a * b, axis=(a.ndim - 1)) else: diff --git a/pytato/pad.py b/pytato/pad.py index 0e15f0d..f2f8c88 100644 --- a/pytato/pad.py +++ b/pytato/pad.py @@ -86,7 +86,8 @@ def _normalize_pad_width( and isinstance(pad_width[0], INT_CLASSES) and isinstance(pad_width[1], INT_CLASSES) ): - processed_pad_widths = [pad_width for _ in range(array.ndim)] + processed_pad_widths = [(pad_width[0], pad_width[1]) + for _ in range(array.ndim)] elif isinstance(pad_width, abc.Sequence): if len(pad_width) != array.ndim: raise ValueError(f"Number of pad widths != {array.ndim}" diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 5994841..9cd1ed1 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -33,7 +33,7 @@ import numpy as np from immutabledict import immutabledict from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, List, Mapping, Iterable, Tuple, Optional, TYPE_CHECKING, - Hashable) + Hashable, cast) from pytato.array import ( Array, IndexLambda, Placeholder, Stack, Roll, @@ -238,7 +238,7 @@ class CopyMapper(CachedMapper[ArrayOrNames]): """ if TYPE_CHECKING: def rec(self, expr: CopyMapperResultT) -> CopyMapperResultT: - return super().rec(expr) + return cast(CopyMapperResultT, super().rec(expr)) def __call__(self, expr: CopyMapperResultT) -> CopyMapperResultT: return self.rec(expr) diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 9098ee9..ce5f352 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -32,7 +32,7 @@ THE SOFTWARE. """ -from typing import Callable, Dict, Tuple, Optional, FrozenSet, Mapping +from typing import cast, Callable, Dict, Tuple, Optional, FrozenSet, Mapping import attrs from pytato.transform import ArrayOrNames, Mapper, MappedT from pytato.array import (Array, AxesT, Einsum, IndexLambda, @@ -195,25 +195,25 @@ class EinsumDistributiveLawMapper(Mapper): assert (isinstance(hlo.x1, Array) and isinstance(hlo.x2, Array) and are_shapes_equal(hlo.x1.shape, hlo.x2.shape)) - return self.rec(hlo.x1, ctx) + self.rec(hlo.x2, ctx) + return cast(Array, self.rec(hlo.x1, ctx) + self.rec(hlo.x2, ctx)) elif hlo.binary_op == BinaryOpType.SUB: assert (isinstance(hlo.x1, Array) and isinstance(hlo.x2, Array) and are_shapes_equal(hlo.x1.shape, hlo.x2.shape)) assert are_shapes_equal(hlo.x1.shape, hlo.x2.shape) - return self.rec(hlo.x1, ctx) - self.rec(hlo.x2, ctx) + return cast(Array, self.rec(hlo.x1, ctx) - self.rec(hlo.x2, ctx)) elif hlo.binary_op == BinaryOpType.MULT: if isinstance(hlo.x1, Array) and np.isscalar(hlo.x2): - return self.rec(hlo.x1, ctx) * hlo.x2 + return cast(Array, self.rec(hlo.x1, ctx) * hlo.x2) else: assert isinstance(hlo.x2, Array) and np.isscalar(hlo.x1) - return hlo.x1 * self.rec(hlo.x2, ctx) + return cast(Array, hlo.x1 * self.rec(hlo.x2, ctx)) elif hlo.binary_op == BinaryOpType.TRUEDIV: if isinstance(hlo.x1, Array) and np.isscalar(hlo.x2): - return self.rec(hlo.x1, ctx) / hlo.x2 + return cast(Array, self.rec(hlo.x1, ctx) / hlo.x2) else: assert isinstance(hlo.x2, Array) and np.isscalar(hlo.x1) - return hlo.x1 / self.rec(hlo.x2, ctx) + return cast(Array, hlo.x1 / self.rec(hlo.x2, ctx)) else: raise NotImplementedError(hlo) else: ```
inducer commented 7 months ago

Thanks for the patch! The reason I got stuck here is that I suspect some of these are honest bugs in mypy which should be reported rather than "papered over" with cast (because they're somewhat likely to bite us again, in a different form).

Here's one example that I filed as I was working on this: https://github.com/python/mypy/issues/16468. (This implies that bugs aren't out of the question.) Fortunately, this one was straightforward to work around on our end.

inducer commented 7 months ago

Here's the next one that I think is wrong: https://github.com/python/mypy/issues/16499. At least considering the message, this covers quite a few of the newly-flagged problems.

inducer commented 7 months ago

OK, looks like this will perhaps all take a bit longer on the mypy end. I've worked through the failures here, should be all good to go. @matthiasdiener Thanks for the patch, though I've decided to go with type-ignores over casts so that mypy gets a chance to tell us when (hopefullly/finally) they won't be needed any longer.