LibCity / Bigscity-LibCity

LibCity: An Open Library for Urban Spatial-temporal Data Mining
https://libcity.ai/
Apache License 2.0
871 stars 159 forks source link

Add model ST-TSNet #398

Closed hczs closed 5 months ago

hczs commented 5 months ago

数据集说明

使用数据集:TAXINYC20140112 ST-TSNet 使用的数据集与 LibCity 的 TAXINYC20140112 数据集数据不一致,LibCity 中的数据集是 pickup 和 dropoff 字段,并且查阅原始数据形状是: (number_of_timeslots, 2, 15, 5),时间间隔半小时,ST-TSNet 使用的 TAXINYC 数据集是 newflow 和 endflow 字段,形状:(number_of_timeslots, 2, 16, 8) ,时间间隔一小时,所以评估后重新转换原子文件进行训练 转换后的原子文件配置信息如下:

{
    "geo": {
        "including_types": [
            "Polygon"
        ],
        "Polygon": {
            "row_id": "num",
            "column_id": "num"
        }
    },
    "grid": {
        "including_types": [
            "state"
        ],
        "state": {
            "row_id": 15,
            "column_id": 5,
            "new_flow": "num",
            "end_flow": "num"
        }
    },
    "info": {
        "data_col": [
            "new_flow",
            "end_flow"
        ],
        "ext_col": [
            "TimeFeature0",
            "TimeFeature1",
            "TimeFeature2",
            "TimeFeature3",
            "TimeFeature4",
            "TimeFeature5",
            "TimeFeature6",
            "TimeFeature7",
            "Holiday",
            "Weather0",
            "Weather1",
            "Weather2",
            "Weather3",
            "Weather4",
            "Weather5",
            "Weather6",
            "Weather7",
            "Weather8",
            "Weather9",
            "Weather10",
            "Weather11",
            "Weather12",
            "WindSpeed",
            "Temperature"
        ],
        "data_files": [
            "NYCTAXI20140112"
        ],
        "geo_file": "NYCTAXI20140112",
        "ext_file": "NYCTAXI20140112",
        "output_dim": 2,
        "time_intervals": 3600,
        "init_weight_inf_or_zero": "inf",
        "set_weight_link_or_dist": "dist",
        "calculate_weight_adj": false,
        "weight_adj_epsilon": 0.1
    }
}

模型训练说明

模型训练数据集相关参数和模型初始化参数都使用的原论文参数 数据集配置:

  {
    "batch_size": 128,
    "cache_dataset": true,
    "num_workers": 0,
    "pad_with_last_sample": false,
    "train_rate": 0.9,
    "eval_rate": 0.1,
    "scaler": "none",
    "load_external": true,
    "normal_external": false,
    "adjust_ext_timestamp": true,
    "ext_scaler": "none",
    "len_closeness": 10,
    "len_period": 0,
    "len_trend": 4,
    "nb_flow": 2,
    "days_test": 28,
    "prediction_offset": 0
  }

模型训练配置:

{
  "max_epoch": 800,
  "learner": "adam",
  "learning_rate": 0.005,
  "weight_decay": 5E-5,

  "lr_decay": true,
  "lr_scheduler": "lambdalr",
  "lrf": 0.01,

  "use_early_stop": true,
  "patience": 500,

  "device": "cuda",
  "map_height": 16,
  "map_width": 8,
  "m_factor": 1,
  "m_factor_2": 1,
  "random_pick": false,
 "loss_weight": 1.0,
 "time_class": 24,
  "drop_prob": 0.1,
  "conv_channels": 64,
  "pre_conv": 0,
  "seq_pool": true,
  "shortcut": true,
  "patch_size": 8,
  "close_channels": 20,
  "trend_channels": 8,
  "close_dim": 128,
  "trend_dim": 128,
  "close_depth": 2,
  "trend_depth": 2,
  "close_head": 2,
  "trend_head": 2,
  "close_mlp_dim": 512,
  "trend_mlp_dim": 512
}

训练结果

NYCTAXI20140112 LibCity ST-TSNet
new-flow-MAE 15.68 12.63
new-flow-MSE 7865.50 4297.74
new-flow-RMSE 88.68 65.56
end-flow-MAE 14.33 12.55
end-flow-MSE 5361.35 4250.46
end-flow-RMSE 73.22 65.20
hczs commented 5 months ago

上述是自己转换的 ST-TSNet 原版数据集跑的结果,下面是使用 Libcity 的 gird 数据集跑的结果: