Open francois-rozet opened 1 month ago
After more troubleshooting, it seems that the conversion to a sparse matrix is incorrect.
>>> mask = torch.rand((M, N)).cuda() < 0.01
>>> sparse = xf.SparseCS(mask, device=mask.device)
>>> (sparse.to_dense() != mask).nonzero()
tensor([[ 0, 1023, 418],
[ 0, 1023, 573],
[ 0, 1023, 583]], device='cuda:0')
The issue comes from _round_nnz
, which drops non-zero elements in the mask when the number of non-zero elements is not a multiple of 4.
Modifying _round_nnz
such that it keeps a few zero elements (with value False
) instead of dropping non-zero elements, solves the discrepancy between sparse
and mask
. Note that the following implementation does not require CPU-GPU synchronization.
def monkey_round_nnz(mask, divisible_by=4):
nnz = torch.count_nonzero(mask)
cunz = torch.cumsum(~mask.flatten(), dim=0)
flip = cunz <= (-nnz) % divisible_by
return torch.logical_or(flip.reshape_as(mask), mask)
xformers.sparse.utils._round_nnz = _round_nnz
However, SparseCSRTensor
does not take the values of the mask into account to perform a masked matmul, which results in incorrect attention values.
Taking the values of mask
in _masked_matmul
into account solves the issue.
@classmethod
def _masked_matmul(cls, a, b, mask):
if not (type(a) is torch.Tensor and type(b) is torch.Tensor):
return NotImplemented
assert mask.shape[1] == a.shape[1]
assert mask.shape[2] == b.shape[2]
values = mask.__values
row_indices = mask.__row_indices
row_offsets = mask.__row_offsets
column_indices = mask.__column_indices
tansp_info = mask.__transp_info
out = _csr_ops._sddmm.apply(
a.contiguous(),
b.transpose(-2, -1).contiguous(),
row_indices,
row_offsets,
column_indices,
tansp_info,
)
out = torch.where(values, out, float("-inf"))
return cls._wrap(
mask.shape,
out,
row_indices,
row_offsets,
column_indices,
tansp_info,
)
🐛 Bug
The output of
scaled_dot_product_attention
is wrong when the mask is aSparseCS
matrix. In particular the last element of the sequence is incorrect, while others are correct.To Reproduce
Expected behavior
The output of
torch.nn.functional.scaled_dot_product_attention
andxf.scaled_dot_product_attention
should be the same (up to some tolerance).Environment