Closed TPF2017 closed 1 month ago
I solve this by
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint["state_dict"])
Is there a better way?
You should use ModelForecast.load_from_checkpoint
instead. See https://github.com/SalesforceAIResearch/uni2ts/blob/main/cli/conf/finetune/model/moirai_small.yaml
edit: load_from_pretrained -> load_from_checkpoint
Thanks for your reply! I modify my code like this:
model_path = "./example_sales/checkpoints/epoch=0-step=100.ckpt"
model = MoiraiForecast(
module=MoiraiFinetune.load_from_checkpoint(model_path),
prediction_length=PDT,
context_length=CTX,
patch_size=PSZ,
num_samples=100,
target_dim=1,
feat_dynamic_real_dim=ds.num_feat_dynamic_real,
past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
)
predictor = model.create_predictor(batch_size=BSZ)
forecasts = predictor.predict(test_data.input)
forecast_it = iter(forecasts)
forecast = next(forecast_it)
but meeting new error when forecasting:
It should look something like below, for module_kwargs
, fill it in with the same arguments as the pre-training phase.
model = MoiraiForecast.load_from_checkpoint(
module_kwargs={...},
prediction_length=PDT,
context_length=CTX,
patch_size=PSZ,
num_samples=100,
target_dim=1,
feat_dynamic_real_dim=ds.num_feat_dynamic_real,
past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
checkpoint_path="./example_sales/checkpoints/epoch=0-step=100.ckpt",
)
I tried this way and it works, thank you very much for your answer!
MoiraiModule.from_pretrained doesn't seem to work.![image](https://github.com/SalesforceAIResearch/uni2ts/assets/26497745/00888a59-f89f-4471-987b-ea74c72d3aad)