benedekrozemberczki / pytorch_geometric_temporal

PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models (CIKM 2021)
MIT License
2.62k stars 369 forks source link

Extreme RAM consumption by multilayer models #159

Closed funbotan closed 2 years ago

funbotan commented 2 years ago

Hi Benedek! First of all, thank you for this project, I hope it has a bright future ahead. My issue is with building deep/multi-layer models out of Recurrent Graph Convolutional Layers. There are no examples of it in the repo (or anywhere else on the internet as far as I searched), so I might be doing something wrong. Here is my project. Now, what I observe when running this code is that RAM utilization goes up nearly exponentially with the number of RGC layers. Of course, most of it ends up in swap, making the training process too slow to be viable. This does not appear to be a memory leak, since most of the memory is not used by Python objects, but rather internally by PyTorch. Have you encountered this issue and is there a way to fix it?

SherylHYX commented 2 years ago

Based on your code, it seems like you are training in a "cumulative" manner for backpropagation, which needs to store many gradient-related objects before doing backpropagation. To alleviate this issue, one option is to turn to the "incremental" manner, which does backpropagation for each snapshot during training.

SherylHYX commented 2 years ago

For another, during evaluation, you could use "with torch.no_grad()". See for example here.

funbotan commented 2 years ago

Based on your code, it seems like you are training in a "cumulative" manner for backpropagation, which needs to store many gradient-related objects before doing backpropagation. To alleviate this issue, one option is to turn to the "incremental" manner, which does backpropagation for each snapshot during training.

Thank you, I see how that would help. However, won't that also discard the temporal data correlations?

SherylHYX commented 2 years ago

Thank you, I see how that would help. However, won't that also discard the temporal data correlations?

You could see from our paper that for some input e.g. Wikipedia Math, using the incremental backprop regime actually leads to better performance than the cumulative one. This is similar to using a mini-batch for SGD compared to using the full batch.

funbotan commented 2 years ago

Well, this did solve the memory consumption problem. Training time has only marginally improved, though, due to the frequent backward passes. I'll try searching for a compromise later. Can't really compare the performance, because the previous configuration never even finished training. Do you think this problem is inherent to GCNs or is it just an implementation issue?

SherylHYX commented 2 years ago

Good that it helps! It is related to how you do backprop for temporal snapshots but not a GCN issue.


From: Alexander Berezin @.> Sent: Tuesday, April 19, 2022 11:59:24 PM To: benedekrozemberczki/pytorch_geometric_temporal @.> Cc: Yixuan He @.>; Comment @.> Subject: Re: [benedekrozemberczki/pytorch_geometric_temporal] Extreme RAM consumption by multilayer models (Issue #159)

Well, this did solve the memory consumption problem. Training time has only marginally improved, though, due to the frequent backward passes. I'll try searching for a compromise later. Can't really compare the performance, because the previous configuration never even finished training. Do you think this problem is inherent to GCNs or is it just an implementation issue?

— Reply to this email directly, view it on GitHubhttps://github.com/benedekrozemberczki/pytorch_geometric_temporal/issues/159#issuecomment-1102824289, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AKKWOACA4HSW4VTQRFQRSXTVF3J5ZANCNFSM5TY3JXZA. You are receiving this because you commented.Message ID: @.***>

funbotan commented 2 years ago

Alright, I'll close the issue for the time being, but very much hope that someone comes up with a better solution down the line. Thank you very much @SherylHYX