PlaytikaOSS / tft-torch

A Python library that implements ״Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting״
MIT License
109 stars 17 forks source link

Regarding ensemble of attention score #10

Open nerdy314 opened 1 year ago

nerdy314 commented 1 year ago

You have mentioned in www.playtika-blog.com/playtika-ai/multi-horizon-forecasting-using-temporal-fusion-transformers-a-comprehensive-overview-part-2/, that "The different heads simply take care of the interactions between the Queries and the Keys, and the outputs of the heads are aggregated and averaged before multiplying by the projected values", However In your implementation you have not multiplied the value with the ensemble of attention scores $\tilde{A}(\boldsymbol{Q},\boldsymbol{K})$ You have ensembled the attention scores after multiplying with the values.

attention_scores = attn_scores_all_heads.mean(dim=1)
attention_outputs = attn_outputs_all_heads.mean(dim=1)

I have seen other implementations as well , they have done the same thing of ensembling after multiplying with the values. I may be completely wrong, but ensembling after multiplying by value doesn't seem intuitive. Can you please shed some light on this matter. Thank you.

Dvirbeno commented 1 year ago

Hi

Eventually, the relation between the attention_scores and the attention_outputs is that the former is multiplied by the values to get the latter. Now, because $\widetilde{A}(\textit{\textbf{Q}},\textit{\textbf{K}})$ is defined as the average of attention scores across all heads, and because $\textit{\textbf{V}}\textit{\textbf{W}}_V$ is "shared", you can move this multiplication inside the parentheses. This can be performed by simply repeating $\textit{\textbf{V}}\textit{\textbf{W}}_V$, and letting each attention head act on its own (when practically they all attend to the same set of values).

This is explained in section 4.4 of the original paper (and on the blogpost mentioned on your message): image

I hope this helps. Let me know if I misinterpreted your question.

Dvirbeno commented 1 year ago

@nerdy314 Does this answer your question? can we close the issue?