I found an issue in the code where the assert statement incorrectly checks if query and key are the same. Since query and key are tensors, using assert query is key is not appropriate for comparing their values.
To properly compare the values of the two tensors, torch.equal should be used. This function compares each element of the tensors, ensuring they have the same values.
Proposed Change
Modify line 257 from:
assert query is key, 'Only Support Self-Attention Currently'
to:
assert torch.equal(query, key), 'Only Support Self-Attention Currently'
Descrition
I found an issue in the code where the assert statement incorrectly checks if query and key are the same. Since query and key are tensors, using assert query is key is not appropriate for comparing their values.
To properly compare the values of the two tensors, torch.equal should be used. This function compares each element of the tensors, ensuring they have the same values.
Proposed Change
Modify line 257 from:
assert query is key, 'Only Support Self-Attention Currently'
to:
assert torch.equal(query, key), 'Only Support Self-Attention Currently'