NorbertZheng / read-papers

My paper reading notes.
MIT License
8 stars 0 forks source link

ICLR '23 | Encoding recurrence into Transformers. #65

Closed NorbertZheng closed 1 year ago

NorbertZheng commented 1 year ago

Feiqing Huang, et. al. Encoding Recurrence into Transformers.

NorbertZheng commented 1 year ago

Overview

This paper novelly

image

NorbertZheng commented 1 year ago

Breaking Down an RNN Layer

Consider an RNN layer with the input variables $\{x{t}\in \mathbb{R}^{d{in}}, 1\leq t\leq T\}$, and it has the form of $h{t}=g(W{h}h{t-1}+W{x}x{t}+b)$, where $g(\cdot)$ is the activation function, $h{t} \in \mathbb{R}^{d}$ is the output or hidden variable with $h{0}=0$, $b\in \mathbb{R}^{d}$ is the bias term, $W{h}\in \mathbb{R}^{d\times d}$ and $W{x}\in \mathbb{R}^{d\times d{in}}$ are weights, When the activation function is linear, i.e. $g(x) = x$, the RNN becomes

$$ h{t}=W{h}h{t-1}+W{x}x{t}, \ or\ equiivalently\ h{t}=\sum{j=0}^{t-1}W{h}^{j}W{x}x{t-j}, $$

where the bias term $b$ is suppressed for simplicity. Although it has a feedforward form, the RNN cannot be trained in parallel, and this is mainly caused by the recurrent weights $W{h}$. This section makes an effort to block diagonalize $W{h}$ such that the RNN at (1) can be broken down into a sequence of simple RNNs with scalar (hidden) coefficients.

image

NorbertZheng commented 1 year ago

Lemma 1 (Theorem 1 in Hartfiel (1995)). Real matrices with $R$ distinct nonzero eigenvalues are dense in the set of all $d\times d$ real matrices with rank at most $R$, where $0 < R\leq d$.

Suppose that the weight matrix $W{h}$ has rank $R\leq d$. By Lemma 1, without loss of much generality, we can assume that the nonzero eigenvalues of $W{h}$ are all distinct. Specifically, $W_{h}$ has

As a result, from Horn and Johnson (2012), we have the Jordan decomposition in real form,

$$ W_{h}=BJB^{-1}, $$

where $B\in \mathbb{R}^{d\times d}$ is invertible and $J\in \mathbb{R}^{d\times d}$ is a block diagonal matrix. It holds that

$$ (W_{h})^{j}=BJ^{j}B^{-1},\forall j\geq 1, $$

and we can then break down the recurrence induced by $W_{h}$ into sum of the recurrence of the $p\times p$ block matrices in $J$ with $p=1\ or\ 2$. Similar to (1), we define three types of RNNs with linear activation below,

$$ \begin{aligned} h{R;t}(\lambda)&=\sum{j=1}^{t-1}\lambda^{j}\cdot W{R;x}\cdot x{t-j},\ h{C{1};t}(\gamma,\theta)&=\sum{j=1}^{t-1}\gamma^{j}\cos(j\theta)\cdot W{C{1};x}\cdot x{t-j},\ h{C{2};t}(\gamma,\theta)&=\sum{j=1}^{t-1}\gamma^{j}\sin(j\theta)\cdot W{C{2};x}\cdot x{t-j},\ \end{aligned} $$

where

Note that each of the three RNNs has the recurrent weights of $\lambda$ or $(\gamma,\theta)$, and its form with a nonlinear activation function is given in the Appendix.

NorbertZheng commented 1 year ago

Proposition 1. Suppose that $W{h}$ has rank $R=r+2s\leq d$, and its eigenvalues are defined above. Let $h{0,t}=W{x}x{t}$, and then the RNN with linear activation at (1) can be equivalently rewritten into

$$ h{t}=\sum{k=1}^{r}h{R;t}(\lambda{k})+\sum{k=1}^{s}h{C{1};t}(\gamma{k},\theta{k})+\sum{k=1}^{s}h{C{2};t}(\gamma{k},\theta{k})+h_{0,t}. $$

Proof for Proposition 1. Let $W{h}$ be a $d\times d$ real matrix with distinct eigenvalues, and from Chapter 3 of Horn and Johnson (2012), we have the Jordan decomposition, $W{h}=BJB^{-1}$, where $B\in \mathbb{R}^{d\times d}$ is invertible, and $J\in \mathbb{R}^{d\times d}$ has a real Jordan form, $J=diag\{\lambda{1},...,\lambda{r},f{C}(\gamma{1},\theta{1}),...,f{C}(\gamma{s},\theta{s}),0\}$ with

$$ f{C}(\gamma{k},\theta{k})=\gamma{k}\cdot \left(\begin{array}{l} \cos(\theta{k}) & \sin(\theta{k})\ -\sin(\theta{k}) & \cos(\theta{k})\ \end{array}\right) \in \mathbb{R}^{2\times 2},\quad 1\leq k\leq s. $$

Then,

$$ (W{h})^{j}=BJ^{j}B^{-1}=\sum{k=1}^{r}\lambda{k}G{R;k}+\sum{k=1}^{s}\gamma{k}^{j}\{\cos(j\theta{k})G{C{1};k}+\sin(j\theta{k})G{C{2};k}\}, \forall j\geq 1, $$

where $G{R;k}$'s, $G{C{1};k}$'s and $G{C{2};k}$'s are $d\times d$ real matrices determined jointly by $B$ and $B^{-1}$. Let $h{0,t}=W{x}x{t}$ and then,

$$ \begin{aligned} h{t}&=\sum{j=0}^{t-1}(W{h})^{j}W{x}x{t-j}\ &=\sum{j=1}^{t-1}\sum{k=1}^{r}\lambda{k}^{j}\cdot G{R;k}W{x}\cdot x{t-j}+\sum{j=1}^{t-1}\sum{k=1}^{s}\gamma{k}^{j}[\cos(j\theta{k})\cdot G{C{1};k}+\sin(j\theta{k})\cdot G{C{2};k}]W{x}\cdot x{t-j}+W{x}x{t}\ &=\sum{k=1}^{r}h{R;t}(\lambda{k})+\sum{k=1}^{s}[h{C{1};t}(\gamma{k},\theta{k})+h{C{2};t}(\gamma{k},\theta{k})]+h_{0,t}, \end{aligned} $$

where

$$ \begin{aligned} &h{R;t}(\lambda{k})=\sum{j=1}^{t-1}\lambda{k}^{j}\cdot G{R;k}W{x}\cdot x{t-j}\ or\ equivalently,\ &h{R;t}(\lambda{k})=g(\lambda{k}\cdot h{R;t-1}+\lambda{k}\cdot G{R;t}W{x}\cdot x_{t-1}), \end{aligned} $$

$g(\cdot)$ being the identity function, i.e. $g(x)=x$, and we also have:

$$ \begin{aligned} &h{C{1};t}(\gamma{k},\theta{k})=\sum{j=1}^{t-1}\gamma{k}^{j}\cos(j\theta{k})\cdot G{C{1},k}W{x}\cdot x{t-j},\ &h{C{2};t}(\gamma{k},\theta{k})=\sum{j=1}^{t-1}\gamma{k}^{j}\sin(j\theta{k})\cdot G{C{2},k}W{x}\cdot x{t-j},\ or\ equivalently,\ &h{C{1};t}(\gamma{k},\theta{k})=g(\gamma{k}\cos(\theta{k})h{C{1};t-1}(\gamma{k},\theta{k})+(\gamma{k}\cos(\theta{k})\cdot G{C{1};k}W{x}\cdot x{t-1}-\gamma{k}\sin(\theta{k})\cdot h{C{2};t-1})),\ &h{C{2};t}(\gamma{k},\theta{k})=g(\gamma{k}\cos(\theta{k})h{C{2};t-1}(\gamma{k},\theta{k})+(\gamma{k}\sin(\theta{k})\cdot G{C{2};k}W{x}\cdot x{t-1}+\gamma{k}\sin(\theta{k})\cdot h{C{1};t-1})),\ \end{aligned} $$

For a more general form, we can further assume that the activation function $g(\cdot)$ is nonlinear for the simple RNNs at (11) and (12).

NorbertZheng commented 1 year ago

An Equivalent MHSA Representation

Consider the RNN of $\{h{R;t}\}$, and let $X=\{x{1},...,x{T}\}'\in \mathbb{R}^{T\times d{in}}$ be an input matrix consisting of $T$ tokens with dimension $d{in}$, where the transpose of a matrix $A$ is denoted by $A'$ throughout this paper. We first give the value matrix $V$ by projecting $X$ with a linear transformation, i.e. $V=XW{V}$ with $W{V}=W{R;x}=G{R}W{x}\in \mathbb{R}^{d_{in}\times d}$, and the relative positional encoding matrix is set to

$$ P{R;mask}(\lambda)=\left(\begin{array}{l} 0 & 0 & 0 & \cdots & 0\ f{1}(\lambda) & 0 & 0 & \cdots & 0\ f{2}(\lambda) & f{1}(\lambda) & 0 & \cdots & 0\ \vdots & \vdots & \vdots & \ddots & \vdots\ f{T-1}(\lambda) & f{T-2}(\lambda) & f_{T-3}(\lambda) & \cdots & 0\ \end{array}\right), $$

where $f_{t}(\lambda)=\lambda^{t}$ for $1\leq t\leq T-1$. As a result, the first RNN at (2) can be represented as a self-attention (SA) form,

$$ (h{R;1},...,h{R;T})'=SA{R}(X)=[softmax(QK')+P{R;mask}(\lambda)]V, $$

where $Q$ and $K$ are zero query and key matrices, respectively. We call $P{R;mask}(\lambda)$ the recurrence encoding matrix (REM) since it summarizes all the recurrence in $\{h{R;t}\}$. But if $Q$ and $K$ are all zero matrices, $softmax(QK')$ will result a uniform matrix, e.g. $h{R;1}$ may contain information after $x{1}$, which is incorrect.

For the RNN of $\{h{C{1};t}\}$, the REM is denoted by $P{C{1};mask}(\gamma,\theta)$, which has the form of (3) with $f{t}(\lambda)$ being replaced by $f{t}(\gamma,\theta)=\gamma^{t}\cos(\theta)$ for $1\leq t\leq T-1$, and the value matrix has the form of $V=XW{V}$ with $W{V}=W{C{1};x}=G{C{1}}W{x}\in \mathbb{R}^{d{in}\times d}$. Similarly, for the RNN of $\{h{C{2};t}\}$, the REM, $P{C{2};mask}(\gamma,\theta)$, has the form of (3) with $f{t}(\lambda)$ being replaced by $f{t}(\gamma,\theta)=\gamma^{t}\sin(t\theta)$ for $1\leq t\leq T-1$, and the value matrix is defined as $V=XW{V}$ with $W{V}=W{C{2};x}=G{C{2}}W{x}\in \mathbb{R}^{d{in}\times d}$. Thus, these two RNNs at (2) can also be represented as SA forms,

$$ (h{C{i};1},...,h{C{i};T})'=SA{C{i}}(X)=[softmax(QK')+P{C{i};mask}(\gamma,\theta)]V,\quad with\ i=1\ or\ 2, $$

where query and key matrices $Q$ and $K$ are both zero.

Finally, for the remaining term in Proposition 1, $h{0,t}$ depends on $x{t}$ only, and there is no inter-dependence involved. Mathematically, we can represent it as a SA with the identity relative positional encoding matrix and zero query and key matrices.

NorbertZheng commented 1 year ago

Proposition 2. If the conditions of Proposition 1 hold, then the RNN with linear activation at (1) can be represented into a multi-head self-attention (MHSA) with $r+2s+1$ heads, where the query and key matrices are zero, and relative positional encoding matrices are $\{P{R;mask}(\lambda{k}),1\leq k\leq r\}$, $\{P{C{1};mask}(\gamma{k},P{C{2};mask}(\gamma{k},\theta_{k}),1\leq k\leq s\}$ and an identity matrix, respectively. Proof for Proposition 2. Using the SA form, we can represent the three types of the RNNs by

$$ \begin{aligned} &(h{R;1}(\lambda{k}),...,h{R;T}(\lambda{k}))=SA{R;k}(X)\ &(h{C{i};1}(\gamma{k},\theta{k}),...,h{C{i};T}(\gamma{k},\theta{k}))=SA{C_{i};k}(X)\ for\ i=1\ or\ 2\ \end{aligned} $$

Therefore, the first term in Proposition 1 can be represented as

$$ MHSA(X)=concat[SA{R;1}(X),...,SA{R;r}(X)]W{o}=\left(\sum{k=1}^{r}h{R;1}(\lambda{k}),...,\sum{k=1}^{r}h{R;T}(\lambda_{k})\right)', $$

where $W{o}=(I{d},...,I{d})'\in \mathbb{R}^{rd\times d}$ with $I{d}$ being $d$-dimensional identity matrix. Similarly the MHSA for second and third term in Proposition 1 is given by

$$ MHSA(X)=concat[SA{C{i};1}(X),...,SA{C{i};s}(X)]W{o}=\left(\sum{k=1}^{s}h{C{i};1}(\lambda{k}),...,\sum{k=1}^{s}h{C{i};T}(\lambda_{k})\right)'. $$

where $W{o}=(I{d},...,I_{d})'\in \mathbb{R}^{sd\times d}$.

And we define the additional head as $SA{0}(X)=(h{0,1},...,h{0,T})'=[softmax(QK')+I]V$ with $W{Q}=W{K}=0$ and $W{V}=W_{x}'$. Combine with (13) and (14), we have

$$ MHSA(X)=concat[\{SA{R;k}(X)\}{1\leq k\leq r},\{SA{C{1};k}(X),SA{C{2};k}(X)\}{1\leq k\leq s},SA{0}(X)]W_{o}, $$

where $W{o}=(I{d},...,I_{d})'\in \mathbb{R}^{(r+2s+1)d\times d}$.

NorbertZheng commented 1 year ago

The three simple RNNs at (2) provide different temporal decay patterns:

image

From Proposition 2, the combination of these three types of patterns forms the recurrent dynamics of the RNN layer at (1). For each head, the REM has one or two parameters, and $W{V}$ can be regarded as one $d\times d$ learnable matrix. This leads to a parameter complexity of $O(Rd^{2})$ (but $G{R}$ only has rank $1$, i.e. $W_{R;x}$ has rank $1$, maybe $O(Rd)$?), and it is slightly larger than that of the RNN at (1), which is $O(d^{2})$ since $R$ is usually much smaller than $d$ (Prabhavalkar et al., 2016). Moreover, the MHSA representation in Proposition 2 gives us a chance to make use of parallel matrix calculation on the GPU hardware; see Appendix D.3 for an illustration of the computational efficiency of REMs.

NorbertZheng commented 1 year ago

Theoretical and Empirical Gaps between linear and nonlinear RNNs

Theoretical Gap

This subsection theoretically evaluates prediction errors when a linear RNN model is used to train the data generated by a nonlinear RNN. For simplicity, we consider 1D case only, and many-to-many RNNs are assumed.

Specifically, the nonlinear RNN model used to generate the data has the recursive form of

$$ g{t}=\sigma{h}(u{t})\quad with\ u{t}=w{h}^{*}g{t-1}+w{x}^{*}x{t}+b^{*}, $$

