arogozhnikov / einops

Flexible and powerful tensor operations for readable and reliable code (for pytorch, jax, TF and others)
https://einops.rocks
MIT License
8.55k stars 352 forks source link

Passing a float in `repeat` as a dimension size prevents correct usage afterwards #309

Open Maykeye opened 9 months ago

Maykeye commented 9 months ago

Consider the scenario

0) Working environment is jupyter notebook(that's where this bug affects code in the wild) 1) I call repeat(image, 'h -> (h H)', H=H) with H=2.0 (float, not integer). It fails complaining about types. OK, fair (in real code I used height/factor where both height and factor are integers, in this simplified scenario 2.0 will suffice) 2) I call it again with H=2 (int). It fails. What? But it fails if and only if I did step (1). If I don't, it works fine.

3) I did some digging and saw that functions use LRU cache. 4) So working theory is if you call repeat improperly, you will taint the cache. 5) I call repeat enough of times with different arguments to cleanse LRU cache 6) I call it again with H=2 (int). Everything works fine.

Here's full code to reproduce.

import torch
from einops import repeat
image = torch.zeros(1)

def run_repeat(H):
    return repeat(image, 'h -> (h H)', H=H)

try: 
    fail_of_course= run_repeat(2.0)
except:
    print("mistakes were made as expected")

try: 
    should_not_fail = run_repeat(2)
except:
    print("we still failed")

for i in range(3, 2048):
    try: 
        cache_away = run_repeat(i)
    except:
        print("fail with i=",i)
try: 
    should_not_fail = run_repeat(2)
    print("but this time we are fine")
except:
    print("we still failed #2")

The output is

mistakes were made as expected
we still failed
but this time we are fine

If we are not hiding exception in should_not_fail, the full unexpected error is

  File "/tmp/a.py", line 13, in <module>
    should_not_fail = run_repeat(2)
                      ^^^^^^^^^^^^^
  File "/tmp/a.py", line 6, in run_repeat
    return repeat(image, 'h -> (h H)', H=H)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/einops/einops.py", line 641, in repeat
    return reduce(tensor, pattern, reduction="repeat", **axes_lengths)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/einops/einops.py", line 523, in reduce
    return _apply_recipe(
           ^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/einops/einops.py", line 248, in _apply_recipe
    tensor = backend.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/einops/_backends.py", line 267, in add_axes
    return x.expand(repeats)
           ^^^^^^^^^^^^^^^^^
TypeError: expand(): argument 'size' failed to unpack the object at pos 2 with error "type must be tuple of ints,but got float"

and if repeat is never called with float, only

but this time we are fine

is displayed

Einops 0.7.0. Torch. 2.2.0+cu121 Python '3.11.7 (main, Jan 29 2024, 16:03:57) [GCC 13.2.1 20230801]'

ETA: yep, replacing in _apply_recipe a call to _reconstruct_from_shape with _reconstruct_from_shape_uncached gets rid of error. So the bug happens because hash(2.0) = hash(2)

(Also functools.lru_cache(1024, typed=True) doesn't help as bandaid solution as from the lru_cache's PoV function always receives the same type axes_lengths: HashableAxesLengths)

arogozhnikov commented 9 months ago

Hi @Maykeye You're correct, reason is in lru_cache. More deeply, that's because hash(2.0) == hash(2) and 2 == 2.0, so map[2] == map[2.0]

Switching from plain LRU to typed LRU would solve this problem, but incurs a slow-down, so I'll just to accept current behavior and treat this situation as a programming error (i.e. user should fix it).

As a recovery: _prepare_transformation_recipe.cache_clear() or restart a kernel (or just overflow cache - also works).

Maykeye commented 9 months ago

(i.e. user should fix it).

The fix is that is appropriate for the user is changing repeat(image, 'h -> (h H)', H=foo/bar) #incorrect call to repeat(image, 'h -> (h H)', H=foo//bar) #correct call in the jupyter cell and rerunning it without seeing a error message about float once again.

_prepare_transformation_recipe.cache_clear is not documented.

The fact it is named _prepare_transformation_recipe rather than prepare_transformation_recipe even suggests that user should not know about its very existence.

Maybe just add an assert in _reconstruct_from_shape_uncached when it's iterating over dimensions that assert not isinstance(dim, float), "dim can't be float!": this way cache will not be filled with floats to begin with, and since the result is cached anyway, one call to assert will not slow the world down (besides asserts can even be disabled with python -O).