ChufanSuki / read-paper-and-code

0 stars 0 forks source link

TMLR 2023 | Replay-enhanced Continual Reinforcement Learning #129

Closed ChufanSuki closed 5 months ago

ChufanSuki commented 5 months ago

https://arxiv.org/abs/2311.11557

ChufanSuki commented 5 months ago

Task sequence $\mathcal{M}=[M_1,M2,\dots,M_N]$

Policy $\pi{\phi}$, value function $Q\theta$. Both use multi-head neural network.

\Theta=\{\phi,\theta\}, \Theta^*=arg \max_\Theta \sum_{j=1}^i J_{M_j}(\Theta)

PopArt

In multi-task RL, agnet need to train different value functions for each task. Since the output range of each task's value function differ, this leads to value function scaling problem. The PopArt method is a technique used to address the value function scaling problem in deep RL.

Continual RL can be viewed as a multi-task learning on current and historical tasks. In replay-based RL, value functions on bot current and past task experiences need to be learned while learning a new task.

ChufanSuki commented 5 months ago

RECALL

Optimize a normalized value function $Q{\theta, \text { norm }}=\left[Q{\theta, \text { norm }}^1, \ldots, Q{\theta, \text { norm }}^i, \ldots, Q{\theta, \text { norm }}^N\right]$ with $N$ output heads.

Given targets denoted as $Q{\bar{\theta},\tau}$, the normalized targets are $\tilde{Q}{\bar{\theta},\tau}=\sigma^{-1}(Q_{\bar{\theta},\tau} - \mu)$, where $\sigma$ and $\mu$ are scale and shift parameters. Each head has its own $(\sigma, \mu)$ learned form the data of the associated task.

\mathcal{L}_{Q_{\text {norm }}}(\theta)=\mathbb{E}_{\left(s_t, a_t\right) \sim \mathcal{D}_{\text {new }} \cup \mathcal{D}_{\text {old }}}\left[\frac{1}{2}\left(Q_{\theta, \text { norm }}\left(s_t, a_t\right)-\widetilde{Q}_{\bar{\theta}, \tau}\left(s_t, a_t\right)\right)^2\right]

where

Q_{\bar{\theta}, \tau}\left(s_t, a_t\right)=r\left(s_t, a_t\right)+\gamma \mathbb{E}_{s_{t+1} \sim p}\left[\mathbb{E}_{a_{t+1} \sim \pi_\phi}\left[\sigma Q_{\bar{\theta}, \text { norm }}\left(s_{t+1}, a_{t+1}\right)+\mu-\alpha \log \pi_\phi\left(a_{t+1} \mid s_{t+1}\right)\right]\right],
\mathcal{L}_{\pi, \text { norm }}(\phi)=\mathbb{E}_{s_t \sim \mathcal{D}_{\text {new }} \cup \mathcal{D}_{\text {old }}}\left[\mathbb{E}_{a_t \sim \pi_\phi}\left[\alpha \log \left(\pi_\phi\left(a_t \mid s_t\right)\right)-Q_{\theta, \text { norm }}\left(s_t, a_t\right)\right]\right]

These loss functions are applied on experiences from both old and new tasks, e.g. 50-50 experience mixture of novel and replayed tasks. For each sample, only the head associated to the task that it belongs to in the value and policy networks are updated.

In addition, after each actor-critic update(e.g. SAC), RECALL is required to incrementally update the scale and shift parameters to achieve adaptively targets rescaling:

\mu_t=\mu_{t-1}+\beta_t\left(Q_{\bar{\theta}, \tau}-\mu_{t-1}\right) \quad$ and $\quad \sigma_t^2=\nu_t-\mu_t^2$, where $\quad \nu_t=\nu_{t-1}+\beta_t\left(Q_{\bar{\theta}, \tau}^2-\nu_{t-1}\right)

where $\nu_t$ estimates the second moment of the targets, and $\beta_t \in[0,1]$ is the step size. Then, the last layer weights $(\mathbf{w}, b)$ of the corresponding head in the normalized $\mathrm{Q}$ network also need to be updated accordingly to preserve the outputs of the unnormalized function precisely after the scale and shift change:

\mathbf{w}^{\prime}=\sigma^{-1} \sigma \mathbf{w}, \quad b^{\prime}=\sigma^{-1}\left(\sigma b+\mu-\mu^{\prime}\right) .

We employ policy distillation technique on the replayed tasks to preventing the distributional shift between the past experiences and the learned policies of old tasks while learning a new task.

\mathcal{L}_{\pi_{\text {distill}}}(\phi)=\mathbb{E}_{s_t \sim \mathcal{D}_{\text {old }}}\left[K L\left[\pi_\phi\left(\cdot \mid s_t\right), \pi_{\text {old}}\left(\cdot \mid s_t\right)\right]\right]
\min _{\theta, \phi} \mathcal{L}_{Q_{\text {norm }}}(\theta)+\mathcal{L}_{\pi, \text { norm }}(\phi)+\lambda \mathcal{L}_{\pi_{distill}}(\phi)
ChufanSuki commented 5 months ago

image