and $\sigma_{h}(\cdot)$ is the activation function satisfying

$$ |\sigma{h}(0)|<1,\sigma{h}'(0)=1,|\sigma_{h}''(x)|\leq 1\ for\ any\ x\in \mathbb{R}. $$

Note that many commonly used activation functions, including Tanh and Sigmoid, satisfy the above condition. We further consider an additive error $\epsilon{t}$, i.e. $y{t}=g{t}+\epsilon{t}$, where $\epsilon_{t}$ has mean zero and a finite variance denoted by $\gamma$.

For the generated data $\{y_{i}\}$, we train a linear RNN model,

$$ h{t}(\theta)=w{h}h{t-1}(\theta)+w{x}x_{t}+b, $$

where the parameters $\theta=(w{h},w{x},b)$. Then the mean squared prediction error can be defined as

$$ e{pred}:=\min{\theta}\mathbb{E}(y{t}-h{t}(\theta))^{2}, $$

and its theoretical bound is provided in the following proposition.

NorbertZheng commented 1 year ago

Proposition 3. Suppose that $\mathbb{E}(u{t}^{2})\leq \alpha$ and $\mathbb{E}(u{t}^{2}u{s}^{2})\leq \beta$ for all $t,s\in \mathbb{Z}$. If $|w{h}^{*}|<1$ and the condition at (6) holds, then

$$ e{pred}\leq \underbrace{(1-|w{h}^{*}|)^{-1}(1+\alpha+\frac{\beta}{4})}{misspecification\ error}+\underbrace{\gamma}{irreducible\ system\ error}, $$

where the first part is due to the misspecification of using the linear activation to approximate $\sigma_{h}(\cdot)$.

Proof for Proposition 3. Let $\theta^{*}=(w{h}^{*},w{x}^{*},b^{*})$, and denote $h{t}=h{t}(\theta^{*})$ for all $t\in \mathbb{Z}$, i.e.

$$ h{t}=w{h}^{*}h{t-1}+w{x}^{*}x_{t}+b^{*}. $$

By the definition of $e_{pred}$, it holds that

$$ e{pred}\leq \mathbb{E}(y{t}-h{t})^{2}=\mathbb{E}(g{t}-h_{t})^{2}+\gamma, $$

where the equality comes from $\mathbb{E}(\epsilon{t})=0$ and $var(\epsilon{t})=\gamma$. By applying second-order Taylor expansion at zero and from (5) and (6), we have

$$ g{t}=\sigma{h}(0)+u{t}+R{t}(0),\ where\ |R{t}(0)|=\left|\frac{1}{2}\sigma{h}''(\tilde{u})u{t}^{2}\right|\leq\frac{u{t}^{2}}{2}, $$

and $\tilde{u}$ lies between $u{t}$ and zero. This, together with (7), leads to $g{t}-h{t}=\sigma{h}(0)+w{h}^{*}(g{t-1}-h{t-1})+R{t}(0)$.

Let $\delta{t}=g{t}-h_{t}$, and it then holds that

$$ \delta{t}=\sigma{h}(0)+w{h}^{*}\delta{t-1}+R{t}(0)=\sum{j=0}^{\infty}(w{h}^{*})^{j}\sigma{h}(0)+\sum{j=0}^{\infty}(w{h}^{*})^{j}R_{t-j}(0), $$

where the second equality is obtained by applying the first equality recursively. As a result, by the condition that $|\sigma{h}(0)|<1$, $\mathbb{E}(u{t}^{2})\leq \alpha$ and $\mathbb{E}(u{t}^{2}u{s}^{2})\leq \beta$ for all $t,s\in \mathbb{Z}$, we can show that

$$ \begin{aligned} \mathbb{E}(\delta{t}^{2})&\leq \xi^{2}+2\xi\mathbb{E}\left(\sum{j=0}{\infty}|w{h}^{*}|^{j}|R{t-j}(0)|\right)+\mathbb{E}\left(\sum{j=0}{\infty}|w{h}^{*}|^{j}|R_{t-j}(0)|\right)\ &\leq \xi^{2}\left(1+\alpha+\frac{\beta}{4}\right), \end{aligned} $$

