NorbertZheng / read-papers

My paper reading notes.
MIT License
7 stars 0 forks source link

ICLR '19 | How to train your MAML. #26

Closed NorbertZheng closed 2 years ago

NorbertZheng commented 2 years ago

Antoniou A, Edwards H, Storkey A. How to train your MAML.

NorbertZheng commented 2 years ago

Related Reference

NorbertZheng commented 2 years ago

Introduction

Enter meta-learning, a universe where computational models, composed by multiple levels of learning abstractions, can improve their own ability to learn by learning some or all of their own building blocks, e.g.

through experience over a large number of tasks. Meta-learning or as often referred to learning to learn is achieved by abstracting learning into two or more levels. The inner-most levels acquire task-specific knowledge (e.g. fine-tuning a model on a new dataset) , whereas the outer-most level acquires across-task knowledge (e.g. learning to transfer between tasks more efficiently). If the models in the inner-most levels make use of components with learnable parameters, the outer-most optimization process can meta-learn the parameters of such components, thus enabling automatic learning of inner-loop components.

Few-shot learning is a perfect example of a problem-area where meta-learning can be used to effectively meta-learn a solution. Using meta-learning, one can formulate and train systems that can very quickly learn from a small training set (i.e. a support set), containing only 1-5 samples from each class, such that it can generalize strongly on a corresponding small validation set (i.e. a target set). The constraints in this instance are that the model will only have access to very few data-points from each class, and the target metric is the target set’s cross-entropy error.

NorbertZheng commented 2 years ago

Meta-learning can be achieved through a variety of learning paradigms. The most effective, as of the time of writing this, are

Both reinforcement learning and genetic algorithms have been demonstrated to be very computationally expensive, often requiring hundreds of GPU hours for a single experiment. However, gradient-based, end-to-end differentiable supervised meta-learning schemes, such as Meta-learner LSTM and Model Agnostic Meta-Learning (MAML), can be run on a single GPU, within 12-24 hours.

Gradient-based, end-to-end differentiable meta-learning presents an incredible opportunity for efficient and effective meta-learning. However, methods from this family of meta-learning are currently in their infancy, often suffering from a variety of issues.

For example, MAML’s inner loop SGD optimizer can outperform the Meta-learner LSTM method, which has parameterized its update rule as an LSTM that receives gradients and predicts updates. This is counter-intuitive. One would expect the learnable optimizer to outperform the manually built one. One possible reason for this might be that the Meta-Learner LSTM’s architecture affected its modelling capacity, thus rendering it inferior to a manually set optimizer. It is perhaps reasonable to assume that one of the deciding factors that make or break such systems is the architecture of the parameterized component. In MAML, we have a system that can achieve very strong results, with a relatively simple learning scheme composed of learning a parameter initialization for quick adaptation. However, even such a relatively parameter-light model can have instability problems depending on architecture and hyper-parameter choices. If MAML has these types of problems, then anything more complicated than that will suffer from such issues as well.

So, in order to improve gradient-based, end-to-end differentiable meta-learning models in general, we focus on MAML (which is relatively simple), identifying some of its problems and proposing solutions that stabilize the training, increase the convergence speed and improve the generalization performance.

NorbertZheng commented 2 years ago

In this blog-post, we’ll go over the MAML model, identify key problems, and then formulate a number of methodologies that attempt to solve them, as proposed in the recent paper How to train your MAML and implemented in the How to train your MAML github repo. The improved MAML variant is called MAML++. Figure 2 showcases how the proposed variant improves stability and convergence speed over the original when strided convolutions are used.

Finally, we’ll have an in-depth look in how the proposed model learned its own learning rates, and draw insights from what the model automatically learned for itself.

NorbertZheng commented 2 years ago

Model Agnostic Meta Learning (MAML)

MAML is a meta-learning framework that attempts to learn a parameter initialization $\theta=\theta{0}$ for a neural network such that after the model takes a small number ($N=1...5$) of Standard SGD steps, with respect to particular task’s support set (i.e. $S={x{S},y{S}}$), it can generalize very well on the task’s target set (i.e. $T={x{T},y_{T}}$).

Now, the next paragraph is probably one of the two most important ones in this whole blog post. So take a breath, sip some of that coffee/water/tea and get ready for a crash course on MAML.

Figure 1 (parsed from top to bottom) illustrates the MAML computation graph. In MAML we are given a task, composed of two sets. A support set (i.e. a small training set), composed of a batch of input-output pairs ${x{S},y{S}}$, and a target set (i.e. a small validation set), composed of input-output pairs ${x{T},y{T}}$. Upon receiving a task, MAML sets $\theta{0}=\theta$ and then begins executing the inner loop optimization process. During this process, a neural network $f$, receives the support set inputs $x{S}$ and some weights $\theta{i-1}$ (where $i=1...N$) and returns some predictions $f(x{s},\theta{i-1})$. The received predictions are compared against the true target labels $y{S}$ using some loss function $L$ to compute the loss of the network, $L{i-1}^{S}$, given current weights and support set. Then the computed loss is used to update the current weights $\theta{i-1}$ to a state towards the current task’s distribution $\theta{i}=\theta{i-1}-\alpha\nabla{\theta-1}L{i-1}^{S}$. This process is repeated $N$ times, after which, $\theta{N}$ is obtained. At this point the inner loop optimization process has completed. Next, the fully updated model $f{\theta_{N}}$ is applied on the target set inputs $x{T}$, to obtain some predictions $f(x{T},\theta_{N})$.

image

These predictions are then used along with the target set’s labels $y{T}$ and a loss function $L$ to compute the task’s target set loss $L{N}^{T}$. At this point, using the target set’s loss, we compute the gradients with respect to $\theta=\theta{0}$, denoted as $\nabla{\theta}L_{N}^{T}$ by backpropagating through the full computation graph, including the inner loop gradient computation and updates. Yes, you did read that correctly, we backpropagate through the inner loop gradient computations themselves. Does the term meta-learning begin to make sense now? So, why do we backpropagate through gradient computations and updates? We do this because we want to learn parameter initializations that can reach a generalizable state for a particular task, after a number of updates. Thus, by backpropagating through the optimization process itself, we can obtain precise, information-rich gradients, that can push our model towards learning such a model very efficiently (when compared to RL and GAs). The part where we do this massive backpropagation through everything and update our network is called the outer loop optimization process. And, that’s it really! This is how MAML works. When implementing MAML, we usually evaluate a batch of tasks and use the sum or mean of their losses to update our model. This has the same effect that mini-batch training has on standard deep neural networks (i.e. the gradients obtained push the network towards a state that improves the performance of a batch of tasks, hence that direction is a more reqularized and thus generalizable one).

NorbertZheng commented 2 years ago

The idea is very elegant and effective. However, after attempting to use such a system as the base for more complicated systems, we realised that

Changing something as simple as the stride of a convolutional layer, or replacing max-pooling can have disastrous consequences for the training of the system, often producing unstable behaviour, with the system attempting to recover and converge, which sometimes results in the system requiring multiple times more compute time and achieving substantially lower generalization score. It soon became very clear that building on top of MAML in any meaningful way (i.e. learning complicated parameterized components which have their own architectures and hyperparameters) we would need to improve and stabilize MAML first.

The improvements list is as follows:

NorbertZheng commented 2 years ago

Stabilizing MAML: Multi Step Loss Optimization (MSL)

Problem

MAML optimizes the initialization parameters of a given model, such that after $N$ steps on the support set task, it will generalize well to a target set task. In other words, the init parameters $\theta=\theta_{0}$ are optimized on the outer loop optimization process with respect to the target set loss, which is computed using the predictions of the resulting model after $N$ updates.

image

However, we noted that changes as simple as using strided convolutions or adding more layers into the network, rendered it very unstable (as supported by the training loss oscillations in figure 2), in the sense that the train loss itself oscillated instead of smoothly being minimized.

This effectively meant that the model required far longer to train, and the model’s final generalization performance was potentially lower than what it could have been if the model was more stable.

NorbertZheng commented 2 years ago

So, why is it unstable?

One of the most frequent sources for instability in deep neural networks, are gradient degradation problems, that is, vanishing and exploding gradients. Furthermore, the signature of this instability was eerily similar to some experiments I’d done in the past with extremely deep networks.

In this instance we have a standard 4-layer convolutional network followed by a single linear layer, unrolled for 5 times. Since the inference graph that backprop has to traverse is effectively composed by a cascade of 25 layers, with no skip-connections, then gradient degradation could make sense. In addition, between each model update operation we also compute derivatives with respect to gradient computation itself, which introduces additional backpropagation operations that can cause gradients to degrade even further.

Ok, fair enough, are there any other symptoms or interesting model behaviour patterns that might shed additional light?

The instability disappears when we use 1-3 inner loop steps, and re-appears when we use 4-5 steps. This indicates that

Furthermore, a brief inspection of gradient summary statistics indicated that in 5-step models the gradients returning to the parameter initialization model appear to have very high max values and very low min values (close to 0).

