Open Lcq2002 opened 1 month ago
forecast.zip 试试这个。
谢谢您的回复,我现在根据这个思路调整了一下,下面代码是可以正常运行的,就是效果看着不是很好 import datetime import pandas as pd import matplotlib.pyplot as plt import timesfm from huggingface_hub import login import numpy as np # 确保导入了 NumPy
codelist = ["000001.ss"]
data_file = f"{codelist[0]}_data.csv" # 例如 "000001.ss_data.csv" data_recent_file = f"{codelist[0]}_recent.csv" # 例如 "000001.ss_recent.csv"
print(f"加载本地数据文件: {data_file}") data2 = pd.read_csv(data_file, index_col=0, parse_dates=True)
print(f"加载本地近期数据文件: {data_recent_file}") data_recent = pd.read_csv(data_recent_file, index_col=0, parse_dates=True)
context_len = 512 # 使用最近512天的数据作为上下文 horizon_len = 256 # 设置预测期间的长度
if len(data2) < context_len: raise ValueError(f"数据长度 {len(data2)} 小于上下文长度 {context_len}")
context_data = data2[-context_len:]
local_model_path = "/root/autodl-tmp/timesfm-1.0-200m-pytorch/torch_model.ckpt" # 请根据本地路径修改
tfm = timesfm.TimesFm( hparams=timesfm.TimesFmHparams( backend="gpu", per_core_batch_size=32, horizon_len=horizon_len, # 与预测长度一致 ), checkpoint=timesfm.TimesFmCheckpoint( path=local_model_path), )
forecast_input = [context_data.values.flatten()] # 调整 forecast_input 的形状 frequency_input = [0] # 0表示高频数据
print("forecast_input shape:", forecast_input[0].shape) # 添加调试信息 print("frequency_input:", frequency_input)
try: point_forecast, experimental_quantile_forecast = tfm.forecast( forecast_input, freq=frequency_input, ) except RuntimeError as e: print(f"发生错误: {e}") # 如果依然出错,进一步调试 raise
forecast_start_date = data2.index[-1] + pd.Timedelta(days=1) forecast_dates = pd.date_range(start=forecast_start_date, periods=horizon_len, freq='B') # 'B' 表示工作日频率 forecast_series = pd.Series(point_forecast[0], index=forecast_dates)
plt.figure(figsize=(24, 12))
if not data_recent.empty: plt.plot(data_recent.index, data_recent.values, label="实际价格 (2024-至今)")
plt.plot(data2.index, data2.values, label="实际价格")
plt.plot(forecast_series.index, forecast_series.values, label="预测价格")
plt.xlabel("日期") plt.ylabel("价格") plt.title(f"{codelist[0]} 价格对比与预测") plt.legend()
output_image = f'{codelist[0]}_comparison.png' plt.savefig(output_image, bbox_inches='tight') print(f"图表已保存到 {output_image}")
plt.close('all')
是,我回溯过历史数据,效果确实一般般,可能是我还不太会使用吧。
我修改了一下代码来适配最新的timesfm,但是运行时候报错: TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs. Loaded Jax TimesFM. Loaded PyTorch TimesFM. 加载本地数据文件: 000001.ss_data.csv 加载本地近期数据文件: 000001.ss_recent.csv Traceback (most recent call last): File "/root/autodl-tmp/forecast.py", line 67, in
point_forecast, experimental_quantile_forecast = tfm.forecast(
^^^^^^^^^^^^^
File "/root/miniconda3/envs/timesfm/lib/python3.11/site-packages/timesfm/timesfm_torch.py", line 144, in forecast
mean_output, full_output = self._model.decode(
^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/timesfm/lib/python3.11/site-packages/timesfm/pytorch_patched_decoder.py", line 787, in decode
final_out = torch.concatenate([final_out, new_ts], axis=-1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Tensors must have same number of dimensions: got 3 and 2
代码如下: import datetime import pandas as pd import matplotlib.pyplot as plt import timesfm from huggingface_hub import login
codelist = ["000001.ss"]
data_file = f"{codelist[0]}_data.csv" # 例如 "000001.ss_data.csv" data_recent_file = f"{codelist[0]}_recent.csv" # 例如 "000001.ss_recent.csv"
print(f"加载本地数据文件: {data_file}") data2 = pd.read_csv(data_file, index_col=0, parse_dates=True)
print(f"加载本地近期数据文件: {data_recent_file}") data_recent = pd.read_csv(data_recent_file, index_col=0, parse_dates=True)
context_len = 512 # 使用最近512天的数据作为上下文 horizon_len = 256 # 设置预测期间的长度
if len(data2) < context_len: raise ValueError(f"数据长度 {len(data2)} 小于上下文长度 {context_len}")
context_data = data2[-context_len:]
local_model_path = "/root/autodl-tmp/timesfm-1.0-200m-pytorch/torch_model.ckpt" # 请根据本地路径修改
tfm = timesfm.TimesFm( hparams=timesfm.TimesFmHparams( backend="gpu", per_core_batch_size=32, horizon_len=128, ), checkpoint=timesfm.TimesFmCheckpoint( path=local_model_path), )
forecast_input = [context_data.values] frequency_input = [0] # 0表示高频数据
point_forecast, experimental_quantile_forecast = tfm.forecast( forecast_input, freq=frequency_input, )
forecast_start_date = data2.index[-1] + pd.Timedelta(days=1) forecast_dates = pd.date_range(start=forecast_start_date, periods=horizon_len, freq='B') # 'B' 表示工作日频率 forecast_series = pd.Series(point_forecast[0], index=forecast_dates)
plt.figure(figsize=(24, 12))
if not data_recent.empty: plt.plot(data_recent.index, data_recent.values, label="实际价格 (2024-至今)")
plt.plot(data2.index, data2.values, label="实际价格")
plt.plot(forecast_series.index, forecast_series.values, label="预测价格")
plt.xlabel("日期") plt.ylabel("价格") plt.title(f"{codelist[0]} 价格对比与预测") plt.legend()
output_image = f'{codelist[0]}_comparison.png' plt.savefig(output_image, bbox_inches='tight') print(f"图表已保存到 {output_image}")
plt.close('all')