Closed ShamimSasani closed 1 year ago
Hello Shamim,
Thank you for your interest in the open implementation of EEGSym. I am sorry for the error that you have encountered, I have updated the repository and now the code should not show you errors.
I have just updated it, solving the problem that you were getting and others that I found. The new code you can also creates a random signal and tests the whole process with pretrained weights for fine-tuning.
I have updated the requirements and tested the code for python==3.10.
I am sorry for the late response.
Can you confirm that there are no further issues?
Kind regards, Sergio Pérez-Velasco
Hi Serpeve, thank you for your response and also the update of the code. I will check them today and get back to you for confirmation about the fixed errors. can you give me a direct email so I can contact you for further questions as well? here is my email: shamim.sasani@gmail.com
best regards, shamim
I have checked the new example code
presented in the README.md
file available. and here is the results:
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 384, 16, 1) 0 []
]
tf.expand_dims (TFOpLambda) (None, 1, 384, 16, 0 ['input_1[0][0]']
1)
tf.compat.v1.gather (TFOpLambd (None, 1, 384, 7, 1 0 ['tf.expand_dims[0][0]']
a) )
tf.compat.v1.gather_2 (TFOpLam (None, 1, 384, 2, 1 0 ['tf.expand_dims[0][0]']
bda) )
tf.compat.v1.gather_1 (TFOpLam (None, 1, 384, 7, 1 0 ['tf.expand_dims[0][0]']
bda) )
concatenate (Concatenate) (None, 1, 384, 9, 1 0 ['tf.compat.v1.gather[0][0]',
) 'tf.compat.v1.gather_2[0][0]']
concatenate_1 (Concatenate) (None, 1, 384, 9, 1 0 ['tf.compat.v1.gather_1[0][0]',
) 'tf.compat.v1.gather_2[0][0]']
concatenate_2 (Concatenate) (None, 2, 384, 9, 1 0 ['concatenate[0][0]',
) 'concatenate_1[0][0]']
conv3d (Conv3D) (None, 2, 384, 9, 2 408 ['concatenate_2[0][0]']
4)
conv3d_1 (Conv3D) (None, 2, 384, 9, 2 792 ['concatenate_2[0][0]']
4)
conv3d_2 (Conv3D) (None, 2, 384, 9, 2 1560 ['concatenate_2[0][0]']
4)
batch_normalization (BatchNorm (None, 2, 384, 9, 2 96 ['conv3d[0][0]']
alization) 4)
batch_normalization_1 (BatchNo (None, 2, 384, 9, 2 96 ['conv3d_1[0][0]']
rmalization) 4)
batch_normalization_2 (BatchNo (None, 2, 384, 9, 2 96 ['conv3d_2[0][0]']
rmalization) 4)
activation (Activation) (None, 2, 384, 9, 2 0 ['batch_normalization[0][0]']
4)
activation_1 (Activation) (None, 2, 384, 9, 2 0 ['batch_normalization_1[0][0]']
4)
activation_2 (Activation) (None, 2, 384, 9, 2 0 ['batch_normalization_2[0][0]']
4)
dropout (Dropout) (None, 2, 384, 9, 2 0 ['activation[0][0]']
4)
dropout_1 (Dropout) (None, 2, 384, 9, 2 0 ['activation_1[0][0]']
4)
dropout_2 (Dropout) (None, 2, 384, 9, 2 0 ['activation_2[0][0]']
4)
concatenate_3 (Concatenate) (None, 2, 384, 9, 7 0 ['dropout[0][0]',
2) 'dropout_1[0][0]',
'dropout_2[0][0]']
add (Add) (None, 2, 384, 9, 7 0 ['concatenate_3[0][0]',
2) 'concatenate_2[0][0]']
average_pooling3d (AveragePool (None, 2, 192, 9, 7 0 ['add[0][0]']
ing3D) 2)
conv3d_3 (Conv3D) (None, 2, 192, 1, 7 648 ['average_pooling3d[0][0]']
2)
batch_normalization_3 (BatchNo (None, 2, 192, 1, 7 288 ['conv3d_3[0][0]']
rmalization) 2)
activation_3 (Activation) (None, 2, 192, 1, 7 0 ['batch_normalization_3[0][0]']
2)
dropout_3 (Dropout) (None, 2, 192, 1, 7 0 ['activation_3[0][0]']
2)
add_1 (Add) (None, 2, 192, 9, 7 0 ['average_pooling3d[0][0]',
2) 'dropout_3[0][0]']
conv3d_6 (Conv3D) (None, 2, 192, 9, 2 6936 ['add_1[0][0]']
4)
conv3d_7 (Conv3D) (None, 2, 192, 9, 2 13848 ['add_1[0][0]']
4)
conv3d_8 (Conv3D) (None, 2, 192, 9, 2 27672 ['add_1[0][0]']
4)
batch_normalization_6 (BatchNo (None, 2, 192, 9, 2 96 ['conv3d_6[0][0]']
rmalization) 4)
batch_normalization_7 (BatchNo (None, 2, 192, 9, 2 96 ['conv3d_7[0][0]']
rmalization) 4)
batch_normalization_8 (BatchNo (None, 2, 192, 9, 2 96 ['conv3d_8[0][0]']
rmalization) 4)
activation_4 (Activation) (None, 2, 192, 9, 2 0 ['batch_normalization_6[0][0]']
4)
activation_5 (Activation) (None, 2, 192, 9, 2 0 ['batch_normalization_7[0][0]']
4)
activation_6 (Activation) (None, 2, 192, 9, 2 0 ['batch_normalization_8[0][0]']
4)
dropout_4 (Dropout) (None, 2, 192, 9, 2 0 ['activation_4[0][0]']
4)
dropout_5 (Dropout) (None, 2, 192, 9, 2 0 ['activation_5[0][0]']
4)
dropout_6 (Dropout) (None, 2, 192, 9, 2 0 ['activation_6[0][0]']
4)
concatenate_4 (Concatenate) (None, 2, 192, 9, 7 0 ['dropout_4[0][0]',
2) 'dropout_5[0][0]',
'dropout_6[0][0]']
add_2 (Add) (None, 2, 192, 9, 7 0 ['concatenate_4[0][0]',
2) 'add_1[0][0]']
average_pooling3d_1 (AveragePo (None, 2, 96, 9, 72 0 ['add_2[0][0]']
oling3D) )
conv3d_9 (Conv3D) (None, 2, 96, 1, 72 648 ['average_pooling3d_1[0][0]']
)
batch_normalization_9 (BatchNo (None, 2, 96, 1, 72 288 ['conv3d_9[0][0]']
rmalization) )
activation_7 (Activation) (None, 2, 96, 1, 72 0 ['batch_normalization_9[0][0]']
)
dropout_7 (Dropout) (None, 2, 96, 1, 72 0 ['activation_7[0][0]']
)
add_3 (Add) (None, 2, 96, 9, 72 0 ['average_pooling3d_1[0][0]',
) 'dropout_7[0][0]']
conv3d_12 (Conv3D) (None, 2, 96, 9, 36 41508 ['add_3[0][0]']
)
conv3d_15 (Conv3D) (None, 2, 96, 9, 36 2592 ['add_3[0][0]']
)
batch_normalization_12 (BatchN (None, 2, 96, 9, 36 144 ['conv3d_12[0][0]']
ormalization) )
batch_normalization_15 (BatchN (None, 2, 96, 9, 36 144 ['conv3d_15[0][0]']
ormalization) )
activation_8 (Activation) (None, 2, 96, 9, 36 0 ['batch_normalization_12[0][0]']
)
activation_9 (Activation) (None, 2, 96, 9, 36 0 ['batch_normalization_15[0][0]']
)
dropout_8 (Dropout) (None, 2, 96, 9, 36 0 ['activation_8[0][0]']
)
dropout_9 (Dropout) (None, 2, 96, 9, 36 0 ['activation_9[0][0]']
)
add_4 (Add) (None, 2, 96, 9, 36 0 ['dropout_8[0][0]',
) 'dropout_9[0][0]']
average_pooling3d_2 (AveragePo (None, 2, 48, 9, 36 0 ['add_4[0][0]']
oling3D) )
conv3d_14 (Conv3D) (None, 2, 48, 1, 36 11664 ['average_pooling3d_2[0][0]']
)
batch_normalization_14 (BatchN (None, 2, 48, 1, 36 144 ['conv3d_14[0][0]']
ormalization) )
activation_10 (Activation) (None, 2, 48, 1, 36 0 ['batch_normalization_14[0][0]']
)
dropout_10 (Dropout) (None, 2, 48, 1, 36 0 ['activation_10[0][0]']
)
add_5 (Add) (None, 2, 48, 9, 36 0 ['average_pooling3d_2[0][0]',
) 'dropout_10[0][0]']
conv3d_16 (Conv3D) (None, 2, 48, 9, 36 10404 ['add_5[0][0]']
)
conv3d_19 (Conv3D) (None, 2, 48, 9, 36 1296 ['add_5[0][0]']
)
batch_normalization_16 (BatchN (None, 2, 48, 9, 36 144 ['conv3d_16[0][0]']
ormalization) )
batch_normalization_19 (BatchN (None, 2, 48, 9, 36 144 ['conv3d_19[0][0]']
ormalization) )
activation_11 (Activation) (None, 2, 48, 9, 36 0 ['batch_normalization_16[0][0]']
)
activation_12 (Activation) (None, 2, 48, 9, 36 0 ['batch_normalization_19[0][0]']
)
dropout_11 (Dropout) (None, 2, 48, 9, 36 0 ['activation_11[0][0]']
)
dropout_12 (Dropout) (None, 2, 48, 9, 36 0 ['activation_12[0][0]']
)
add_6 (Add) (None, 2, 48, 9, 36 0 ['dropout_11[0][0]',
) 'dropout_12[0][0]']
average_pooling3d_3 (AveragePo (None, 2, 24, 9, 36 0 ['add_6[0][0]']
oling3D) )
conv3d_18 (Conv3D) (None, 2, 24, 1, 36 11664 ['average_pooling3d_3[0][0]']
)
batch_normalization_18 (BatchN (None, 2, 24, 1, 36 144 ['conv3d_18[0][0]']
ormalization) )
activation_13 (Activation) (None, 2, 24, 1, 36 0 ['batch_normalization_18[0][0]']
)
dropout_13 (Dropout) (None, 2, 24, 1, 36 0 ['activation_13[0][0]']
)
add_7 (Add) (None, 2, 24, 9, 36 0 ['average_pooling3d_3[0][0]',
) 'dropout_13[0][0]']
conv3d_20 (Conv3D) (None, 2, 24, 9, 18 2610 ['add_7[0][0]']
)
conv3d_23 (Conv3D) (None, 2, 24, 9, 18 648 ['add_7[0][0]']
)
batch_normalization_20 (BatchN (None, 2, 24, 9, 18 72 ['conv3d_20[0][0]']
ormalization) )
batch_normalization_23 (BatchN (None, 2, 24, 9, 18 72 ['conv3d_23[0][0]']
ormalization) )
activation_14 (Activation) (None, 2, 24, 9, 18 0 ['batch_normalization_20[0][0]']
)
activation_15 (Activation) (None, 2, 24, 9, 18 0 ['batch_normalization_23[0][0]']
)
dropout_14 (Dropout) (None, 2, 24, 9, 18 0 ['activation_14[0][0]']
)
dropout_15 (Dropout) (None, 2, 24, 9, 18 0 ['activation_15[0][0]']
)
```add_8 (Add) (None, 2, 24, 9, 18 0 ['dropout_14[0][0]',
) 'dropout_15[0][0]']
average_pooling3d_4 (AveragePo (None, 2, 12, 9, 18 0 ['add_8[0][0]']
oling3D) )
conv3d_22 (Conv3D) (None, 2, 12, 1, 18 2916 ['average_pooling3d_4[0][0]']
)
batch_normalization_22 (BatchN (None, 2, 12, 1, 18 72 ['conv3d_22[0][0]']
ormalization) )
activation_16 (Activation) (None, 2, 12, 1, 18 0 ['batch_normalization_22[0][0]']
)
dropout_16 (Dropout) (None, 2, 12, 1, 18 0 ['activation_16[0][0]']
)
add_9 (Add) (None, 2, 12, 9, 18 0 ['average_pooling3d_4[0][0]',
) 'dropout_16[0][0]']
conv3d_24 (Conv3D) (None, 2, 12, 9, 18 1296 ['add_9[0][0]']
)
batch_normalization_24 (BatchN (None, 2, 12, 9, 18 72 ['conv3d_24[0][0]']
ormalization) )
activation_17 (Activation) (None, 2, 12, 9, 18 0 ['batch_normalization_24[0][0]']
)
dropout_17 (Dropout) (None, 2, 12, 9, 18 0 ['activation_17[0][0]']
)
add_10 (Add) (None, 2, 12, 9, 18 0 ['add_9[0][0]',
) 'dropout_17[0][0]']
average_pooling3d_5 (AveragePo (None, 2, 6, 9, 18) 0 ['add_10[0][0]']
oling3D)
conv3d_25 (Conv3D) (None, 1, 6, 1, 18) 5832 ['average_pooling3d_5[0][0]']
batch_normalization_25 (BatchN (None, 1, 6, 1, 18) 72 ['conv3d_25[0][0]']
ormalization)
activation_18 (Activation) (None, 1, 6, 1, 18) 0 ['batch_normalization_25[0][0]']
dropout_18 (Dropout) (None, 1, 6, 1, 18) 0 ['activation_18[0][0]']
add_11 (Add) (None, 2, 6, 9, 18) 0 ['average_pooling3d_5[0][0]',
'dropout_18[0][0]']
conv3d_26 (Conv3D) (None, 1, 6, 1, 18) 5832 ['add_11[0][0]']
batch_normalization_26 (BatchN (None, 1, 6, 1, 18) 72 ['conv3d_26[0][0]']
ormalization)
activation_19 (Activation) (None, 1, 6, 1, 18) 0 ['batch_normalization_26[0][0]']
dropout_19 (Dropout) (None, 1, 6, 1, 18) 0 ['activation_19[0][0]']
add_12 (Add) (None, 2, 6, 9, 18) 0 ['add_11[0][0]',
'dropout_19[0][0]']
conv3d_27 (Conv3D) (None, 1, 6, 1, 18) 648 ['add_12[0][0]']
batch_normalization_27 (BatchN (None, 1, 6, 1, 18) 72 ['conv3d_27[0][0]']
ormalization)
activation_20 (Activation) (None, 1, 6, 1, 18) 0 ['batch_normalization_27[0][0]']
dropout_20 (Dropout) (None, 1, 6, 1, 18) 0 ['activation_20[0][0]']
conv3d_28 (Conv3D) (None, 1, 1, 1, 18) 1944 ['dropout_20[0][0]']
batch_normalization_28 (BatchN (None, 1, 1, 1, 18) 72 ['conv3d_28[0][0]']
ormalization)
activation_21 (Activation) (None, 1, 1, 1, 18) 0 ['batch_normalization_28[0][0]']
dropout_21 (Dropout) (None, 1, 1, 1, 18) 0 ['activation_21[0][0]']
add_13 (Add) (None, 1, 6, 1, 18) 0 ['dropout_20[0][0]',
'dropout_21[0][0]']
conv3d_29 (Conv3D) (None, 1, 1, 1, 36) 216 ['add_13[0][0]']
batch_normalization_29 (BatchN (None, 1, 1, 1, 36) 144 ['conv3d_29[0][0]']
ormalization)
activation_22 (Activation) (None, 1, 1, 1, 36) 0 ['batch_normalization_29[0][0]']
dropout_22 (Dropout) (None, 1, 1, 1, 36) 0 ['activation_22[0][0]']
conv3d_30 (Conv3D) (None, 1, 1, 1, 36) 1296 ['dropout_22[0][0]']
batch_normalization_30 (BatchN (None, 1, 1, 1, 36) 144 ['conv3d_30[0][0]']
ormalization)
activation_23 (Activation) (None, 1, 1, 1, 36) 0 ['batch_normalization_30[0][0]']
dropout_23 (Dropout) (None, 1, 1, 1, 36) 0 ['activation_23[0][0]']
add_14 (Add) (None, 1, 1, 1, 36) 0 ['dropout_22[0][0]',
'dropout_23[0][0]']
conv3d_31 (Conv3D) (None, 1, 1, 1, 36) 1296 ['add_14[0][0]']
batch_normalization_31 (BatchN (None, 1, 1, 1, 36) 144 ['conv3d_31[0][0]']
ormalization)
activation_24 (Activation) (None, 1, 1, 1, 36) 0 ['batch_normalization_31[0][0]']
dropout_24 (Dropout) (None, 1, 1, 1, 36) 0 ['activation_24[0][0]']
add_15 (Add) (None, 1, 1, 1, 36) 0 ['add_14[0][0]',
'dropout_24[0][0]']
conv3d_32 (Conv3D) (None, 1, 1, 1, 36) 1296 ['add_15[0][0]']
batch_normalization_32 (BatchN (None, 1, 1, 1, 36) 144 ['conv3d_32[0][0]']
ormalization)
activation_25 (Activation) (None, 1, 1, 1, 36) 0 ['batch_normalization_32[0][0]']
dropout_25 (Dropout) (None, 1, 1, 1, 36) 0 ['activation_25[0][0]']
add_16 (Add) (None, 1, 1, 1, 36) 0 ['add_15[0][0]',
'dropout_25[0][0]']
conv3d_33 (Conv3D) (None, 1, 1, 1, 36) 1296 ['add_16[0][0]']
batch_normalization_33 (BatchN (None, 1, 1, 1, 36) 144 ['conv3d_33[0][0]']
ormalization)
activation_26 (Activation) (None, 1, 1, 1, 36) 0 ['batch_normalization_33[0][0]']
dropout_26 (Dropout) (None, 1, 1, 1, 36) 0 ['activation_26[0][0]']
add_17 (Add) (None, 1, 1, 1, 36) 0 ['add_16[0][0]',
'dropout_26[0][0]']
flatten (Flatten) (None, 36) 0 ['add_17[0][0]']
dense (Dense) (None, 2) 74 ['flatten[0][0]']
==================================================================================================
Total params: 162,152
Trainable params: 160,496
Non-trainable params: 1,656
__________________________________________________________________________________________________
Epoch 1/500
18/18 [==============================] - 204s 11s/step - loss: 1.6567 - accuracy: 0.5083 - val_loss: 0.7143 - val_accuracy: 0.4550
Epoch 2/500
18/18 [==============================] - 190s 10s/step - loss: 1.6622 - accuracy: 0.4833 - val_loss: 0.7552 - val_accuracy: 0.5000
Epoch 3/500
18/18 [==============================] - 189s 10s/step - loss: 1.6357 - accuracy: 0.4783 - val_loss: 0.8154 - val_accuracy: 0.5100
Epoch 4/500
18/18 [==============================] - 187s 10s/step - loss: 1.5270 - accuracy: 0.5050 - val_loss: 0.8799 - val_accuracy: 0.5250
Epoch 5/500
18/18 [==============================] - 185s 10s/step - loss: 1.4210 - accuracy: 0.5000 - val_loss: 0.9373 - val_accuracy: 0.5200
Epoch 6/500
18/18 [==============================] - 183s 10s/step - loss: 1.4004 - accuracy: 0.4967 - val_loss: 0.9849 - val_accuracy: 0.5250
Epoch 7/500
18/18 [==============================] - 181s 10s/step - loss: 1.3900 - accuracy: 0.4883 - val_loss: 1.0248 - val_accuracy: 0.5250
Epoch 8/500
18/18 [==============================] - 177s 9s/step - loss: 1.3331 - accuracy: 0.4850 - val_loss: 1.0522 - val_accuracy: 0.5200
Epoch 9/500
18/18 [==============================] - 176s 9s/step - loss: 1.3708 - accuracy: 0.4733 - val_loss: 1.0724 - val_accuracy: 0.5150
Epoch 10/500
18/18 [==============================] - 176s 9s/step - loss: 1.2359 - accuracy: 0.4883 - val_loss: 1.0799 - val_accuracy: 0.5050
Epoch 11/500
18/18 [==============================] - 175s 9s/step - loss: 1.2221 - accuracy: 0.5000 - val_loss: 1.0771 - val_accuracy: 0.5050
Epoch 12/500
18/18 [==============================] - 174s 9s/step - loss: 1.0934 - accuracy: 0.5233 - val_loss: 1.0807 - val_accuracy: 0.5000
Epoch 13/500
18/18 [==============================] - 175s 9s/step - loss: 1.1810 - accuracy: 0.4900 - val_loss: 1.0726 - val_accuracy: 0.5000
Epoch 14/500
18/18 [==============================] - 175s 9s/step - loss: 1.1158 - accuracy: 0.5117 - val_loss: 1.0626 - val_accuracy: 0.5000
Epoch 15/500
18/18 [==============================] - 177s 9s/step - loss: 1.0876 - accuracy: 0.4967 - val_loss: 1.0460 - val_accuracy: 0.5000
Epoch 16/500
18/18 [==============================] - 176s 9s/step - loss: 0.9918 - accuracy: 0.5300 - val_loss: 1.0278 - val_accuracy: 0.5000
Epoch 17/500
18/18 [==============================] - 183s 10s/step - loss: 0.9795 - accuracy: 0.5250 - val_loss: 1.0102 - val_accuracy: 0.4950
Epoch 18/500
18/18 [==============================] - 186s 10s/step - loss: 0.9753 - accuracy: 0.5283 - val_loss: 0.9909 - val_accuracy: 0.4900
Epoch 19/500
18/18 [==============================] - 188s 10s/step - loss: 0.9106 - accuracy: 0.5283 - val_loss: 0.9738 - val_accuracy: 0.5050
Epoch 20/500
18/18 [==============================] - 193s 10s/step - loss: 0.9821 - accuracy: 0.5217 - val_loss: 0.9567 - val_accuracy: 0.4900
Epoch 21/500
18/18 [==============================] - 182s 10s/step - loss: 0.9361 - accuracy: 0.4883 - val_loss: 0.9362 - val_accuracy: 0.4900
Epoch 22/500
18/18 [==============================] - 183s 10s/step - loss: 0.9371 - accuracy: 0.5017 - val_loss: 0.9185 - val_accuracy: 0.4800
Epoch 23/500
18/18 [==============================] - 179s 10s/step - loss: 0.8806 - accuracy: 0.5133 - val_loss: 0.9018 - val_accuracy: 0.4800
Epoch 24/500
18/18 [==============================] - 177s 9s/step - loss: 0.8786 - accuracy: 0.4967 - val_loss: 0.8869 - val_accuracy: 0.4800
Epoch 25/500
18/18 [==============================] - 175s 9s/step - loss: 0.8668 - accuracy: 0.5033 - val_loss: 0.8689 - val_accuracy: 0.4850
Epoch 26/500
19/18 [==============================] - ETA: -2s - loss: 0.8193 - accuracy: 0.5233Restoring model weights from the end of the best epoch: 1.
18/18 [==============================] - 175s 9s/step - loss: 0.8193 - accuracy: 0.5233 - val_loss: 0.8558 - val_accuracy: 0.4900
Epoch 26: early stopping
7/7 [==============================] - 11s 2s/step
so far it got no errors and this is the best it got and then eary_stopping
worked.
I am training the model with a real small data set just now, I've removed the callbacks=[early_stopping]
and reduced the epochs
to 200 for faster results. dose it going to lower the expected accuracy you mentioned in the paper? because the other hyperparameters
are the same as you set them.
looking forward to hearing from you. Best regards, shamim.
Hi, I have some problems during creating the model with the pretrained weights as follow : when I copy and paste the "Example of Use" without any change , firstly I got this error that the "
model_hyparams
" is not defined. then I defined it asmodel_hyparams= dict( )
, and it continued. however, after reaching the "model.load_weights('EEGSym_pretrained_weights_{}_electrode.h5'.format(ncha))
" I got this error:how can I fix this ?
another question: " why in the model summary, I got the dimension of the model's
Input_layer
as a 4 with "None
" in theaxis=0
, is that correct? if it is correct, how can I make myx_train
in this form ?https://colab.research.google.com/drive/1b0LI54aT7p4sWQQ87EgDKEa0-8Hgts_9?usp=sharing
looking forward to hearing from you. shamim.