thuml / Autoformer

About Code release for "Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting" (NeurIPS 2021), https://arxiv.org/abs/2106.13008
MIT License
2k stars 429 forks source link

Huggingface 示例代码bug #179

Open LiuZhihhxx opened 1 year ago

LiuZhihhxx commented 1 year ago

运行huggingface关于AutoformerForPrediction的演示代码

from huggingface_hub import hf_hub_download
import torch
from transformers import AutoformerForPrediction

file = hf_hub_download(
    repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset"
)
batch = torch.load(file)

model = AutoformerForPrediction.from_pretrained("huggingface/autoformer-tourism-monthly")

# during training, one provides both past and future values
# as well as possible additional features
outputs = model(
    past_values=batch["past_values"],
    past_time_features=batch["past_time_features"],
    past_observed_mask=batch["past_observed_mask"],
    static_categorical_features=batch["static_categorical_features"],
    static_real_features=batch["static_real_features"],
    future_values=batch["future_values"],
    future_time_features=batch["future_time_features"],
)

loss = outputs.loss
loss.backward()

# during inference, one only provides past values
# as well as possible additional features
# the model autoregressively generates future values
outputs = model.generate(
    past_values=batch["past_values"],
    past_time_features=batch["past_time_features"],
    past_observed_mask=batch["past_observed_mask"],
    static_categorical_features=batch["static_categorical_features"],
    static_real_features=batch["static_real_features"],
    future_time_features=batch["future_time_features"],
)

mean_prediction = outputs.sequences.mean(dim=1)

outputs = model(...)出现了矩阵维度不匹配的bug: RuntimeError: mat 1 and mat 2 shapes cannot be multiplied(1536x23 and 22x64)

对应数据集中,bs=64, 输入长度=61, 预测长度=24, 有两个时间特征. 本人能力有限只能看出来1536=64*24, 其他几个维度实在是找不到规律所在. 而在前面AutoformerModel的demo与之相似,但在outputs = model(...)这步却没有报错. 请问应该如何解决? 感激不尽!!

1024djy commented 1 year ago

请问你有解决吗,我也碰到了这个问题