Hi! At the moment,SparseAttention class inherits from Attention but it does not support cache mechanism. I guess it is the reason of the unexpected keyword argument bug.
class CachedAs(nn.Module):
"""
A wrapper that defines a key for the inference cache.
"""
def __init__(self, cache_key, fn):
super().__init__()
self.cache_key = cache_key
self.fn = fn
def forward(self, x, *, cache=None, **kwargs):
return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs)
and line 279
if isinstance(attn, Attention):
attn = CachedAs(f'attn_{ind}', attn)
else:
# at the moment, other attention classes don't support cache
attn = NonCached(attn)
Hi! At the moment,
SparseAttention
class inherits fromAttention
but it does not support cache mechanism. I guess it is the reason of theunexpected keyword argument
bug.CODE
dalle_pytorch/attention.py
line 366dalle_pytorch/transformer.py
line 60and line 279