yuqinie98 / PatchTST

An offical implementation of PatchTST: "A Time Series is Worth 64 Words: Long-term Forecasting with Transformers." (ICLR 2023) https://arxiv.org/abs/2211.14730
Apache License 2.0
1.37k stars 248 forks source link

Pandas dataframe error during pre-training (version incompatible) #59

Closed koseoyoung closed 1 year ago

koseoyoung commented 1 year ago

Hi, I encountered this issue while pre-training.

python patchtst_pretrain.py --dset ettm1 --mask_ratio 0.4
args: Namespace(dset_pretrain='ettm1', context_points=512, target_points=96, batch_size=64, num_workers=0, scaler='standard', features='M', patch_len=12, stride=12, revin=1, n_layers=3, n_heads=16, d_model=128, d_ff=512, dropout=0.2, head_dropout=0.2, mask_ratio=0.4, n_epochs_pretrain=10, lr=0.0001, pretrained_model_id=1, model_type='based_model')
Traceback (most recent call last):
  File "/Users/dorothyko/network-ml/PatchTST/PatchTST_self_supervised/patchtst_pretrain.py", line 148, in <module>
    suggested_lr = find_lr()
  File "/Users/dorothyko/network-ml/PatchTST/PatchTST_self_supervised/patchtst_pretrain.py", line 95, in find_lr
    dls = get_dls(args)    
  File "/Users/dorothyko/network-ml/PatchTST/PatchTST_self_supervised/datautils.py", line 24, in get_dls
    dls = DataLoaders(
  File "/Users/dorothyko/network-ml/PatchTST/PatchTST_self_supervised/src/data/datamodule.py", line 29, in __init__
    self.train = self.train_dataloader()
  File "/Users/dorothyko/network-ml/PatchTST/PatchTST_self_supervised/src/data/datamodule.py", line 35, in train_dataloader
    return self._make_dloader("train", shuffle=self.shuffle_train)
  File "/Users/dorothyko/network-ml/PatchTST/PatchTST_self_supervised/src/data/datamodule.py", line 44, in _make_dloader
    dataset = self.datasetCls(**self.dataset_kwargs, split=split)
  File "/Users/dorothyko/network-ml/PatchTST/PatchTST_self_supervised/src/data/pred_dataset.py", line 137, in __init__
    self.__read_data__()
  File "/Users/dorothyko/network-ml/PatchTST/PatchTST_self_supervised/src/data/pred_dataset.py", line 171, in __read_data__
    data_stamp = df_stamp.drop(['date'], 1).values
TypeError: DataFrame.drop() takes from 1 to 2 positional arguments but 3 were given

According to the latest pandas documentation, we need to specify axis. This bug might be due to the version difference in pandas. (my current pandas version is pandas=2.0.2)

I can send out the fixup PR