openclimatefix / power_perceiver

Machine learning experiments using the Perceiver IO model to forecast the electricity system (starting with solar)
MIT License
7 stars 1 forks source link

[ML Idea] Hierarchical Perceiver #14

Open jacobbieker opened 2 years ago

jacobbieker commented 2 years ago

Detailed Description

https://arxiv.org/abs/2202.10890 is a new Perceiver model that is supposed to scale better,etc. I've only skimmed it, but might be a nice improvement.

Context

Possible Implementation

JackKelly commented 2 years ago

Thanks loads for the heads up! I hadn't seen that paper! Looks great.

JackKelly commented 2 years ago

I've only skim-read a few paragraphs so far (!) but, yeah, this could be really useful to allow us to include large geographical regions in the input! Nice! Thanks again.

JackKelly commented 2 years ago

I've started reading the paper. Some notes:

The paper is very focused on creating a one size fits all "AutoML" model. But in our case, we're happy to modify the architecture for our domain.

We could use the HiP idea by using one group per timestep (instead of trying to use a Perceiver as an RNN, where each cross-attend sees a timestep).

Advantages of using "one-HiP-group-per-timestep":

Disadvantages: In Perceiver, I liked the idea that cross attend i can gets to query cross attend i+1. But maybe that's just a pretty idea, rather than something that's definitely skillful in practice!

The paper shows that learned position encodings kind of suck when trained purely on a classification task. In the paper, they use a (heavily) masked auto encoding to pre-train the learned position encoding. For us, we have a very natural dense prediction task (predicting a future satellite image) which may be sufficient to learn position encodings. That said, the paper's results on learned position encodings weren't that amazing. So I think we're probably fine to stick with hand-crafted Fourier features for now.

One thing I'm a bit nervous about is the "fan out" (where they use, say, a hierarchy of 1-4-8 groups to decode the signal). Instead, I wonder if it might be nice to use HiP as the "history encoder", and then repeatedly cross-attend to the output of that "history encoder" to generate each timestep of prediction. Although I can't entirely put my finger on why I think that, right now! Maybe one thought is that I feel like every prediction timestep should have access to the entire latent bottleneck vector, and not just one group. And, because, if we want to predict days into the future, then we'll need, like, 48 output groups. Which sounds onorous. And each output group would only get a tiny input.

Right now, without thinking about it very deeply (and without having read the entire paper), I quite like the idea of using HiP as the "history encoder", and using repeated cross-attends to generate the predictions.

JackKelly commented 2 years ago

Or do the MetNet thing: run the entire model (encoder and decoder) for every prediction timestep. Ie the model always outputs exactly one timestep. And all the queries are conditioned on the target timestep.

I kind of find this a bit ugly, because it feels like the encoder should only have to run once for all prediction timesteps. But can't really argue with MetNet's performance!

JackKelly commented 2 years ago

After sleeping on it, I really like the idea of kind of merging ideas from Hierarchical Perceiver (Carreira et al. 2022) with ideas from MetNet (Sønderby et al. 2020).

To be specific: A model which:

image

Link to Google Drawing of the diagram above.

Link to Twitter thread

JackKelly commented 2 years ago

@jacobbieker absolutely no rush but I'd be curious if you have any thoughts on the architecture sketched out above? :slightly_smiling_face:

jacobbieker commented 2 years ago

Yeah, I like the idea! One thing that MetNet-2 found was that instead of just giving the forecast timestep in the input, condition all the convolutional blocks with the timestep, I do it here in our implementation, where a scale and bias are computed for each conv block based on teh forecast timestep and added to that layer, to ensure its strongly conditioned on the lead time. So I think something like that could help here too. But other than that, I think this looks great!

JackKelly commented 2 years ago

MetNet-2 found was that instead of just giving the forecast timestep in the input, condition all the convolutional blocks with the timestep

ooh, that's really interesting, I had missed that detail! Very useful, thank you @jacobbieker!

JackKelly commented 2 years ago

After thinking about graph neural networks, I'm now thinking about an encoder / decoder model, more like this:

image

(link to Google Drawing).

See issue #34 for more discussion of graph neural nets, and why the model above might be a nice compromise.