Open 5g4s opened 1 year ago
We propose RFA, a linear time and space attention that uses random feature methods to approximate the softmax function, and explore its application in transformers. Compared to existing efficient transformer variants, RFA is competitive in terms of both accuracy and efficiency on three long text classification datasets.
RFA approximates the dot-then-approximate function with a kernel trick $\exp (\mathbf{x} \cdot \mathbf{y}) \approx \phi(\mathbf{x}) \cdot \phi(\mathbf{y})$.
\begin{aligned}
\text{attn}\left(\mathbf{q}_t,\left\{\mathbf{k}_i\right\},\left\{\mathbf{v}_i\right\}\right) & =\sum_i \frac{\exp \left(\mathbf{q}_t \cdot \mathbf{k}_i / \sigma^2\right)}{\sum_j \exp \left(\mathbf{q}_t \cdot \mathbf{k}_j / \sigma^2\right)} \mathbf{v}_i^{\top} \\
& \approx \sum_i \frac{\phi\left(\mathbf{q}_t\right)^{\top} \phi\left(\mathbf{k}_i\right) \mathbf{v}_i^{\top}}{\sum_j \phi\left(\mathbf{q}_t\right) \cdot \boldsymbol{\phi}\left(\mathbf{k}_j\right)} \\
& =\frac{\phi\left(\mathbf{q}_t\right)^{\top} \sum_i \phi\left(\mathbf{k}_i\right) \otimes \mathbf{v}_i}{\phi\left(\mathbf{q}_t\right) \cdot \sum_j \phi\left(\mathbf{k}_j\right)}=\text{RFA}\left(\mathbf{q}_t,\left\{\mathbf{k}_i\right\},\left\{\mathbf{v}_i\right\}\right) .
\end{aligned}
https://arxiv.org/abs/2103.02143