ChristophReich1996 / Swin-Transformer-V2

PyTorch reimplementation of the paper "Swin Transformer V2: Scaling Up Capacity and Resolution" [CVPR 2022].
https://arxiv.org/abs/2111.09883
MIT License
173 stars 14 forks source link

Scaled cosine attention error #1

Closed YeolJ00 closed 2 years ago

YeolJ00 commented 2 years ago

The scaled cosine attention part of the implementation seems wrong in model_parts.py

attention_map: torch.Tensor = torch.einsum("bhqd, bhkd -> bhqk", query, key) \
                                      / torch.maximum(torch.norm(query, dim=-1, keepdim=True)
                                                      * torch.norm(key, dim=-1, keepdim=True),
                                                      torch.tensor(1e-06, device=query.device, dtype=query.dtype))

should be corrected as

attention_map: torch.Tensor = torch.einsum("bhqd, bhkd -> bhqk", query, key) \
                                      / torch.maximum(torch.norm(query, dim=-1, keepdim=True)
                                                      @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1),
                                                      torch.tensor(1e-06, device=query.device, dtype=query.dtype))

since the equation normalizes the attention values for each query and key pair. The original code would produce a norm vector of shape (B, H, N, 1), while the actual norm matrix we need should be in shape (B, H, N, N).

ChristophReich1996 commented 2 years ago

Hi @YeolJ00,

yes you are completely right, thanks for indicating this issue! The norm is now fixed.

Cheers Christoph