Unbabel / COMET

A Neural Framework for MT Evaluation
https://unbabel.github.io/COMET/html/index.html
Apache License 2.0
453 stars 72 forks source link

tensor_lru_cache is limited to tensors with at least 2-Dimensions #124

Closed alvations closed 1 year ago

alvations commented 1 year ago

🐛 Bug

The tensor_lru_cache assumes tensor diagonals because of x.diagonal().__repr__() in _make_key. To support robust tensor caching, it's best to check for size before calling repr(diagonal()).

To Reproduce

This works as expected:

from .lru_cache import tensor_lru_cache

@tensor_lru_cache(None) # unlimited size
def add(x, y):
  return x + y

for _ in range(10):
  tmp = add(torch.tensor([[0,1,2], [2,3,4]]), torch.tensor([[5,6,7], [8,9,10]]))

print(add.cache_info())

[out]:

CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)

This works too:

for _ in range(10):
  tmp = add(torch.tensor([[0, 2]]), torch.tensor([[5, 8]]))

print(add.cache_info())

[out]:

CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)

This fails because the tensor is 1-D

for _ in range(10):
  tmp = add(torch.tensor([0, 2]), torch.tensor([5, 8]))

print(add.cache_info())

[out]:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
[<ipython-input-22-4f0c4b36b198>](https://localhost:8080/#) in <module>
      4 
      5 for _ in range(10):
----> 6   tmp = add(torch.tensor([0, 2]), torch.tensor([5, 8]))
      7 
      8 print(add.cache_info())

1 frames
[<ipython-input-11-1aab7df411aa>](https://localhost:8080/#) in _make_key(args, kwds, typed, kwd_mark, fasttypes, tuple, type, len)
     36                 x.__repr__()
     37                 + "\n"
---> 38                 + x.diagonal().__repr__()
     39                 + "\n"
     40                 + x.shape.__repr__()

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

Expected behaviour

for _ in range(10):
  tmp = add(torch.tensor([0, 2]), torch.tensor([5, 8]))

print(add.cache_info())

[out]:

CacheInfo(hits=9, misses=1, maxsize=None, currsize=1)

Suggested solution

At https://github.com/Unbabel/COMET/blob/master/comet/models/lru_cache.py#L52 this change would fix the issue:

    new_args = []
    for x in args:
        if torch.is_tensor(x):
            print(x)
            assert len(x.size()) > 0, "Tensor needs to be at least 1-Dimensional."
            # HACK: Tensor representations omit some tensor content.
            # Nonetheless converting the tensor into a tuple is too slow.
            # The current solution is an approximation to the actual tensor
            # full representation. This can still lead to `false` cache hits!
            if len(x.size()) == 1:
              reprs = [repr(x), repr(x.shape)]
            else:
              reprs = [repr(x), repr(x.diagonal), repr(x.shape)]
            new_args.append("\n".join(reprs))
        else:
            new_args.append(x)

Otherwise checking that it's >= 2-D would work too, e.g.

    new_args = []
    for x in args:
        if torch.is_tensor(x):
            assert len(x.size()) > 1, "Tensor needs to be at least 2-Dimensional."
            # HACK: Tensor representations omit some tensor content.
            # Nonetheless converting the tensor into a tuple is too slow.
            # The current solution is an approximation to the actual tensor
            # full representation. This can still lead to `false` cache hits!
            reprs = [repr(x), repr(x.diagonal), repr(x.shape)]
            new_args.append("\n".join(reprs))
        else:
            new_args.append(x)

Note: Calling __repr__ in a single new_args.append(x.__repr() + "\n" ...) would also increase cache hits but for bevity and readability in this issue, I've used repr.

ricardorei commented 1 year ago

This makes sense @alvations but what is the specific use case you are trying here? The current implementations should be working with current models implementation

alvations commented 1 year ago

The bug shouldn't impact normal COMET users using the model off-the-shelves. It was found after stress testing modules in the library.

But in any case, the cache should at least assert the 2-D nature just in case users made some changes to the model / inputs.

ricardorei commented 1 year ago

Ok got it. I'll work on it

alvations commented 1 year ago

Thank you for the commit!!

ricardorei commented 1 year ago

I also released version 2.0.1 which includes this already. Enjoy!