SalesforceAIResearch / uni2ts

[ICML2024] Unified Training of Universal Time Series Forecasting Transformers
Apache License 2.0
611 stars 49 forks source link

How to load a local fine-tuned model #43

Closed TPF2017 closed 1 month ago

TPF2017 commented 1 month ago

MoiraiModule.from_pretrained doesn't seem to work. image

TPF2017 commented 1 month ago

I solve this by checkpoint = torch.load(model_path) model.load_state_dict(checkpoint["state_dict"])

Is there a better way?

gorold commented 1 month ago

You should use ModelForecast.load_from_checkpoint instead. See https://github.com/SalesforceAIResearch/uni2ts/blob/main/cli/conf/finetune/model/moirai_small.yaml

edit: load_from_pretrained -> load_from_checkpoint

TPF2017 commented 1 month ago

Thanks for your reply! I modify my code like this:

model_path = "./example_sales/checkpoints/epoch=0-step=100.ckpt"
model = MoiraiForecast(
    module=MoiraiFinetune.load_from_checkpoint(model_path),
    prediction_length=PDT,
    context_length=CTX,
    patch_size=PSZ,
    num_samples=100,
    target_dim=1,
    feat_dynamic_real_dim=ds.num_feat_dynamic_real,
    past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
)
predictor = model.create_predictor(batch_size=BSZ)
forecasts = predictor.predict(test_data.input)

forecast_it = iter(forecasts)
forecast = next(forecast_it)

but meeting new error when forecasting: ea066fde-7f42-4785-ab8f-566545294138

gorold commented 1 month ago

It should look something like below, for module_kwargs, fill it in with the same arguments as the pre-training phase.

model = MoiraiForecast.load_from_checkpoint(
    module_kwargs={...},
    prediction_length=PDT,
    context_length=CTX,
    patch_size=PSZ,
    num_samples=100,
    target_dim=1,
    feat_dynamic_real_dim=ds.num_feat_dynamic_real,
    past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
    checkpoint_path="./example_sales/checkpoints/epoch=0-step=100.ckpt",
)
TPF2017 commented 1 month ago

I tried this way and it works, thank you very much for your answer!