LibCity / Bigscity-LibCity

LibCity: An Open Library for Urban Spatial-temporal Data Mining
https://libcity.ai/
Apache License 2.0
934 stars 168 forks source link

feat: add model Trafformer #435

Closed hczs closed 3 months ago

hczs commented 3 months ago

任务执行配置信息

数据集配置文件(TrafficStatePointDataset.json) 下方是 METR_LA 的参数,跑 PEMS_BAY 需要将 batch_size 调整为 1,其他与 METR_LA 的参数一致

{
  "batch_size": 4,
  "cache_dataset": false,
  "num_workers": 0,
  "pad_with_last_sample": false,

  "train_rate": 0.7,
  "eval_rate": 0.1,

  "scaler": "standard",
  "load_external": true,
  "normal_external": false,
  "ext_scaler": "none",
  "input_window": 12,
  "output_window": 12,
  "add_time_in_day": true,
  "add_day_in_week": false
}

模型配置文件(Trafformer.json)

{
  "output_attention": false,
  "d_model": 32,
  "embed": "timeF",
  "freq": "m",
  "dropout": 0.0,
  "enc_attn": "full",
  "factor": 50,
  "n_heads": 4,
  "d_ff": 256,
  "activation": "gelu",
  "e_layers": 2,
  "distil": false,
  "mix": true,
  "d_layers": 1
}

执行器配置信息(TrafformerExecutor.json)

{
  "gpu": true,
  "gpu_id": 0,
  "max_epoch": 100,
  "learner": "adam",
  "learning_rate": 0.002,
  "weight_decay": 0.0001,
  "lr_decay": false,
  "clip_grad_norm": true,
  "max_grad_norm": 5,
  "use_early_stop": true,
  "patience": 10,
  "gradient_accumulation_steps": 16
}

注:任务执行的时候需要将 seed 参数设置为 99

任务执行结果

模型评价指标对比如下