where $\xi=\sum{j=0}^{\infty}|w{h}^{*}|^{j}. If $|w{h}^{*}<1$, we have $\xi=(1-|w{h}^{*}|)^{-1}$. This, together with (8), accomplishes the proof.

NorbertZheng commented 1 year ago

Proposition 3 for Dilated RNNs. For some positive integer $d$, let $P{R;mask;T}$ be the block matrix formed by the first $T$ columns and the first $T$ rows of $P{R;mask}\otimes I{d}$. And $P{C{1};mask;T}$ can be defined similarly for $i=1\ or\ 2$. Consider a dilated RNN (Chang et al., 2017) with the dilating factor $d$. It has the form $h{t}=g(W{h}h{t-d}+W{x}x{t}+b)$ where $g(\cdot)$ is the activation function, $h{t}\in \mathbb{R}^{d}$ is the output or hidden variable with $h{0}=0$, $b\in \mathbb{R}^{d{in}}$ is the bias term, $W{h}\in \mathbb{R}^{d\times d}$ and $W{x}\in \mathbb{R}^{d\times d{in}}$ are weights. When the activation function is linear, i.e. $g(x)=x$, the RNN becomes

$$ h{t}=W{h}h{t-d}+W{x}x{t},\quad or\ h{t}=\sum{j=0}^{t-1}(W{h})^{j}WW{x}\cdot x{t-j\cdot d}, $$

where the bias term $b$ is suppressed for simplicity. We have the following proposition.

Proposition 4. If the conditions of Proposition 1 hold for $W{h}$ in (10), then the RNN with linear activation at (10) can be represented into a multi-head self-attention (MHSA) with $r+2s+1$ heads, where the query and key matrices are zero, and relative positional encoding matrices are $\{P{R;mask;T}(\lambda{k}),1\leq k\leq r\}$, $\{P{C{1};mask;T}(\lambda{k}),P{C{2};mask;T}(\lambda_{k}),1\leq k\leq r\}$ and an identity matrix, respectively.

Proof for Proposition 4. The proof follows directly from Propositions 1 and 2.

NorbertZheng commented 1 year ago

Expirical Gap

This subsection conducts a synthetic experiment to evaluate the performance of linear RNNs when there is nonlinearity in the data.

We first generate the data by using a two-layer nonlinear RNN model with the form of

$$ z{t}^{(i)}=\alpha g{t}^{(i)}+(1-\alpha)h_{t}^{(i)}, $$

where

$$ \begin{aligned} h{t}^{(i)}&=W{h}^{(i)}z{t-1}^{(i)}+W{z}^{(i)}z{t}^{(i-1)}+b^{(i)},\ g{t}^{(i)}&=\sigma{h}(W{h}^{(i)}z{t-1}^{(i)}+W{z}^{(i)}z_{t}^{(i-1)}+b^{(i)}),\ \end{aligned} $$

with $i=1\ or\ 2$, where $z{t}^{(0)}=x{t}$, $z{t}^{(i)}\in \mathbb{R}^{2}$ for $1\leq i\leq 3$, $\sigma{h}(\cdot)$ is a nonlinear activation function, and $0\leq \alpha \leq 1$ is the weight of nonlinearity. An additive error is further assumed, i.e.

$$ y{t}=z{t}^{(2)}+\epsilon_{t}, $$

where $\{x{t}\}$ and $\{\epsilon{t}\}$ are independent and follow the standard multivariate normal distribution. Three nonlinear functions are considered for $\sigma_{h}(\cdot)$, , Tanh, Sigmoid and ReLU. As $\alpha$ increases from 0 to 1, the data-generating process gradually changes from a strictly linear RNN to a nonlinear one, i.e. $\alpha$ essentially controls the proportion of nonlinearity involved.

