LibCity / Bigscity-LibCity

LibCity: An Open Library for Urban Spatial-temporal Data Mining
https://libcity.ai/
Apache License 2.0
886 stars 163 forks source link

How to adapt the models to grid data and the correspondence between data and models? #368

Open aptx1231 opened 11 months ago

aptx1231 commented 11 months ago

For the details, please visit https://bigscity-libcity-docs.readthedocs.io/en/latest/user_guide/data/dataset_for_task.html

Note 5. There are some models that require three inputs, i.e., CLOSENESS, PERIOD, TREND, such as STResNet, ACFM, ASTGCN. for such models we implemented corresponding generalized versions with only CLOSENESS inputs for fair comparisons, i.e., STResNetCommon, ACFMCommon, ASTGCNCommon.

aptx1231 commented 11 months ago

Here is how to generalize models used for point-based data for grid-based data.

(1) If the dataset class used by the model is TrafficStatePointDataset, such as AGCRN, ASTGCNCommon, CCRNN, etc., you can directly set dataset_class to TrafficStateGridDataset in task_file.json or through a custom configuration file(--config_file). Then set the parameter use_row_column of TrafficStateGridDataset to False.

(2) If the dataset class used by the model is the subclass of TrafficStatePointDataset, such as ASTGCNDataset, CONVGCNDataset, STG2SeqDataset, etc., you can modify the file of the dataset class to make it inherit TrafficStateGridDataset instead of the current TrafficStatePointDataset. Then set the parameter use_row_column in the function init() to False.

Example (1):

Before modification:

task_config.json "RNN": { "dataset_class": "TrafficStatePointDataset", }, TrafficStateGridDataset.json { "use_row_column": true }

After modification:

task_config.json "RNN": { "dataset_class": "TrafficStateGridDataset", }, TrafficStateGridDataset.json { "use_row_column": false }

Example (2)::

Before modification:

task_config.json "STG2Seq": { "dataset_class": "STG2SeqDataset", }, STG2SeqDataset.json { "use_row_column": false } stg2seq_dataset.py from libcity.data.dataset import TrafficStatePointDataset class STG2SeqDataset(TrafficStatePointDataset): def init(self, config): super().init(config) pass

After modification:

task_config.json "STG2Seq": { "dataset_class": "STG2SeqDataset", }, STG2SeqDataset.json { "use_row_column": false } stg2seq_dataset.py from libcity.data.dataset import TrafficStateGridDataset class STG2SeqDataset(TrafficStateGridDataset): def init(self, config): super().init(config) self.use_row_column = False pass

aptx1231 commented 11 months ago
image image
aptx1231 commented 11 months ago

Additional processing may be required for some special models.