timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
5.07k stars 633 forks source link

learn.get_X_preds RAM Memory Spike #692

Open lesego94 opened 1 year ago

lesego94 commented 1 year ago

My 15 GB GPU Crashes when making predictions. Upon further reading, (link attached) I found that the function get_X_preds does some very inefficient reloading of gigabytes of data to make predictions whereas this could be done one at a time at a fraction of the memory costs. This may have already been fixed in fastai, but the fix hasn't made it to tsai. Please read the attached forum for details.

https://forums.fast.ai/t/learn-get-preds-memory-inefficiency-quick-fix/84029

Does anyone know how to get around this issue? or how I can load my model in batches perhaps?

oguiza commented 1 year ago

Hello @lesego94, I have resolved another problem (#695) that could potentially be responsible for the GPU memory spike. It would be great if you could install tsai from GitHub using: pip install git+https://github.com/timeseriesAI/tsai.git

and verify if the problem still persists.

lesego94 commented 1 year ago

Hi Oguiza, appreciate you looking into this. It did not work. I realized I made a mistake earlier, the memory spike is occurring in my CPU ram, not GPU.

image

Let me give you some information about what I'm running. Im running the PatchTST model notebook 15_PatchTST_a_new_transformer_for_LTSF.ipynb using my own dataset, with 288,670 total parameters. The spike only occurs when I use:

learn = load_learner('models/patchTST.pt') scaled_preds, *_ = learn.get_X_preds(X[splits[1]])

image

oguiza commented 1 year ago

Ok, I understand. Have you pip installed tsai from the gh repo? Here are a couple of things to test:

lesego94 commented 1 year ago

Thanks Oguiza, I will try your suggestions and get back to you.