Open jacobbieker opened 2 years ago
Thanks loads for the heads up! I hadn't seen that paper! Looks great.
I've only skim-read a few paragraphs so far (!) but, yeah, this could be really useful to allow us to include large geographical regions in the input! Nice! Thanks again.
I've started reading the paper. Some notes:
The paper is very focused on creating a one size fits all "AutoML" model. But in our case, we're happy to modify the architecture for our domain.
We could use the HiP idea by using one group per timestep (instead of trying to use a Perceiver as an RNN, where each cross-attend sees a timestep).
Advantages of using "one-HiP-group-per-timestep":
Disadvantages: In Perceiver, I liked the idea that cross attend i can gets to query cross attend i+1. But maybe that's just a pretty idea, rather than something that's definitely skillful in practice!
The paper shows that learned position encodings kind of suck when trained purely on a classification task. In the paper, they use a (heavily) masked auto encoding to pre-train the learned position encoding. For us, we have a very natural dense prediction task (predicting a future satellite image) which may be sufficient to learn position encodings. That said, the paper's results on learned position encodings weren't that amazing. So I think we're probably fine to stick with hand-crafted Fourier features for now.
One thing I'm a bit nervous about is the "fan out" (where they use, say, a hierarchy of 1-4-8 groups to decode the signal). Instead, I wonder if it might be nice to use HiP as the "history encoder", and then repeatedly cross-attend to the output of that "history encoder" to generate each timestep of prediction. Although I can't entirely put my finger on why I think that, right now! Maybe one thought is that I feel like every prediction timestep should have access to the entire latent bottleneck vector, and not just one group. And, because, if we want to predict days into the future, then we'll need, like, 48 output groups. Which sounds onorous. And each output group would only get a tiny input.
Right now, without thinking about it very deeply (and without having read the entire paper), I quite like the idea of using HiP as the "history encoder", and using repeated cross-attends to generate the predictions.
Or do the MetNet thing: run the entire model (encoder and decoder) for every prediction timestep. Ie the model always outputs exactly one timestep. And all the queries are conditioned on the target timestep.
I kind of find this a bit ugly, because it feels like the encoder should only have to run once for all prediction timesteps. But can't really argue with MetNet's performance!
After sleeping on it, I really like the idea of kind of merging ideas from Hierarchical Perceiver (Carreira et al. 2022) with ideas from MetNet (Sønderby et al. 2020).
To be specific: A model which:
@jacobbieker absolutely no rush but I'd be curious if you have any thoughts on the architecture sketched out above? :slightly_smiling_face:
Yeah, I like the idea! One thing that MetNet-2 found was that instead of just giving the forecast timestep in the input, condition all the convolutional blocks with the timestep, I do it here in our implementation, where a scale and bias are computed for each conv block based on teh forecast timestep and added to that layer, to ensure its strongly conditioned on the lead time. So I think something like that could help here too. But other than that, I think this looks great!
MetNet-2 found was that instead of just giving the forecast timestep in the input, condition all the convolutional blocks with the timestep
ooh, that's really interesting, I had missed that detail! Very useful, thank you @jacobbieker!
After thinking about graph neural networks, I'm now thinking about an encoder / decoder model, more like this:
(link to Google Drawing).
See issue #34 for more discussion of graph neural nets, and why the model above might be a nice compromise.
Detailed Description
https://arxiv.org/abs/2202.10890 is a new Perceiver model that is supposed to scale better,etc. I've only skimmed it, but might be a nice improvement.
Context
Possible Implementation