NorbertZheng commented 2 years ago

Solution

Let’s assume that gradient degradation is indeed the problem, how can we solve it?

Adding skip-connections that connect the various iterations of the model is probably not a very good idea, as that might introduce additional gradient computational overheads and make things slower and more complex. Thus, another possible route is to

If the model received gradients immediately after each inner-loop step, (which corresponds to $L$ at each inner step in tem, instead of $loss$ at each outer step, seems that tem is much deeper, where we set $n_{rollout}$ to 75 >> 5), then the gradient degradation problem could be decreased significantly.

At the same time we could

So our model remains focused on the main goal as well. Explicit gradients can be introduced by computing the target set loss after every inner loop update, then computing a weighted average of the per-step losses to be the optimization loss.

This is important to do, to ensure that the model attempts to minimize the last step’s loss the most by the end of the experiment. This ensures that the additional update steps are utilized as much as possible.

After implementing the proposed method, we observed that the training performance stabilized, thus improving convergence speed and having a minor positive impact on the final generalization performance.

NorbertZheng commented 2 years ago

Figure 3 illustrates how multi step loss optimization works. So, the only difference between MAML and MAML with MSL, is that after each inner loop update step $\theta{i}=\theta{i-1}-\alpha\nabla{\theta{i-1}}L{i-1}^{S}$, we compute the target set loss using the current weights $\theta{i}$, instead of directly proceeding to execute yet another update with respect to the support set. After $N$ steps have been completed, and $N$ target set losses (one for each parameter state after each update step) have been obtained. We take a weighted sum of these losses $L{0...N}^{T}=\Sigma{N=0}^{N}w{i}L{i}^{T}$, and optimize the outer loop parameters $\theta$ using the combined loss.

image

$w$ here is a $N$ dimensional importance weight vector, that indicates the importance of each steps loss towards the overall loss. During early training, all losses have about equal importance, and as training progresses, the importance weights are annealed, such that earlier step losses have increasingly lower importance, and the $N^{th}$ step loss has increasingly higher importance assigned to it. As a result the model slowly transitions into the original MAML loss, whilst making sure that the gradients received at each update step, are cleaner and less probable to cause gradient degradation issues.

Well done. This is all one needs to understand to be able to implement and use the multi-step loss technique (assuming a good level of skill in deep learning frameworks).

NorbertZheng commented 2 years ago

If we add MAML MSL to tem, every bptt update step (over $n{rollout}$ steps) is the inner loop. After each inner loop, we use updated model to evaluate $loss$ using test-set, and $loss$ contains multiple steps (e.g. $n{chunk}^{S}*n_{rollout}$ steps) $L$, this is indeed a very deep network.

Then the number of inner loops ($N$) will be meaningless.

NorbertZheng commented 2 years ago

Results

Objectives achieved by MSL:

image

NorbertZheng commented 2 years ago

Step-by-step Batch Normalization for Meta-Learning (BNWB + BNRS)

Problem

In the original MAML paper, the authors

Doing so, makes the optimization landscape of learning the betas and gammas far more complex, as we are now effectively sharing the parameters across all possible means and standard deviations that the millions of mini-batches will have. Furthermore, the mean and standard deviation used for normalization are very far from the true mean and standard deviation which as a consequence reduces the generalization performance and convergence speed.

So one might ask, why would they not use standard batch-normalization with stored statistics?

The answer, is in fact, simpler than you might think. After running numerous experiments using standard batch normalization, I found out that it simply did not work. Furthermore, the authors chose to only learn betas, whilst keeping gammas stationary at some default value. Again this practice seems counter-intuitive. And again, the reason is that it won’t work otherwise. I found this fact perplexing.

NorbertZheng commented 2 years ago

The above issue looks similar to the problem that Layer Normalization is trying to solve. The activity is evolving with time! With one extremely long sequence, will MAML be lost in the accumulation of error during path integration process?

NorbertZheng commented 2 years ago

Solution

Why would one of the most powerful, well-tested and highly reliable normalization layers fail to work in this instance?

Well, after sleeping on it, I had a potential answer to this, and again, it is simpler than one might think. As most problems in science, this one was also stemming from a wrongly placed assumption. We were assuming that the initialization model and all it’s updated iterations had in fact, similar feature-distributions.

Obviously this assumption is far from correct. Especially in the case of MAML, where we are literally learning fast-adapting networks, which in other words, causes the model to quickly change as much as possible to learn a new task. Fixing this issue was as simple as

