Closed YeolJ00 closed 2 years ago
The scaled cosine attention part of the implementation seems wrong in model_parts.py
attention_map: torch.Tensor = torch.einsum("bhqd, bhkd -> bhqk", query, key) \ / torch.maximum(torch.norm(query, dim=-1, keepdim=True) * torch.norm(key, dim=-1, keepdim=True), torch.tensor(1e-06, device=query.device, dtype=query.dtype))
should be corrected as
attention_map: torch.Tensor = torch.einsum("bhqd, bhkd -> bhqk", query, key) \ / torch.maximum(torch.norm(query, dim=-1, keepdim=True) @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1), torch.tensor(1e-06, device=query.device, dtype=query.dtype))
since the equation normalizes the attention values for each query and key pair. The original code would produce a norm vector of shape (B, H, N, 1), while the actual norm matrix we need should be in shape (B, H, N, N).
Hi @YeolJ00,
yes you are completely right, thanks for indicating this issue! The norm is now fixed.
Cheers Christoph
The scaled cosine attention part of the implementation seems wrong in model_parts.py
should be corrected as
since the equation normalizes the attention values for each query and key pair. The original code would produce a norm vector of shape (B, H, N, 1), while the actual norm matrix we need should be in shape (B, H, N, N).