Description of changes: This PR adds MLX inference support.
Summary of changes
Update pyproject.toml withmlx dependencies.
Create chronos_mlx package which will hosts all mlx inference stuff.
All classes from main:src/chronos/chronos.py are copy-pasted into mlx:src/chronos_mlx/chronos.py and modified to use numpy and mlx arrays instead. Note that the reason for using numpy arrays as input and output is that mlx doesn't support some operations that are required for input and output transform.
MLX implementation of T5 is in src/chronos_mlx/t5.py. It has been adapted from ml-explore/mlx-examples with the following main modifications:
Added support for attention mask.
Added support for top_k and top_p sampling.
src/chronos_mlx/translate.py translates weights from a torch HF model to mlx.
Add THIRD-PARTY-LICENSES.txt for third party code from mlx-examples.
Add tests and CI for mlx version.
Sample inference code
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from chronos_mlx import ChronosPipeline
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
dtype="bfloat16",
)
df = pd.read_csv(
"https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
)
# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = df["#Passengers"].values
prediction_length = 12
forecast = pipeline.predict(
context, prediction_length
) # shape [num_series, num_samples, prediction_length]
# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0)
plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
plt.plot(forecast_index, median, color="tomato", label="median forecast")
plt.fill_between(
forecast_index,
low,
high,
color="tomato",
alpha=0.3,
label="80% prediction interval",
)
plt.legend()
plt.grid()
plt.show()
Benchmark
import timeit
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from gluonts.dataset.repository import get_dataset
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import SampleForecast
from tqdm.auto import tqdm
from chronos import ChronosPipeline as ChronosPipelineTorch
from chronos_mlx import ChronosPipeline as ChronosPipelineMLX
def benchmark_torch_model(
pipeline: ChronosPipelineTorch,
gluonts_dataset: str = "m4_hourly",
batch_size: int = 32,
):
dataset = get_dataset(gluonts_dataset)
prediction_length = dataset.metadata.prediction_length
_, test_template = split(dataset.test, offset=-prediction_length)
test_data = test_template.generate_instances(prediction_length)
test_data_input = list(test_data.input)
start_time = timeit.default_timer()
forecasts = []
for idx in tqdm(range(0, len(test_data_input), batch_size)):
batch = [
torch.tensor(item["target"])
for item in test_data_input[idx : idx + batch_size]
]
batch_forecasts = pipeline.predict(batch, prediction_length)
forecasts.append(batch_forecasts)
forecasts = torch.cat(forecasts)
end_time = timeit.default_timer()
print(f"Inference time: {end_time-start_time:.2f}s")
results_df = evaluate_forecasts(
forecasts=[
SampleForecast(fcst.numpy(), start_date=label["start"])
for fcst, label in zip(forecasts, test_data.label)
],
test_data=test_data,
metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
)
results_df["inference_time"] = end_time - start_time
return results_df
def benchmark_mlx_model(
pipeline: ChronosPipelineMLX,
gluonts_dataset: str = "m4_hourly",
batch_size: int = 32,
):
dataset = get_dataset(gluonts_dataset)
prediction_length = dataset.metadata.prediction_length
_, test_template = split(dataset.test, offset=-prediction_length)
test_data = test_template.generate_instances(prediction_length)
test_data_input = list(test_data.input)
start_time = timeit.default_timer()
forecasts = []
for idx in tqdm(range(0, len(test_data_input), batch_size)):
batch = [item["target"] for item in test_data_input[idx : idx + batch_size]]
batch_forecasts = pipeline.predict(batch, prediction_length)
forecasts.append(batch_forecasts)
forecasts = np.concatenate(forecasts)
end_time = timeit.default_timer()
print(f"Inference time: {end_time-start_time:.2f}s")
results_df = evaluate_forecasts(
forecasts=[
SampleForecast(fcst, start_date=label["start"])
for fcst, label in zip(forecasts, test_data.label)
],
test_data=test_data,
metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
)
results_df["inference_time"] = end_time - start_time
return results_df
def main(
version: str = "cpu", # cpu, mps, mlx
dtype: str = "bfloat16",
gluonts_dataset: str = "australian_electricity_demand",
model_name: str = "amazon/chronos-t5-small",
batch_size: int = 4,
):
if version == "cpu" or version == "mps":
pipeline = ChronosPipelineTorch.from_pretrained(
model_name,
device_map=version,
torch_dtype=getattr(torch, dtype),
)
benchmark_fn = benchmark_torch_model
else:
pipeline = ChronosPipelineMLX.from_pretrained(model_name, dtype=dtype)
benchmark_fn = benchmark_mlx_model
result_df = benchmark_fn(
pipeline, gluonts_dataset=gluonts_dataset, batch_size=batch_size
)
result_df["model"] = model_name
return result_df
if __name__ == "__main__":
gluonts_dataset: str = "m4_hourly"
model_name: str = "amazon/chronos-t5-mini"
batch_size: int = 8
dfs = []
for version in ["cpu", "mps", "mlx"]:
for dtype in ["float32"]:
try:
df = main(
version=version,
dtype=dtype,
model_name=model_name,
gluonts_dataset=gluonts_dataset,
batch_size=batch_size,
)
df["version"] = version
df["dtype"] = dtype
dfs.append(df)
except TypeError:
pass
result_df = pd.concat(dfs).reset_index(drop=True)
result_df.to_csv("benchmark.csv", index=False)
result_df["version"] = result_df["version"].map(
{"cpu": "Torch (CPU)", "mps": "Torch (MPS)", "mlx": "MLX"}
)
fig = plt.figure(figsize=(8, 5))
g = sns.barplot(
data=result_df,
x="dtype",
y="inference_time",
hue="version",
alpha=0.6,
)
plt.ylabel("Inference Time (on M1 Pro)")
plt.title(f"{model_name} inference times on {gluonts_dataset} dataset")
plt.savefig("benchmark.png", dpi=200)
TODOs:
[x] Implement top_p sampling.
[x] Add tests.
[x] Add CI.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
Issue #, if available: #28
Description of changes: This PR adds MLX inference support.
Summary of changes
pyproject.toml
withmlx
dependencies.chronos_mlx
package which will hosts all mlx inference stuff.main:src/chronos/chronos.py
are copy-pasted intomlx:src/chronos_mlx/chronos.py
and modified to use numpy and mlx arrays instead. Note that the reason for using numpy arrays as input and output is that mlx doesn't support some operations that are required for input and output transform.src/chronos_mlx/t5.py
. It has been adapted from ml-explore/mlx-examples with the following main modifications:src/chronos_mlx/translate.py
translates weights from a torch HF model to mlx.THIRD-PARTY-LICENSES.txt
for third party code frommlx-examples
.mlx
version.Sample inference code
Benchmark
TODOs:
top_p
sampling.By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.