Closed gsamaras closed 2 years ago
Hey @gsamaras, have you tried once with checkpointing and then NBEATSModel.load_from_checkpoint()
?
We switched to PyTorch Lighting a couple of versions ago which changed the save/load logic.
If this doesn't work and you can't upgrade your Darts version, then you could theoretically change the TorchForecastingModel.load_model()
function signature in the source code (keep in mind all the downsides of doing this) to pass the device mapping to the internal torch.load()
call
Hi @dennisbader, so you mean, with Darts 0.16.1, in the GPU machine use NBEATSModel.save_checkpoints()
(ref), and then in the CPU machine use NBEATSModel.load_from_checkpoint()
?
Yes currently upgrading is not an option for my project-related-internal-reasons. :/
You should activate checkpointing at model creation on the GPU machine (see model docs here):
model = NBEATSModel(..., model_name="my_model", work_dir=my_work_dir, save_checkpoints=True)
And then on the CPU machine load from checkpoint see here
model = NBEATSModel.load_from_checkpoint(model_name="my_model", work_dir=my_work_dir, best=False)
(I can't say for certain that it will work though)
It didn't. On load I got:
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
If you have other ideas please let me know. Otherwise, when time permits in the future, I will upgrade Darts and revisit this.
Then you would have to adapt the code:
Thank you, I'll see if I manage to modify darts or if I'll be able to update Darts later.
I had trained an N-Beats model in GPU and now I want to use to do predictions in CPU.
Then on a CPU machine I do:
and get this error:
My attempt was to follow this SO question, like:
but this method doesn't accept any such parameter.
If I then use
torch.load()
, like:it works but then the predict method doesn't work:
because it's a Torch model and not a Darts one.
Can you please help?
I cannot use the
Trainer
parameter inpredict()
, because I am using Darts 0.16.1 and cannot change at this time, which has this signature: