Open Maykeye opened 9 months ago
Hi @Maykeye
You're correct, reason is in lru_cache.
More deeply, that's because hash(2.0) == hash(2)
and 2 == 2.0
, so map[2] == map[2.0]
Switching from plain LRU to typed LRU would solve this problem, but incurs a slow-down, so I'll just to accept current behavior and treat this situation as a programming error (i.e. user should fix it).
As a recovery: _prepare_transformation_recipe.cache_clear()
or restart a kernel (or just overflow cache - also works).
(i.e. user should fix it).
The fix is that is appropriate for the user is changing repeat(image, 'h -> (h H)', H=foo/bar) #incorrect call
to repeat(image, 'h -> (h H)', H=foo//bar) #correct call
in the jupyter cell and rerunning it without seeing a error message about float once again.
_prepare_transformation_recipe.cache_clear
is not documented.
The fact it is named _prepare_transformation_recipe
rather than prepare_transformation_recipe
even suggests that user should not know about its very existence.
Maybe just add an assert in _reconstruct_from_shape_uncached
when it's iterating over dimensions that assert not isinstance(dim, float), "dim can't be float!"
: this way cache will not be filled with floats to begin with, and since the result is cached anyway, one call to assert will not slow the world down (besides asserts can even be disabled with python -O
).
Consider the scenario
0) Working environment is jupyter notebook(that's where this bug affects code in the wild) 1) I call
repeat(image, 'h -> (h H)', H=H)
withH=2.0
(float, not integer). It fails complaining about types. OK, fair (in real code I usedheight/factor
where bothheight
andfactor
are integers, in this simplified scenario 2.0 will suffice) 2) I call it again withH=2
(int). It fails. What? But it fails if and only if I did step (1). If I don't, it works fine.3) I did some digging and saw that functions use LRU cache. 4) So working theory is if you call repeat improperly, you will taint the cache. 5) I call repeat enough of times with different arguments to cleanse LRU cache 6) I call it again with
H=2
(int). Everything works fine.Here's full code to reproduce.
The output is
If we are not hiding exception in
should_not_fail
, the full unexpected error isand if repeat is never called with float, only
is displayed
Einops 0.7.0. Torch. 2.2.0+cu121 Python '3.11.7 (main, Jan 29 2024, 16:03:57) [GCC 13.2.1 20230801]'
ETA: yep, replacing in
_apply_recipe
a call to_reconstruct_from_shape
with_reconstruct_from_shape_uncached
gets rid of error. So the bug happens because hash(2.0) = hash(2)(Also
functools.lru_cache(1024, typed=True)
doesn't help as bandaid solution as from the lru_cache's PoV function always receives the same typeaxes_lengths: HashableAxesLengths
)