Once these simple changes were made, batch normalization improve the convergence speed and generalization performance substantially. The above tables showcase the improvements.

NorbertZheng commented 2 years ago

Results

Objectives achieved by BNRS + BNWB:

NorbertZheng commented 2 years ago

Per-Layer Per-Step Learnable Learning Rates (LSLR)

Problem

Selecting an inner loop learning rate can be an arduous process, requiring significant amounts of GPU hours. However, since MAML is framed within a meta-learning setting. One can choose to

And learnable learning rates have the follow benefits:

In fact, allowing the network to learn its own inner loop’s learning rate, opens the door to new possibilities. In a recent paper by Li et. al., called Meta-SGD, the authors propose

The results they showcase are state of the art. However, learning one learning rate for each parameter is very expensive memory-wise and computation-wise. Furthermore, in a multi-step setup (e.g. multiple inner-loop steps), one could also learn a learning rate and direction for each update step.

NorbertZheng commented 2 years ago

Solution

In an attempt to improve the expressivity of the system whilst keeping the memory and computational expenses similar to MAML, we instead propose

What this effectively means, is that we’ll allow the network to learn it’s own learning rate scheduler over the $N$ inner loop steps, whilst allowing the network to learn different learning rates for different layers (a trade-off between expressivity and memory- & computational-expenses), thus allowing the network to keep some layers pretty much identical to the initialization, for some others learn with much higher learning rates, and even allow negative learning to take place.

By negative learning we mean, taking update steps with negative learning rates. One can refer to that as forgetting, but we think that the term forgetting is not precise. More than likely, MAML learns to model how gradients at specific parameters/layers/time-steps influence future gradients, implicitly as part of training, (avoid local-minima maybe?). Thus, by allowing the network to learn its own learning rates for each step and layer, we effectively add more expressivity and thus freedom on how it can go about learning things. Perhaps, most learning rate choices (even negative ones) are only there to steer current and future gradients, to the right direction.

NorbertZheng commented 2 years ago

Results

The above tables showcase the effect of LSLR on MAML. One can clearly see that the generalization performance has increased. From inspecting the figures we can also see that convergence speed was increased in a similar manner to how step-by-step batch norm improved the system.

Targets achieved by LSLR:

NorbertZheng commented 2 years ago

Step by Step MAML / MAML++

The next step was to combine all methods. In addition to the previously mentioned methods, we trained the model using the Adam optimizer with $lr=0.001$, $b1=0.9$, $b2=0.99$ cosine annealing the learning rate down to 0.00001 over 150 epochs, each consisting of 500 outer loop updates. The above table shows the results of the approach, showcasing that the combined approaches allow for even further improvement in generalization performance and convergence speed, all whilst being highly stable and robust.

NorbertZheng commented 2 years ago

Can we learn anything from the learned learning rates? (Meta-meta-learning?)

One of the most important premises of meta-learning is automation of research in a way. The learned per-step per-layer learning rates represent an automatically learned solution. Since these learning rates work with the learned parameter initialization, they are not directly applicable to standard deep neural network training, however, they may provide interesting hints/insights into strong ways of learning to few-shot learn.

NorbertZheng commented 2 years ago

Why did you spend all this time just to improve a very specific meta-learning framework like MAML?

We weren’t really trying to improve MAML in particular. We were trying to

The potential of this particular type of meta-learning is immense, as it provides very sample-efficient meta-learning (when compared to RL and GA counterparts). MAML is a high performance and light-weight instance of gradient-based meta-learning, which makes it an ideal base on which one can build more complicated models. However, after attempting to use MAML we noticed many of its shortcomings (which are almost definitely present in other gradient-based meta-learning systems). Thus solving the problems that MAML has exposed is vital, not only for MAML itself, but for meta-learning as a whole. Once those problems are solved, one can begin to tap the potential of gradient-based meta-learning by building more interesting and complicated systems on top of existing gradient-based meta-learning frameworks (such as MAML).

NorbertZheng commented 2 years ago

Conclusion

The potential of meta-learning in solving hard problems is immense and at the current stage completely untapped. For the first time ever, we have the compute and software required to train very complicated meta-learning systems that can learn their own internal inference blocks. In this blog-post we demonstrated how one can stabilize and improve a very elegant and powerful such system, called MAML.

In our attempts to improve it, we introduced

The resulting learned learning rates provided a lot of insight into how a good few-shot learning system can be built, and made for very interesting and information dense visualizations that can provide further information into how we can build even better meta-learning systems. There is a lot more to be discussed about that aspect, but more of that in our next blog post.