Open chufanchen opened 6 months ago
CL-plugin is a full continual learning module designed to interact with a pre-trained model.
Freeze pre-train model to avoid forgetting is not the best choice.
CTR inserts CL-plugin in two locations of BERT, i.e. in each transformer layer of BERT.
In learning, only the two CL-plugins and the classification heads are trained.
Inputs: hidden states $h^{(t)} \in \mathbb{R}^{d_t \times d_e}$, task ID $t$
$t$: current task ID $d_t$: # tokens $d_e$: # dimensions
Knowledge Sharing Module(KSM) and Task Specific Module(TSM)
Capsule: 2-layer fully-connected network $f_i(\cdot)=MLP_i(\cdot)$
Each capsule represents a task. Assume we have learned $t$ task so far, the capsule for task $i \leq k$ is $p_i^{(t)}=f_i(h^{(t)})$.
$u{j \vert i}^{(t)}=W{ij}pi^{(t)}$, $W{ij} \in \mathbb{R}^{d_s \times d_k}$.
$d_s$ and $d_k$ are dimensions of task capsule $i$ and transfer capsule $j$
https://arxiv.org/abs/2112.02706