jrzaurin / pytorch-widedeep

A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
Apache License 2.0
1.3k stars 190 forks source link

Dropout layer being created on forward pass (in MultiHeadedAttention) #189

Closed BrunoBelucci closed 1 year ago

BrunoBelucci commented 1 year ago

There is one dropout layer being created on the forward pass in the MultiHeadedAttention class (pytorch_widedeep/models/tabular/transformers/_attention_layers.py):

class MultiHeadedAttention(nn.Module):
...
    def forward(self, X_Q: Tensor, X_KV: Optional[Tensor] = None) -> Tensor:
    ...
    self.attn_weights, attn_output = self._standard_attention(q, k, v)
    ...
    def _standard_attention(self, q: Tensor, k: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]:
    ...
    attn_output = einsum(
            "b h s l, b h l d -> b h s d",nn.Dropout(self.dropout)(attn_weights), v  # << HERE
        )

It prevents us from correctly putting the whole model in "eval" mode, because the dropout is always applied. I think we should instantiate the layer in the __init__. I will promptly submit a PR to fix this.

jrzaurin commented 1 year ago

You are right, in fact it was like that a few versions ago, no idea why I changed it 🤷🏻‍♂️, but thank you very much for finding the bug.