arogozhnikov / einops

Flexible and powerful tensor operations for readable and reliable code (for pytorch, jax, TF and others)
https://einops.rocks
MIT License
8.54k stars 352 forks source link

[BUG] einops.repeat returns value with type Never #299

Closed adamjstewart closed 11 months ago

adamjstewart commented 11 months ago

Describe the bug

This is a weird one. When I pass a PyTorch Tensor into einops.repeat, mypy 1.7+ with strict mode enabled tries to tell me that the return type is typing.Never.

Reproduction steps

Create the following test.py file:

import einops
import torch

x = torch.rand(2, 2)
reveal_type(x)
x = einops.repeat(x, "h w -> h w c", c=3)
reveal_type(x)

Then run:

> mypy --strict test.py
test.py:5: note: Revealed type is "torch._tensor.Tensor"
test.py:7: note: Revealed type is "Never"

Expected behavior

I would expect the output to have the same type as the input, a PyTorch Tensor. This works fine for:

I haven't checked any other einops functions.

Your platform Version of einops, python and DL package that you used

I'm trying to figure out if this is a bug in einops, PyTorch, or mypy. Any help you can give me would be appreciated. I've noticed several PyTorch-specific issues with mypy 1.7+, so I doubt this is just an einops issue.

arogozhnikov commented 11 months ago

Hi Adam,

einops.repeat typing logic is literally in this line: https://github.com/arogozhnikov/einops/blob/d0c7feef31eed7adceff07175300ebfc8bdee2b2/einops/einops.py#L459

and it wasn't changed for a long time. Also it is agnostic to framework, so I assume that's something with mypy.

adamjstewart commented 10 months ago

Can confirm that this issue was fixed in PyTorch 2.2!

arogozhnikov commented 10 months ago

that's great, thanks for leaving this note