tallamjr / astronet

Efficient Deep Learning for Real-time Classification of Astronomical Transients and Multivariate Time-series
Apache License 2.0
14 stars 3 forks source link

[BENCHMARK] Compare results to `dl-4-tsc` on the MTS dataset #52

Closed tallamjr closed 3 years ago

tallamjr commented 3 years ago

dl-4-tsc by Fawaz et al compares deep learning classifiers on 128 univariate datasets and 12 multivariate datasets.

It would be good in the first instance to be able to compare results produced here to the multivariate case to gauge the performances on simplified multivariate data.

Show below is a snapshot of the results shown in the repository as of https://github.com/hfawaz/dl-4-tsc/commit/3ee62e16e118e4f5cfa86d01661846dfa75febfa

The following table contains the averaged accuracy over 10 runs of each implemented model on the MTS archive, with the standard deviation between parentheses.

Datasets MLP FCN ResNet Encoder MCNN t-LeNet MCDCNN Time-CNN TWIESN
AUSLAN 93.3(0.5) 97.5(0.4) 97.4(0.3) 93.8(0.5) 1.1(0.0) 1.1(0.0) 85.4(2.7) 72.6(3.5) 72.4(1.6)
ArabicDigits 96.9(0.2) 99.4(0.1) 99.6(0.1) 98.1(0.1) 10.0(0.0) 10.0(0.0) 95.9(0.2) 95.8(0.3) 85.3(1.4)
CMUsubject16 60.0(16.9) 100.0(0.0) 99.7(1.1) 98.3(2.4) 53.1(4.4) 51.0(5.3) 51.4(5.0) 97.6(1.7) 89.3(6.8)
CharacterTrajectories 96.9(0.2) 99.0(0.1) 99.0(0.2) 97.1(0.2) 5.4(0.8) 6.7(0.0) 93.8(1.7) 96.0(0.8) 92.0(1.3)
ECG 74.8(16.2) 87.2(1.2) 86.7(1.3) 87.2(0.8) 67.0(0.0) 67.0(0.0) 50.0(17.9) 84.1(1.7) 73.7(2.3)
JapaneseVowels 97.6(0.2) 99.3(0.2) 99.2(0.3) 97.6(0.6) 9.2(2.5) 23.8(0.0) 94.4(1.4) 95.6(1.0) 96.5(0.7)
KickvsPunch 61.0(12.9) 54.0(13.5) 51.0(8.8) 61.0(9.9) 54.0(9.7) 50.0(10.5) 56.0(8.4) 62.0(6.3) 67.0(14.2)
Libras 78.0(1.0) 96.4(0.7) 95.4(1.1) 78.3(0.9) 6.7(0.0) 6.7(0.0) 65.1(3.9) 63.7(3.3) 79.4(1.3)
NetFlow 55.0(26.1) 89.1(0.4) 62.7(23.4) 77.7(0.5) 77.9(0.0) 72.3(17.6) 63.0(18.2) 89.0(0.9) 94.5(0.4)
UWave 90.1(0.3) 93.4(0.3) 92.6(0.4) 90.8(0.4) 12.5(0.0) 12.5(0.0) 84.5(1.6) 85.9(0.7) 75.4(6.3)
Wafer 89.4(0.0) 98.2(0.5) 98.9(0.4) 98.6(0.2) 89.4(0.0) 89.4(0.0) 65.8(38.1) 94.8(2.1) 94.9(0.6)
WalkvsRun 70.0(15.8) 100.0(0.0) 100.0(0.0) 100.0(0.0) 75.0(0.0) 60.0(24.2) 45.0(25.8) 100.0(0.0) 94.4(9.1)
Average_Rank 5.208333 2.000000 2.875000 3.041667 7.583333 8.000000 6.833333 4.625000 4.833333
Wins 0 5 3 0 0 0 0 0 2

The MTS data has been downloaded from: http://www.mustafabaydogan.com/files/viewcategory/20-data-sets.html and the processed using dl-4-tsc/utils/utils.py with this change:

diff --git a/utils/utils.py b/utils/utils.py
index 0ef692b..c0ae7ab 100755
--- a/utils/utils.py
+++ b/utils/utils.py
@@ -219,8 +219,8 @@ def transform_to_same_length(x, n_var, max_length):

 def transform_mts_to_ucr_format():
-    mts_root_dir = '/mnt/Other/mtsdata/'
-    mts_out_dir = '/mnt/nfs/casimir/archives/mts_archive/'
+    mts_root_dir = '/Users/tallamjr/github/tallamjr/origin/astronet/data/mtsdata/'
+    mts_out_dir = '/Users/tallamjr/github/tallamjr/origin/astronet/data/transformed-mtsdata/'
     for dataset_name in MTS_DATASET_NAMES:
         # print('dataset_name',dataset_name)

Then by running: $ python main.py transform_mts_to_ucr_format N.B Empty folders with the names of the datasets found in mtsdata needed to be created first in transformerd-mtsdata before running the above command.

With this inclusion, it may be desirable to refactor how one loads data into astronet.t2.train.py and astronet.t2.opt.hypertrain.py as there will be many more to list in an if/else block now, this may be better served in astronet.t2.utils.py perhaps

tallamjr commented 3 years ago

Closed with #59