Closed ghost closed 2 years ago
Because (q * k).sum() is equivalent to q.T @ k
Oh, I get it. Thank you so much for your reply!
Because (q k).sum() is equivalent to q.T @ k Hi Dr. Zhou, thanks for your great work, but I do not understand why (q k).sum() is equivalent to q.T @ k. Here I have an example:
q = torch.randn(2,3) q tensor([[-1.4198, -1.4788, -0.8260], [-0.0783, 1.2059, 0.5165]]) k = torch.randn(2,3) k tensor([[-0.9287, -0.4349, 1.5053], [ 1.0446, -1.4643, 0.6810]]) (q*k).sum(dim=-1) tensor([ 0.7185, -1.4959]) q.T @ k tensor([[ 1.2368, 0.7322, -2.1906], [ 2.6331, -1.1226, -1.4048], [ 1.3067, -0.3970, -0.8916]])
And obversourly, they are not equal to each other.
Hi,
Could you tell me why the attention calculation in your code is achieved by doing hadamard product and summing the elements of the last dimension, instead of dot product operation?
Thank you so much!