amazon-science / chronos-forecasting

Chronos: Pretrained (Language) Models for Probabilistic Time Series Forecasting
https://arxiv.org/abs/2403.07815
Apache License 2.0
2.38k stars 273 forks source link

Add MLX inference support #41

Closed abdulfatir closed 5 months ago

abdulfatir commented 5 months ago

Issue #, if available: #28

Description of changes: This PR adds MLX inference support.

Summary of changes

Sample inference code

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from chronos_mlx import ChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",
    dtype="bfloat16",
)

df = pd.read_csv(
    "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
)

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = df["#Passengers"].values
prediction_length = 12
forecast = pipeline.predict(
    context, prediction_length
)  # shape [num_series, num_samples, prediction_length]

# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0)

plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
plt.plot(forecast_index, median, color="tomato", label="median forecast")
plt.fill_between(
    forecast_index,
    low,
    high,
    color="tomato",
    alpha=0.3,
    label="80% prediction interval",
)
plt.legend()
plt.grid()
plt.show()

Benchmark

benchmark

import timeit

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from gluonts.dataset.repository import get_dataset
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import SampleForecast
from tqdm.auto import tqdm

from chronos import ChronosPipeline as ChronosPipelineTorch
from chronos_mlx import ChronosPipeline as ChronosPipelineMLX

def benchmark_torch_model(
    pipeline: ChronosPipelineTorch,
    gluonts_dataset: str = "m4_hourly",
    batch_size: int = 32,
):
    dataset = get_dataset(gluonts_dataset)
    prediction_length = dataset.metadata.prediction_length
    _, test_template = split(dataset.test, offset=-prediction_length)
    test_data = test_template.generate_instances(prediction_length)
    test_data_input = list(test_data.input)

    start_time = timeit.default_timer()
    forecasts = []
    for idx in tqdm(range(0, len(test_data_input), batch_size)):
        batch = [
            torch.tensor(item["target"])
            for item in test_data_input[idx : idx + batch_size]
        ]
        batch_forecasts = pipeline.predict(batch, prediction_length)
        forecasts.append(batch_forecasts)
    forecasts = torch.cat(forecasts)
    end_time = timeit.default_timer()

    print(f"Inference time: {end_time-start_time:.2f}s")

    results_df = evaluate_forecasts(
        forecasts=[
            SampleForecast(fcst.numpy(), start_date=label["start"])
            for fcst, label in zip(forecasts, test_data.label)
        ],
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
    )
    results_df["inference_time"] = end_time - start_time
    return results_df

def benchmark_mlx_model(
    pipeline: ChronosPipelineMLX,
    gluonts_dataset: str = "m4_hourly",
    batch_size: int = 32,
):
    dataset = get_dataset(gluonts_dataset)
    prediction_length = dataset.metadata.prediction_length
    _, test_template = split(dataset.test, offset=-prediction_length)
    test_data = test_template.generate_instances(prediction_length)
    test_data_input = list(test_data.input)

    start_time = timeit.default_timer()
    forecasts = []
    for idx in tqdm(range(0, len(test_data_input), batch_size)):
        batch = [item["target"] for item in test_data_input[idx : idx + batch_size]]
        batch_forecasts = pipeline.predict(batch, prediction_length)
        forecasts.append(batch_forecasts)
    forecasts = np.concatenate(forecasts)
    end_time = timeit.default_timer()

    print(f"Inference time: {end_time-start_time:.2f}s")

    results_df = evaluate_forecasts(
        forecasts=[
            SampleForecast(fcst, start_date=label["start"])
            for fcst, label in zip(forecasts, test_data.label)
        ],
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
    )
    results_df["inference_time"] = end_time - start_time
    return results_df

def main(
    version: str = "cpu",  # cpu, mps, mlx
    dtype: str = "bfloat16",
    gluonts_dataset: str = "australian_electricity_demand",
    model_name: str = "amazon/chronos-t5-small",
    batch_size: int = 4,
):
    if version == "cpu" or version == "mps":
        pipeline = ChronosPipelineTorch.from_pretrained(
            model_name,
            device_map=version,
            torch_dtype=getattr(torch, dtype),
        )
        benchmark_fn = benchmark_torch_model
    else:
        pipeline = ChronosPipelineMLX.from_pretrained(model_name, dtype=dtype)
        benchmark_fn = benchmark_mlx_model

    result_df = benchmark_fn(
        pipeline, gluonts_dataset=gluonts_dataset, batch_size=batch_size
    )
    result_df["model"] = model_name
    return result_df

if __name__ == "__main__":
    gluonts_dataset: str = "m4_hourly"
    model_name: str = "amazon/chronos-t5-mini"
    batch_size: int = 8
    dfs = []
    for version in ["cpu", "mps", "mlx"]:
        for dtype in ["float32"]:
            try:
                df = main(
                    version=version,
                    dtype=dtype,
                    model_name=model_name,
                    gluonts_dataset=gluonts_dataset,
                    batch_size=batch_size,
                )
                df["version"] = version
                df["dtype"] = dtype
                dfs.append(df)
            except TypeError:
                pass

    result_df = pd.concat(dfs).reset_index(drop=True)
    result_df.to_csv("benchmark.csv", index=False)

    result_df["version"] = result_df["version"].map(
        {"cpu": "Torch (CPU)", "mps": "Torch (MPS)", "mlx": "MLX"}
    )
    fig = plt.figure(figsize=(8, 5))
    g = sns.barplot(
        data=result_df,
        x="dtype",
        y="inference_time",
        hue="version",
        alpha=0.6,
    )
    plt.ylabel("Inference Time (on M1 Pro)")
    plt.title(f"{model_name} inference times on {gluonts_dataset} dataset")
    plt.savefig("benchmark.png", dpi=200)

TODOs:

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.