aehrc / LAAT

A Label Attention Model for ICD Coding from Clinical Text
Other
66 stars 22 forks source link

Reproducing Joint-Laat result #5

Closed Abhinav43 closed 2 years ago

Abhinav43 commented 2 years ago

Hi, thank you for the paper. I am trying to reproduce the Joint-Laat results, I am using these parameters :

python3 -m src.run \
    --problem_name mimic-iii_2_full \
    --max_seq_length 4000 \
    --n_epoch 50 \
    --patience 5 \
    --batch_size 8 \
    --optimiser adamw \
    --lr 0.001 \
    --dropout 0.3 \
    --level_projection_size 128 \
    --main_metric micro_f1 \
    --embedding_mode word2vec \
    --embedding_file data/embeddings/word2vec_sg0_100.model \
    --attention_mode label \
    --d_a 512 \
    RNN  \
    --rnn_model LSTM \
    --n_layers 1 \
    --bidirectional 1 \
    --hidden_size 512 

Using this, I am able to get the following output:

00:32:03 INFO ======== Results at level_0 ========
00:32:03 INFO Results on Test set at epoch #8 with Loss 0.02354:
[MICRO] accuracy: 0.53093       auc: 0.98799    precision: 0.7335       recall: 0.65783 f1: 0.69361     P@1: 0  P@5: 0    P@8: 0  P@10: 0 P@15: 0
[MACRO] accuracy: 0.19297       auc: 0.93447    precision: 0.28883      recall: 0.25282 f1: 0.26963     P@1: 0.95878      P@5: 0.89359    P@8: 0.82681    P@10: 0.77921   P@15: 0.66222

00:32:03 INFO ======== Results at level_1 ========
00:32:03 INFO Results on Test set at epoch #8 with Loss 0.00541:
[MICRO] accuracy: 0.40399       auc: 0.98824    precision: 0.63812      recall: 0.52405 f1: 0.57549     P@1: 0  P@5: 0    P@8: 0  P@10: 0 P@15: 0
[MACRO] accuracy: 0.06239       auc: 0.91987    precision: 0.09447      recall: 0.08848 f1: 0.09138     P@1: 0.91163      P@5: 0.81109    P@8: 0.73873    P@10: 0.69256   P@15: 0.59085

which looks like the LAAT model result as reported in the paper, How can I reproduce Joint-LAAT model results with AUC of 92 something and f1 score of 10.7 ?

tienthanhdhcn commented 2 years ago

Thanks @Abhinav43, the number reported in the paper is the average of several random runs (detailed in section 4.3). Also, if you round the AUC up to 3 digits you should get 0.920.