openclimatefix / metnet

PyTorch Implementation of Google Research's MetNet and MetNet-2
MIT License
231 stars 47 forks source link

Memory issues in training loop #19

Open ValterFallenius opened 2 years ago

ValterFallenius commented 2 years ago

Detailed Description

I am running on a slightly down sampled dataset with spatial dimensions 448x448. The shape of my input tensor is (None, t = 7, c = 75, w = 112, h = 112) and output tensor shape (None, t=60, c = 51, w = 28, h = 28). I have implemented a version that works with pytorch-lightning for parallelization and would be happy to share if anyone wants. I got the following parameters:

Downsampler (same as paper): 1.6 M parameters Temporal encoder (hidden = 384): 6.6M parameters Temporal Aggregation (4 layers, heads=8, num_dims=2): 4.7M parameters

But when I run a single training epoch this with batch_size = 1 on an NVIDIA A100 GPU, I get the error:

"RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 39.59 GiB total capacity; 36.22 GiB already allocated; 6.19 MiB free; 37.53 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF". When I reduce the lead times this error disappears (this is the bottleneck)."

Possible Implementation

The issue is that the effective batch size with this training loop is 60*batch_size. The paper only does one random lead time per sample, which now makes sense to me. This solves the memory issue by allowing effective minimum batch size to be lower than 60. I am however not certain how to implement this since the training step is automated by the pytorch-lightning module. A quick fix would be to generate 60 different copies of the input tensor, encode them with all 60 different lead times and pair them with an output tensor of shape t=1 instead of t=60. However I see some potential issues with this solution because it's very inefficient memory wise since we would have 60 nearly identical input tensors.

jacobbieker commented 2 years ago

Yeah, based off the discussion you had, I think changing the default forward pass to take in which lead time makes the most sense, would reduce memory issues quite a lot. Sharing your training code would be great! Would love to see how that is working, and I'll try to get one written up this week too

jacobbieker commented 2 years ago

I've started making my training script, and changed MetNet and MetNet2 to only do one lead time per forward pass in #20

Hinode commented 2 years ago

Hi @jacobbieker , I just re-pull the master branch and follow your "backwords()" example to train the model. My model = MetNet(hidden_dim=32, forecast_steps=18, input_channels=7, output_channels=101, sat_channels=1, input_size=128)

And when starting the train and monitoring the memory, I found the memory is leaking. After ~250 epoch (on a single NVIDIA P100 gpu, 16GB memory) or ~100 epoch (on a single cpu, 130 GB memory), the train crashes with "segmentation fault". I tried to find out where the memory leak happens, but not sure yet. Would you train the model and check the memory leak issue? Thanks.

Hinode commented 2 years ago

@jacobbieker BTW, my model summary is like below: (does this make sense? Why the most parameters are in the "DownSampler" layer?)

===================================== Layer (type:depth-idx) Param #

MetNet -- ├─MetNetPreprocessor: 1-1 -- │ └─PixelUnshuffle: 2-1 -- │ └─CenterCrop: 2-2 -- ├─Dropout: 1-2 -- ├─TimeDistributed: 1-3 -- │ └─DownSampler: 2-3 -- │ │ └─Sequential: 3-1 1,596,640 ├─ConditionTime: 1-4 -- ├─TemporalEncoder: 1-5 -- │ └─ConvGRU: 2-4 -- │ │ └─ModuleList: 3-2 248,960 │ │ └─RNNDropout: 3-3 -- │ │ └─ModuleList: 3-4 -- ├─Sequential: 1-6 -- │ └─AxialAttention: 2-5 -- │ │ └─ModuleList: 3-5 8,256 ├─Conv2d: 1-7 3,333

Total params: 1,857,189 Trainable params: 1,857,189 Non-trainable params: 0

ValterFallenius commented 2 years ago

Hi Hinode, You can play around with the hyper parameters of the network in the metnet model file. Right now you are using hyper parameters much smaller than the original paper.

The downsampler they have has 286 input channels, 160 kernels --> 256 kernels --> 256 kernels --> 256 kernels. This leads to n_parameters: 9(286+1)160 + 9(160+1)256 + 2(9(256+1)256) = 1,968,480

I am not certain how to calculate the parameters for the ConvLSTM (ConvGRU in our case) and the axial attention blocks. But default in this model is significantly less than original paper. When I try to initialize with their hyper parameters I get something very far from 225M parameters:

| Name | Type | Params

0 | image_encoder | TimeDistributed | 2.0 M 1 | ct | ConditionTime | 0
2 | temporal_enc | TemporalEncoder | 6.6 M 3 | temporal_agg | Sequential | 9.4 M 4 | head | Conv2d | 197 K

18.2 M Trainable params 0 Non-trainable params 18.2 M Total params 72.960 Total estimated model params size (MB)

Hinode commented 2 years ago

Hi @ValterFallenius, interesting. Based on your analysis, the model looks quite different from the original paper.

Hinode commented 2 years ago

@ValterFallenius regarding your initial error message,

"RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 39.59 GiB total capacity; 36.22 GiB already allocated; 6.19 MiB free; 37.53 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF". When I reduce the lead times this error disappears (this is the bottleneck)."

The reason is someone else was also using/occupying your GPU at the same time when you were testing it. So the 37.53 GiB reserved. Whenever the GPU is released, the error message is gone.

ValterFallenius commented 2 years ago

@Hinode I dont think someone else was using the GPU, rather the backprop kept gradients for a batch size that was too big, effectively 60 as I mentioned.

I just realised another thing, assuming the authors number of 225M parameters is correct, then obviously something is missing in this model. I was wondering if perhaps the self.head instead is a fully connected layer instead of the 1x1 cnn? What do you think @jacobbieker @Hinode ?

EDIT: Fully connected would be way too big...

ValterFallenius commented 2 years ago

I reached out to author Casper again, he was very helpful and answered the following questions (in code blocks):

"Dear Valter,

1. I have implemented a model very similar to yours with help from Open Climate Fix at https://github.com/ValterFallenius/metnet/tree/lightning. I have tried to completely replicate your model by using the specified parameters, however I am not able to replicate the same amount of parameters (225M) as your model. I am wondering what I'm missing in my implementation, here are my parameters:

I'm not sure what the exact source of the difference is. However most plausible explanation is a difference in the transformer as the vast majority of the parameters reside here. Our heads were small MLPs as far as I remember (I'm not at google anymore so do not have access to the source code). Anyway i do not think it should make a huge difference in performance, the exact parameter count and you should be able to get to 95% of the performance with an order of magnitudes smaller model.

2. Did you do anything to prevent class imbalance. In my data 0-1mm rain/h bins stands for more than 99.9% of all data points, did you have the same property in your training data?

Our data was also severely skewed. We discarded samples with a number of pixels with precipitation below a certain threshold - like 5 for an example, during training.

3. How did you normalize your data? When you normalize your data with zero-mean and unit variance, do you do that sample by sample or on the entire dataset? I mean, which one of the following do you do?

The first option will not work. The second one should be fine, albeit using some percentile range instead of std is more stable to outliers. However we simply scaled our inputs with fixed constants to bring them into reasonable ranges. Anything beyond that harmed performance in the end.

4. This last question you can ignore if you don't have time for it. My network doesn't seem to be able to train properly but I am having a hard time finding the bug. I am trying to run it like you suggested on a small subset of the data, with only a single lead time and much fewer hidden layers but it doesn't train well. The pytorch model compiles and doesn't report any bugs but the network still won't reduce the error even when run repeatedly on the same training sample . If you can spare the time, full details are available at https://github.com/openclimatefix/metnet/issues/22.

If the plot is showing log loss the number seems too high e.g. your probability mass on the correct class is on average exp(-1.77) ≈ 1e-77, just prediting no-rain all the time should give you something much better. However one thing to note is log-loss changes after the first few hundred updates are usually super small. Most likely you have a sign error in the loss or not updating the params as the grads are not zero? I can't really help beyond that.

Good luck ironing out the bugs :)

Casper"

Hinode commented 2 years ago

@Hinode I dont think someone else was using the GPU, rather the backprop kept gradients for a batch size that was too big, effectively 60 as I mentioned.

I just realised another thing, assuming the authors number of 225M parameters is correct, then obviously something is missing in this model. I was wondering if perhaps the self.head instead is a fully connected layer instead of the 1x1 cnn? What do you think @jacobbieker @Hinode ?

EDIT: Fully connected would be way too big...

This is the exact question that I want to find out. I do need to dive deeper into the model structure.

Hinode commented 2 years ago

std

Thanks for sharing Casper's reply. It sounds like the MetNet will end? Since he left google ...

Hinode commented 2 years ago

@Hinode I dont think someone else was using the GPU, rather the backprop kept gradients for a batch size that was too big, effectively 60 as I mentioned.

I just realised another thing, assuming the authors number of 225M parameters is correct, then obviously something is missing in this model. I was wondering if perhaps the self.head instead is a fully connected layer instead of the 1x1 cnn? What do you think @jacobbieker @Hinode ?

EDIT: Fully connected would be way too big...

If you look at the MetNet init:

class MetNet(torch.nn.Module, PyTorchModelHubMixin): def init( self, image_encoder: str = "downsampler", input_channels: int = 12, sat_channels: int = 12, input_size: int = 256, output_channels: int = 12, hidden_dim: int = 64, kernel_size: int = 3, num_layers: int = 1, num_att_layers: int = 1, forecast_steps: int = 48, temporal_dropout: float = 0.2, **kwargs, ):

The num_att_layers =1. But in the original paper, there are 8 axial attention layers. So, @jacobbieker , how should the user customize the values? Any instructions? Thanks.

ValterFallenius commented 2 years ago

If you look at the MetNet init:

class MetNet(torch.nn.Module, PyTorchModelHubMixin): def init( self, image_encoder: str = "downsampler", input_channels: int = 12, sat_channels: int = 12, input_size: int = 256, output_channels: int = 12, hidden_dim: int = 64, kernel_size: int = 3, num_layers: int = 1, num_att_layers: int = 1, forecast_steps: int = 48, temporal_dropout: float = 0.2, **kwargs, ):

The num_att_layers =1. But in the original paper, there are 8 axial attention layers. So, @jacobbieker , how should the user customize the values? Any instructions? Thanks.

They use 8 layers like you say, but they also have a different implementation. Casper's reply mentions that each head in the attentions layers were small MLPs. I dont know exactly the details but the number of heads is also 16 in the original paper, not 8 like the standard inout is here. To run it with original settings you do num_att_layers=8 and and heads=16 in the MetNet.py file.

However still I think this model varies from the orginal paper a bit, since our parameters don't add up despite using their hyperparameters. But again, this should not be an issue...

Hinode commented 2 years ago

I agree that this is not an issue. But how do the users set 8 axial attention layers (4 x-direction and 4 y-direction. Separate one another)?

jacobbieker commented 2 years ago

I agree that this is not an issue. But how do the users set 8 axial attention layers (4 x-direction and 4 y-direction.色separate one another)?

That would be with

self.temporal_agg = nn.Sequential(
            *[
                AxialAttention(dim=hidden_dim, dim_index=1, heads=4, num_dimensions=2)
                for _ in range(num_att_layers)
            ]
        )

in the metnet.py The axial attention implementation I use here is here: https://github.com/lucidrains/axial-attention for more details on how it should work.