SalesforceAIResearch / uni2ts

[ICML2024] Unified Training of Universal Time Series Forecasting Transformers
Apache License 2.0
616 stars 50 forks source link

Question: Is it possible to use Moirai without GluonTS? #9

Closed isaacmg closed 1 month ago

isaacmg commented 3 months ago

Hi thanks for the repository and the example usage. However, I would like to use the model for inference (and possibly fine-tune it as well) without using GluonTS at all. Could you elaborate on what raw inputs the model takes and how to decouple it from the Gluon dependencies?

Thanks.

gorold commented 3 months ago

The MoiraiForecast class is a PyTorch Lightning Module, you can use it without the GluonTS components (but you'd have to rewrite the class removing the GluonTS imports). Btw, GluonTS doesn't have any Gluon dependencies.

The required inputs are:

  1. past_target, your input time series
  2. past_observed_target, a boolean tensor, True to indicate the data point is observed, False otherwise
  3. past_id_pad, a boolean tensor, True to indicate the data point is padding

You can check out the MoiraiForecast implementation to check the typehints for more information about the inputs. Here's an example:

import torch
from huggingface_hub import hf_hub_download

from uni2ts.model.moirai import MoiraiForecast

SIZE = "small"  # model size: choose from {'small', 'base', 'large'}
PDT = 20  # prediction length: any positive integer
CTX = 200  # context length: any positive integer
PSZ = "auto"  # patch size: choose from {"auto", 8, 16, 32, 64, 128}
BSZ = 32  # batch size: any positive integer
TEST = 100  # test set length: any positive integer

# Prepare pre-trained model by downloading model weights from huggingface hub
model = MoiraiForecast.load_from_checkpoint(
    checkpoint_path=hf_hub_download(
        repo_id=f"Salesforce/moirai-1.0-R-{SIZE}", filename="model.ckpt"
    ),
    prediction_length=PDT,
    context_length=CTX,
    patch_size=PSZ,
    num_samples=100,
    target_dim=1,
    feat_dynamic_real_dim=0,
    past_feat_dynamic_real_dim=0,
    map_location="cpu",
)

PAST = CTX + PDT if PSZ == 'auto' else CTX
pred = model(
    past_target=torch.randn(BSZ, PAST, 1),
    past_observed_target=torch.ones(BSZ, PAST, 1, dtype=torch.bool),
    past_is_pad=torch.zeros(BSZ, PAST, dtype=torch.bool),
)
gorold commented 2 months ago

Closing due to inactivity, feel free to reopen if you need any further clarifications.