openclimatefix / metnet

PyTorch Implementation of Google Research's MetNet and MetNet-2
MIT License
242 stars 50 forks source link

Attention Layer Bottleneck #24

Open ValterFallenius opened 2 years ago

ValterFallenius commented 2 years ago

Found a bottleneck: the attention layer I have found a potential bottleneck for why bug #22 occurred. It seems like the axial attention layer is some kind of bottleneck. I ran the network for 1000 epochs to try to overfit a small subset of 4 samples. See run at WandB. The network is not able to drop the loss at all almost and does not overfit the data, it yields a very bad result, some kind of mean. See image below:

yhat against y

After removing the axial attention layer the model does as expected and overfits the training data, see below after 100 epochs:

yhat against y without attention

The message from the author listed in #19 does mention that our implementation of axial attention seems to be very different from theirs, he says: "Our (Google's) heads were small MLPs as far as I remember (I'm not at google anymore so do not have access to the source code)." I am not experienced enough to look into the source code of our Axial Attention Library to see how this differs from theirs.

  1. What are heads in the axial attention? What does the number of heads have to do with anything?
  2. Are we doing both vertical and horizontal attention passes in our implementation?
jacobbieker commented 2 years ago

So, this is the actual implementation being used for axial attention: https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L153-L178 which seems like ti is doing both vertical and horizontal passes. But, I just realized that we don't actually do any position embeddings, other than the lat/lon inputs, before passing to the axial attention. So we might need to add that and see what happens? The number of heads is the number of heads for multi-headed attention. So we can probably just set it to one and be fine I think.

ValterFallenius commented 2 years ago

Okay, how do we do this?

If we have 8 channels in the RNN output with 28×28 height and width, is this embedding information of which pixel we are in? I am struggling a bit wrapping my head around attention and axial attention...

Also when you say set the number of heads to 1 you mean for debugging, right? We still want multi head attention to replicate their model in the end.

jacobbieker commented 2 years ago

Yeah, set to 1 for the debugging to get it to overfit first. And yeah, the position embedding is saying which pixel we are in, and the location information of that pixel related to other pixels in the input. The library has a function for it, so we can probably just do this for where we use the axial attention: https://github.com/openclimatefix/metnet/pull/25

JackKelly commented 2 years ago

I just realized that we don't actually do any position embeddings, other than the lat/lon inputs, before passing to the axial attention. So we might need to add that and see what happens?

I don't know if it's relevant but it recently occured to me that MetNet version 1 is quite similar to the Temporal Fusion Transformer (TFT) (also from Google!), except MetNet has 2 spatial dimensions, whilst TFT is for timeseries without any (explicit) spatial dimensions. In particular, both TFT and MetNet use an RNN followed by multi-head attention. In the TFT paper, they claim that the RNN generates a kind of learnt position encoding. So they don't bother with a "hand-crafted" position encoding.

The TFT paper says:

[The LSTM] also serves as a replacement for standard positional encoding, providing an appropriate inductive bias for the time ordering of the inputs.

ValterFallenius commented 2 years ago

I can confirm initial tests show promising results now, the networks seems to learn something now :) I'll be back with more results in a few days.

peterdudfield commented 2 years ago

@all-contributors please add @jacobbieker for code

allcontributors[bot] commented 2 years ago

@peterdudfield

I've put up a pull request to add @jacobbieker! :tada:

peterdudfield commented 2 years ago

@all-contributors please add @JackKelly for code

allcontributors[bot] commented 2 years ago

@peterdudfield

I've put up a pull request to add @JackKelly! :tada:

peterdudfield commented 2 years ago

@all-contributors please add @ValterFallenius for userTesting

allcontributors[bot] commented 2 years ago

@peterdudfield

I've put up a pull request to add @ValterFallenius! :tada: