Open aptx1231 opened 1 year ago
Here is how to generalize models used for point-based data for grid-based data.
(1) If the dataset class used by the model is TrafficStatePointDataset, such as AGCRN, ASTGCNCommon, CCRNN, etc., you can directly set dataset_class to TrafficStateGridDataset in task_file.json or through a custom configuration file(--config_file). Then set the parameter use_row_column of TrafficStateGridDataset to False.
(2) If the dataset class used by the model is the subclass of TrafficStatePointDataset, such as ASTGCNDataset, CONVGCNDataset, STG2SeqDataset, etc., you can modify the file of the dataset class to make it inherit TrafficStateGridDataset instead of the current TrafficStatePointDataset. Then set the parameter use_row_column in the function init() to False.
task_config.json "RNN": { "dataset_class": "TrafficStatePointDataset", }, TrafficStateGridDataset.json { "use_row_column": true }
task_config.json "RNN": { "dataset_class": "TrafficStateGridDataset", }, TrafficStateGridDataset.json { "use_row_column": false }
task_config.json "STG2Seq": { "dataset_class": "STG2SeqDataset", }, STG2SeqDataset.json { "use_row_column": false } stg2seq_dataset.py from libcity.data.dataset import TrafficStatePointDataset class STG2SeqDataset(TrafficStatePointDataset): def init(self, config): super().init(config) pass
task_config.json "STG2Seq": { "dataset_class": "STG2SeqDataset", }, STG2SeqDataset.json { "use_row_column": false } stg2seq_dataset.py from libcity.data.dataset import TrafficStateGridDataset class STG2SeqDataset(TrafficStateGridDataset): def init(self, config): super().init(config) self.use_row_column = False pass
Additional processing may be required for some special models.
For the details, please visit https://bigscity-libcity-docs.readthedocs.io/en/latest/user_guide/data/dataset_for_task.html
Note 5. There are some models that require three inputs, i.e., CLOSENESS, PERIOD, TREND, such as STResNet, ACFM, ASTGCN. for such models we implemented corresponding generalized versions with only CLOSENESS inputs for fair comparisons, i.e., STResNetCommon, ACFMCommon, ASTGCNCommon.