lhw828 / timesfm

Google timesfm 实战部署/本地部署详细记录
71 stars 10 forks source link

Tensors维度不匹配 #9

Open Lcq2002 opened 2 weeks ago

Lcq2002 commented 2 weeks ago

我修改了一下代码来适配最新的timesfm,但是运行时候报错: TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs. Loaded Jax TimesFM. Loaded PyTorch TimesFM. 加载本地数据文件: 000001.ss_data.csv 加载本地近期数据文件: 000001.ss_recent.csv Traceback (most recent call last): File "/root/autodl-tmp/forecast.py", line 67, in point_forecast, experimental_quantile_forecast = tfm.forecast( ^^^^^^^^^^^^^ File "/root/miniconda3/envs/timesfm/lib/python3.11/site-packages/timesfm/timesfm_torch.py", line 144, in forecast mean_output, full_output = self._model.decode( ^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/timesfm/lib/python3.11/site-packages/timesfm/pytorch_patched_decoder.py", line 787, in decode final_out = torch.concatenate([final_out, new_ts], axis=-1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Tensors must have same number of dimensions: got 3 and 2

代码如下: import datetime import pandas as pd import matplotlib.pyplot as plt import timesfm from huggingface_hub import login

codelist = ["000001.ss"]

data_file = f"{codelist[0]}_data.csv" # 例如 "000001.ss_data.csv" data_recent_file = f"{codelist[0]}_recent.csv" # 例如 "000001.ss_recent.csv"

print(f"加载本地数据文件: {data_file}") data2 = pd.read_csv(data_file, index_col=0, parse_dates=True)

print(f"加载本地近期数据文件: {data_recent_file}") data_recent = pd.read_csv(data_recent_file, index_col=0, parse_dates=True)

context_len = 512 # 使用最近512天的数据作为上下文 horizon_len = 256 # 设置预测期间的长度

if len(data2) < context_len: raise ValueError(f"数据长度 {len(data2)} 小于上下文长度 {context_len}")

context_data = data2[-context_len:]

local_model_path = "/root/autodl-tmp/timesfm-1.0-200m-pytorch/torch_model.ckpt" # 请根据本地路径修改

tfm = timesfm.TimesFm( hparams=timesfm.TimesFmHparams( backend="gpu", per_core_batch_size=32, horizon_len=128, ), checkpoint=timesfm.TimesFmCheckpoint( path=local_model_path), )

forecast_input = [context_data.values] frequency_input = [0] # 0表示高频数据

point_forecast, experimental_quantile_forecast = tfm.forecast( forecast_input, freq=frequency_input, )

forecast_start_date = data2.index[-1] + pd.Timedelta(days=1) forecast_dates = pd.date_range(start=forecast_start_date, periods=horizon_len, freq='B') # 'B' 表示工作日频率 forecast_series = pd.Series(point_forecast[0], index=forecast_dates)

plt.figure(figsize=(24, 12))

if not data_recent.empty: plt.plot(data_recent.index, data_recent.values, label="实际价格 (2024-至今)")

plt.plot(data2.index, data2.values, label="实际价格")

plt.plot(forecast_series.index, forecast_series.values, label="预测价格")

plt.xlabel("日期") plt.ylabel("价格") plt.title(f"{codelist[0]} 价格对比与预测") plt.legend()

output_image = f'{codelist[0]}_comparison.png' plt.savefig(output_image, bbox_inches='tight') print(f"图表已保存到 {output_image}")

plt.close('all')

lhw828 commented 1 week ago

forecast.zip 试试这个。

Lcq2002 commented 1 week ago

谢谢您的回复,我现在根据这个思路调整了一下,下面代码是可以正常运行的,就是效果看着不是很好 import datetime import pandas as pd import matplotlib.pyplot as plt import timesfm from huggingface_hub import login import numpy as np # 确保导入了 NumPy

codelist = ["000001.ss"]

data_file = f"{codelist[0]}_data.csv" # 例如 "000001.ss_data.csv" data_recent_file = f"{codelist[0]}_recent.csv" # 例如 "000001.ss_recent.csv"

print(f"加载本地数据文件: {data_file}") data2 = pd.read_csv(data_file, index_col=0, parse_dates=True)

print(f"加载本地近期数据文件: {data_recent_file}") data_recent = pd.read_csv(data_recent_file, index_col=0, parse_dates=True)

context_len = 512 # 使用最近512天的数据作为上下文 horizon_len = 256 # 设置预测期间的长度

if len(data2) < context_len: raise ValueError(f"数据长度 {len(data2)} 小于上下文长度 {context_len}")

context_data = data2[-context_len:]

local_model_path = "/root/autodl-tmp/timesfm-1.0-200m-pytorch/torch_model.ckpt" # 请根据本地路径修改

tfm = timesfm.TimesFm( hparams=timesfm.TimesFmHparams( backend="gpu", per_core_batch_size=32, horizon_len=horizon_len, # 与预测长度一致 ), checkpoint=timesfm.TimesFmCheckpoint( path=local_model_path), )

forecast_input = [context_data.values.flatten()] # 调整 forecast_input 的形状 frequency_input = [0] # 0表示高频数据

print("forecast_input shape:", forecast_input[0].shape) # 添加调试信息 print("frequency_input:", frequency_input)

try: point_forecast, experimental_quantile_forecast = tfm.forecast( forecast_input, freq=frequency_input, ) except RuntimeError as e: print(f"发生错误: {e}") # 如果依然出错,进一步调试 raise

forecast_start_date = data2.index[-1] + pd.Timedelta(days=1) forecast_dates = pd.date_range(start=forecast_start_date, periods=horizon_len, freq='B') # 'B' 表示工作日频率 forecast_series = pd.Series(point_forecast[0], index=forecast_dates)

plt.figure(figsize=(24, 12))

if not data_recent.empty: plt.plot(data_recent.index, data_recent.values, label="实际价格 (2024-至今)")

plt.plot(data2.index, data2.values, label="实际价格")

plt.plot(forecast_series.index, forecast_series.values, label="预测价格")

plt.xlabel("日期") plt.ylabel("价格") plt.title(f"{codelist[0]} 价格对比与预测") plt.legend()

output_image = f'{codelist[0]}_comparison.png' plt.savefig(output_image, bbox_inches='tight') print(f"图表已保存到 {output_image}")

plt.close('all')

lhw828 commented 1 week ago

是,我回溯过历史数据,效果确实一般般,可能是我还不太会使用吧。