jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.87k stars 611 forks source link

Can not load model #702

Open mijosch opened 2 years ago

mijosch commented 2 years ago

for saving:

class MyModelSave(Callback): def on_train_epoch_end(self, trainer, pl_module): torch.save(pl_module.state_dict(),"model.mdl")

for loading: checkpoint = torch.load('model.mdl') print("loaded") tft.load_state_dict(checkpoint) print("tft from statedict")

Model never get loaded(waited for 4 hours), statedict size > 250mb

How do you save and load the TemporalFusionTransformer? This does not work.

adamwuyu commented 2 years ago

Hi @mijosch I met the same problem, have you figured how to save and load TFT modle?

nicocheh commented 2 years ago

same here, do you have any news on this?

ozanozyegen commented 2 years ago

I use load_from_checkpoint function as shown in the TFT tutorial. best_tft = tft.load_from_checkpoint(best_model_path) There are alternative ways to do it described in the lightning docs.