Reduce the memory cost of trilinear from "N BatchSize C_len Q_len HiddenSize" to "N BatchSize C_len Q_len", N is an integer. Which is about N 0.23G -> N 0.002G for HiddenSize 96, and about N 0.31G -> N * 0.002G for HiddenSize 128.
The main idea behind the optimization is that:
The attention function of "[C, Q, C Q] dot W" can be split to "C dot W1 + Q dot W2 + (C Q) dot W3".
Given that, we could perform the dot function before the expand_dims and tile, so that the last dimension can be reduced from HiddenSize to 1 (as the last dimension of W is 1).
Btw, I think the current inputs of trilinear function obtain many memories, even the multi-head self-attention will cost more memories.
Add early stop as current pipeline just save the last five models, it's very easy to overfitting so that we can't get the actually EM and F1 of dev set (the log is reported by the max length as 400, which is lower than the actual numbers).
The main idea behind the optimization is that: The attention function of "[C, Q, C Q] dot W" can be split to "C dot W1 + Q dot W2 + (C Q) dot W3". Given that, we could perform the dot function before the expand_dims and tile, so that the last dimension can be reduced from HiddenSize to 1 (as the last dimension of W is 1). Btw, I think the current inputs of trilinear function obtain many memories, even the multi-head self-attention will cost more memories.