The sequence $\{y_{t},1\leq t\leq T\}$ is fitted separately by

Specifically, we generate a sequence of length 10000 and then divide it into 100 segments, each of length 100. In each segment, we train with the first 99 observations and calculate the prediction error for the last observation. The Adam optimizer is adopted for training, and the training procedure will be terminated when the training loss drops by less than $10^{-5}$. The mean squared prediction $error(MSPE)$ averaged over 100 segments are denoted by $e{pred}^{L}$, $e{pred}^{NL}$ and $e_{pred}^{RSA}$ for three models, respectively. Using nonlinear RNNs as the benchmark, the MSPE ratio for the linear RNN or the RSA is defined as

$$ MSPE\ ratio\ for\ model\ i=\frac{e{pred}^{i}}{e{pred}^{NL}},\quad where\ i\in \{L,RSA\}. $$

Figure 6 presents the MSPE ratios for three types of activation functions. It can be seen that, when $\alpha=1$, nonlinear RNNs perform the best, while linear RNN suffers from misspecification error. Alternatively, when $\alpha=0$, the opposite can be observed. Moreover, as $\alpha$ increases, i.e. there are more nonlinearity, it is expected that linear RNNs become less favorable, while the proposed RSA can remedy the problem to some extent. Especially when $\alpha>0.6$, the RSA consistently achieves better prediction performance than the pure linear RNN.

image

NorbertZheng commented 1 year ago

Encoding Recurrence into Self-Attention

While the query and key matrices are set to zero in the MHSA representation at Proposition 2, they play a central role in a standard Transformer. This motivates us to propose the Self-Attention with Recurrence (RSA) module to seamlessly combine the strengths of RNNs and Transformers:

$$ RSA(X)=\{[1-\sigma(\mu)]softmax(QK')+\sigma(\mu)P\}V, $$

for each head, where $P$ is a regular or cyclical REM, and $\sigma(\mu)\in[0,1]$ is a gate with $\sigma$ being the sigmoid function and $\mu$ being the learnable gate-control parameter. Figure 1(d) provides a graphical illustration of one RSA head. image

Note that the REMs in Section 2 are all lower triangular matrices, which correspond to unidirectional RNNs. In the meanwhile, for non-causal sequential learning tasks (Graves et al., 2005; Zhou et al., 2016), bidirectional RNNs are usually applied, and accordingly we can define the unmasked versions of REMs. Specifically, the regular REM is

$$ P{R;unmask}(\lambda)=P{R;mask}(\lambda)+[P_{R;mask}(\lambda)]', $$

and the cyclical REMs are

$$ P{C{i};unmask}(\lambda)=P{C{i};mask}(\lambda)+[P{C{i};mask}(\lambda)]',\ with\ i=1\ and\ 2, $$

In practice, these REMs will be explosive when $|\lambda|>1$ or $|\gamma|>1$. To avoid this problem, we further bound these two parameters by transformations in this paper, i.e.

$$ \begin{aligned} \lambda&=tanh(\eta),\ \gamma&=sigmoid(\nu), \end{aligned} $$

and these notations then become

$$ P{R;k}(\eta), P{C{1},k}(\nu,\theta), P{C_{2},k}(\nu,\theta),\ with\ k\in \{mask,unmask\}. $$

NorbertZheng commented 1 year ago

The learnable gate $\sigma(\mu)$ is used to measure the proportion or strength of recurrent patterns. When the sample size is relatively small, it is also possible to use REMs to approximate a part of non-recurrent patterns to obtain a better bias-variance trade-off, leading to a higher value of $\sigma(\mu)$.

For the multihead RSA, the gate-control parameter $\mu$ only varies across layers, while the parameters controlling the matrix $P$ vary across all heads and layers.

NorbertZheng commented 1 year ago

Initialization. Two types of parameters are introduced in (4). The gate-control parameter $\mu$ is initialized in the interval $[−3, 3]$ for all the layers. The second type of parameters are $(\lambda,\gamma,\theta)$, which determine the recurrent patterns. To encourage REMs to be non-zero and well-diversified, we initialize $\lambda$'s at different heads to spread out between $[-2,-1]\cup [1,2]$ and $\nu$'s to spread out between $[1,2]$, and $\theta$ is initialized at $\frac{\pi}{4}$.

Dilated REM variants. The dilated REMs can be further obtained by considering the block matrix formed by the first $T$ columns and first $T$ rows of $P\otimes I_{d}$, where $d$ is the dilating factor. In fact, the dilated REMs can encapsulate the recurrence patterns of the dilated RNNs (Chang et al., 2017); see Proposition 4 in the Appendix. They describe potentially periodic recurrence over a long-term timespan, and can significantly enrich the temporal dynamics of our RSA.

Hyperparameters $k_{i}$ and $d$. In each multihead RSA layer with $H$ heads, the number of the six types of REMs, namely one regular, two cyclical, one dilated regular and two dilated cyclical REMs, are denoted by $k{1}-k{6}$, respectively. Since it holds that $\Sigma{i=1}^{6} k{i}=H$, we can apply constrained hyperparameter search to optimize over the choices of $k{i}$'s. For simplicity, we can set $k{2}=k{3}$ and $k{5}=k_{6}$, which indicates that the cyclical REMs come in pairs. The search for $d$, on the other hand, can be guided by various data analytic tools. For instance, $d$ can be observed from the autocorrelation plots as the seasonal period length for time series data (Cryer and Chan, 2008); while for the language-related data, $d$ can be heuristically deduced from the recurrence plots (Webber et al., 2016).

Multi-Head RSA operations. We first choose masked or unmasked REMs according to the nature of the task, and then select the hyperparameters including the dilating factor $d$ and the numbers of the six types of REMs $(k{1},...,k{6})$. Each RSA head is calculated by (4) with a different REM. Finally, a linear layer is applied to combine the output from all heads, and it is then followed by

selectively.

More discussions on REMs. Each type of REMs is basically a masked or unmasked linear aggregation of tokens, and we may alternatively consider a more general Toeplitz, or even a fully learnable, matrix $P$ at (4) such that all kinds of temporal decay patterns can be automatically captured. Although more flexible than REMs, it will need $O(T)$ or $O(T^{2})$ additional parameters, where $T$ is the sequence length, while each REM requires merely one or two parameters. Note that the proposed RSA at (4) also includes a standard self-attention, which is flexible enough for all remaining patterns from the fitted REMs, and hence it may not be necessary to consider a more general yet less efficient structure for REMs. On the other hand, the REMs may fail to capture the possible nonlinearity in the data since they come from a linear RNN, while the self-attention component in the RSA module can remedy the problem to some extent; see Appendix B for both theoretical and empirical evidences.

NorbertZheng commented 1 year ago

Multi-Scale Recurrence

For sequential learning tasks, the recurrence relationship can be well observed at different levels of temporal granularity. This feature has inspired new designs to add recurrence to Transformers from varying scales; see illustration in Figure 5. Transformer-XL (XL) (Dai et al., 2019) partitions a long input sequence into segments, and places them into consecutive batches.

Temporal Latent Bottleneck (TLB) (Didolkar et al., 2022) further divides the segment within one batch into smaller chunks, and then adopts the state vectors to aggregate both high-level information across layers and temporal information across chunks.

Block-Recurrent Transformer (BRT) (Hutchins et al., 2022) also establishes recurrence across chunks (or blocks), while their recurrent states are layer-specific and updated with an LSTM-style gated design. As a comparison, the proposed RSA follows the RNN to account for the recurrence between individual inputs. In other words, it models

Subsequently, it can be easily incorporated into the aforementioned coarser-grained designs, and may potentially bring benefits to their performance. For illustration, we use XL and BRT as our baseline models in Section 4.3.

image