gzerveas / mvts_transformer

Multivariate Time Series Transformer, public version
MIT License
718 stars 169 forks source link

M2 / sktime 0.22 / Python 3.11 #56

Open richarddli opened 10 months ago

richarddli commented 10 months ago

I was able to get the regression example running with the following branches, plus a weird monkey patch on PyTorch, which was to remove the is_causal=is_causal flag on line 315 in transformers.py. (Apparently is_causal is a hint, but PyTorch kept on complaining how is_causal was an unexpected kwarg.)

This took me a few hours to debug, so I thought I'd share this with others in case they're trying out this library. I wouldn't merge this as-is due to the PyTorch monkey patch required.

BUT ... with the patches below, I'm able to run on Apple Silicon with a fully modern Python & dependency stack. (The problem with the pinned versions is that there are no wheels of the older binaries for Apple Silicon.)

richarddli commented 10 months ago

I shared this too soon! With the BeijingPM 25 dataset, I get:

2023-08-24 15:15:29,239 | INFO : Epoch 0 Validation Summary: epoch: 0.000000 | loss: 31485.421528 |
Traceback (most recent call last):
  File "/Users/rdl/amorphous/mvts_transformer/src/main.py", line 307, in <module>
    main(config)
  File "/Users/rdl/amorphous/mvts_transformer/src/main.py", line 235, in main
    aggr_metrics_val, best_metrics, best_value = validate(val_evaluator, tensorboard_writer, config, best_metrics,
                                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/rdl/amorphous/mvts_transformer/src/running.py", line 222, in validate
    np.savez(pred_filepath, **per_batch)
  File "/Users/rdl/miniconda3/envs/mvts2/lib/python3.11/site-packages/numpy/lib/npyio.py", line 639, in savez
    _savez(file, args, kwds, False)
  File "/Users/rdl/miniconda3/envs/mvts2/lib/python3.11/site-packages/numpy/lib/npyio.py", line 740, in _savez
    val = np.asanyarray(val)
          ^^^^^^^^^^^^^^^^^^
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (79,) + inhomogeneous part.
richarddli commented 10 months ago

The proximate cause of the error above is because the last nparray in the list of predictions & targets is not the same as the other nparrays in the list.

I've implemented a hack which drops the last array. I ran the AppEnergy data set with and without the patch, and the accuracy doesn't seem to be affected dramatically. If someone else has insight as to the correct fix, I'd be grateful!

In the meantime, I'm implementing my own custom data classes ...