pmixer / SASRec.pytorch

PyTorch(1.6+) implementation of https://github.com/kang205/SASRec
Apache License 2.0
349 stars 93 forks source link

When using multihead_attention, why does the queries are normalized while keys and values are not ? #33

Open YTZ01 opened 1 year ago

YTZ01 commented 1 year ago

for i in range(len(self.attention_layers)): seqs = torch.transpose(seqs, 0, 1) Q = self.attention_layernormsi mhaoutputs, = self.attention_layers[i](Q, seqs, seqs, attn_mask=attention_mask)

key_padding_mask=timeline_mask

                                        # need_weights=False) this arg do not work?
        seqs = Q + mha_outputs
        seqs = torch.transpose(seqs, 0, 1)

In the SASRec paper, Ⅲ. Methodology part B.Self-Attention Block, the formula uses the same embedding object as queries, keys and values, then converts it through linear projections. Why does queries are normalized, while keys and values are not in the code?

pmixer commented 1 year ago

Code is not well formatted, guess you mean this line:

https://github.com/pmixer/SASRec.pytorch/blob/master/model.py#L81

Personally, I believe you can try to make Q, K, V w/ or w/o layernorm in experiments, it's not required for doing so.

Well, as query comes as last layer projected result, it's better to make it numerically stable for easier training.

K and V could also be layernormed, but as they are used for generating dot product to get query weights, I guess layernorm may not greatly effect these weights.

In summary, pls try to make some modification in your own experiments and draw some conclusion based on experiments result which is most reliable and fruitful, one is not forced to obey all the settings in current implementation, some settings comes as empirical stuff("it works well, so I keep using it in this way").