macabdul9 / CASA-Dialogue-Act-Classifier

PyTorch implementation of the paper "Dialogue Act Classification with Context-Aware Self-Attention" for dialogue act classification with a generic dataset class and PyTorch-Lightning trainer
MIT License
44 stars 13 forks source link

Inference #12

Open glicerico opened 3 years ago

glicerico commented 3 years ago

@macabdul9 sorry for the large PR, but I had to accumulate improvements to achieve proper inference. Summary of changes:

PolKul commented 3 years ago

@glicerico, I've just created a new fork from yours. I made some changes to the code related to running on gpu. Also added annotations to the classes. And added evaluation of the model on a test dataset. The problem is, when using your trained checkpoint, my eval method gives me quite poor results. May I ask you to try it on your side and see if there is any problem with your trained checkpoint or with my code? My fork is here https://github.com/PolKul/CASA-Dialogue-Act-Classifier.git

here is the result of running Eval on your checkpoint "epoch=29-val_accuracy=0.751411.ckpt"

              precision    recall  f1-score   support

           0       0.59      0.81      0.68       360
           1       0.00      0.00      0.00       328
           2       0.00      0.00      0.00        19
           3       0.00      0.00      0.00         0
           4       0.00      0.00      0.00         7
           5       0.00      0.00      0.00        17
           6       0.00      0.00      0.00       208
           7       0.00      0.00      0.00         7
           8       0.00      0.00      0.00        27
           9       0.00      0.00      0.00         3
          10       0.00      0.00      0.00         3
          11       0.00      0.00      0.00       765
          12       0.00      0.00      0.00        21
          13       0.00      0.00      0.00        76
          14       0.00      0.00      0.00         1
          15       0.00      0.00      0.00        23
          16       0.00      0.00      0.00        21
          17       0.00      0.00      0.00        28
          18       0.00      0.00      0.00         9
          19       0.00      0.00      0.00         2
          20       0.37      0.09      0.14        81
          21       0.00      0.00      0.00        16
          22       0.00      0.00      0.00         5
          23       0.00      0.00      0.00         7
          24       0.00      0.00      0.00        23
          25       0.00      0.00      0.00        10
          26       0.00      0.00      0.00         6
          27       0.00      0.00      0.00        26
          28       0.00      0.00      0.00         6
          29       0.00      0.00      0.00        73
          30       0.00      0.00      0.00         0
          31       0.00      0.00      0.00        12
          32       0.00      0.00      0.00        16
          33       0.00      0.00      0.00         2
          34       0.00      0.00      0.00        55
          35       0.00      0.00      0.00         1
          36       0.00      0.00      0.00        84
          37       0.01      0.33      0.01        36
          38       0.20      0.11      0.14      1317
          39       0.00      0.00      0.00       718
          40       0.00      0.00      0.00         1
          41       0.00      0.00      0.00         0
          42       0.00      0.00      0.00        94

    accuracy                           0.10      4514
   macro avg       0.03      0.03      0.02      4514
weighted avg       0.11      0.10      0.10      4514

It shows accuracy of only 10%...

glicerico commented 3 years ago

Hey @PolKul , as commented in one of the issues, that checkpoint was trained before the classes problem was solved, so it probably is using the wrong labels. Do you have a newer checkpoint, like epoch=5-val_accuracy=0.779101.ckpt that I uploaded before (and removed bc my dropbox was full)?

glicerico commented 3 years ago

@PolKul I uploaded it here again... please try your evaluation with this checkpoint and let me know. I'll probably remove the file in a day or 2, unless someone wants to host it somewhere else :) https://www.dropbox.com/s/e88ymjfej80zabs/epoch%3D28-val_accuracy%3D0.746056.ckpt?dl=0

glicerico commented 3 years ago

Oh, @PolKul , I am just noticing that you are using your own class label numbering... so it's expected that the predictions won't match. In the following line: https://github.com/PolKul/CASA-Dialogue-Act-Classifier/blob/32214d64d556505424b1efe54905371e7f417dcb/predict.py#L130 you give an arbitrary number to each tag, based on enumerate and the order in which you defined the class labels in dataset.py. The checkpoint was trained using the sorted list of tags from the training set, as it was suggested by @Christopher-Thornton: https://github.com/glicerico/CASA-Dialogue-Act-Classifier/blob/92400edff9e0ab724d545d4495346e5eae4cd77e/dataset/dataset.py#L18

So, you probably should leave the classes as it was proposed in my pull request, or train a model with the label order that you prefer :)

PolKul commented 3 years ago

@glicerico, thank you for your review. However my question was more about the eval() method of the DialogClassifier class. As you can see it doesn't use my annotated classes (act_label_names list) in any way and still produces really bad results (0.1 F1 score). To avoid confusion, you can add the same eval method to your branch and try running it. Let me know if you can see any better statistics from it?

glicerico commented 3 years ago

You're right, I see that you only use act_label_names to print (to print incorrectly, as the classes in act_label_names are numbered differently from the predictions). The other point I made above still remains: that model was trained when there were some problems with class numbering. Sorry, but I would prefer not to have to checkout your branch, figure it out, and run it, until you explore all possible reasons we see :)

PolKul commented 3 years ago

Sorry, but I don't see where you see the problem with the act_label_names.

to print incorrectly, as the classes in act_label_names are numbered differently from the predictions

that is a dictionary, with the following structure: ["name","act_tag","example"]. The code below is finding a "name" by "act_tag":

for utterance, prediction in zip(utterances, predicted_acts):
    for index, act_tag in enumerate(act_label_names['act_tag']):
        if act_tag == prediction:
            print(f"{prediction}({utterance})-> {act_label_names['name'][index]}")

Or you mean that "prediction" is incorrectly labeled?

glicerico commented 3 years ago

After your past comment, I don't see a problem with act_label_names. I am talking about the 2 posts prior to that:

Hey @PolKul , as commented in one of the issues, that checkpoint was trained before the classes problem was solved, so it probably is using the wrong labels. Do you have a newer checkpoint, like epoch=5-val_accuracy=0.779101.ckpt that I uploaded before (and removed bc my dropbox was full)?

@PolKul I uploaded it here again... please try your evaluation with this checkpoint and let me know. I'll probably remove the file in a day or 2, unless someone wants to host it somewhere else :) https://www.dropbox.com/s/e88ymjfej80zabs/epoch%3D28-val_accuracy%3D0.746056.ckpt?dl=0

glicerico commented 3 years ago

Or you mean that "prediction" is incorrectly labeled?

I mean that prediction is labeled differently

PolKul commented 3 years ago

I confirm that both "epoch=28-val_accuracy=0.746056.ckpt" and "epoch=29-val_accuracy=0.751411.ckpt" give the same (bad) results with F1 score of 0.1

It would be interesting to see the results of your eval()...

glicerico commented 3 years ago

Hi @PolKul , these are the results I got using the best checkpoint I trained, with unfrozen Roberta weights. I invite you to use that checkpoint, I uploaded it here: https://www.dropbox.com/s/1zj4vq59z9h6re3/epoch%3D5-val_accuracy%3D0.779101.ckpt?dl=0 Please let me know when you get it, so I don't have my dropbox account completely full.

Eval on Test dataset
-------------------------------------
100%|██████████| 64/64 [24:00<00:00, 22.50s/it]
/home/andres/src/miniconda3/envs/CASA/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
              precision    recall  f1-score   support

           0       0.78      0.84      0.81       360
           1       0.08      0.05      0.06        19
           3       0.33      0.14      0.20         7
           4       1.00      0.06      0.11        17
           5       0.74      0.46      0.57       208
           6       0.00      0.00      0.00         7
           7       0.41      0.26      0.32        27
           8       0.00      0.00      0.00         3
           9       0.00      0.00      0.00         3
          10       0.79      0.92      0.85       765
          11       0.50      0.05      0.09        21
          12       0.68      0.84      0.75        76
          13       0.00      0.00      0.00         1
          14       1.00      0.04      0.08        23
          15       0.67      0.86      0.75        21
          16       0.47      0.32      0.38        28
          17       0.67      0.44      0.53         9
          18       1.00      1.00      1.00         2
          19       0.84      0.59      0.70        81
          20       0.33      0.19      0.24        16
          21       0.75      0.60      0.67         5
          22       0.00      0.00      0.00         7
          23       0.82      0.61      0.70        23
          24       0.33      0.10      0.15        10
          25       0.50      0.33      0.40         6
          26       0.81      0.81      0.81        26
          27       0.33      0.17      0.22         6
          28       0.79      0.62      0.69        73
          30       0.11      0.08      0.10        12
          31       0.65      0.81      0.72        16
          32       1.00      1.00      1.00         2
          33       0.73      0.75      0.74        55
          34       0.00      0.00      0.00         1
          35       0.69      0.80      0.74        84
          36       0.00      0.00      0.00        36
          37       0.81      0.85      0.83      1317
          38       0.69      0.72      0.71       718
          39       0.00      0.00      0.00         1

    accuracy                           0.76      4092
   macro avg       0.51      0.40      0.42      4092
weighted avg       0.75      0.76      0.74      4092
PolKul commented 3 years ago

Hi @glicerico, thanks for the checkpoint and eval. I've just updated the repo from your latest inference branch and it worked! Not sure what was wrong with my previous code though.. any way, thank you for your assistance.

minarainbow commented 3 years ago

Hi @PolKul , these are the results I got using the best checkpoint I trained, with unfrozen Roberta weights. I invite you to use that checkpoint, I uploaded it here: https://www.dropbox.com/s/1zj4vq59z9h6re3/epoch%3D5-val_accuracy%3D0.779101.ckpt?dl=0 Please let me know when you get it, so I don't have my dropbox account completely full.

Eval on Test dataset
-------------------------------------
100%|██████████| 64/64 [24:00<00:00, 22.50s/it]
/home/andres/src/miniconda3/envs/CASA/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
              precision    recall  f1-score   support

           0       0.78      0.84      0.81       360
           1       0.08      0.05      0.06        19
           3       0.33      0.14      0.20         7
           4       1.00      0.06      0.11        17
           5       0.74      0.46      0.57       208
           6       0.00      0.00      0.00         7
           7       0.41      0.26      0.32        27
           8       0.00      0.00      0.00         3
           9       0.00      0.00      0.00         3
          10       0.79      0.92      0.85       765
          11       0.50      0.05      0.09        21
          12       0.68      0.84      0.75        76
          13       0.00      0.00      0.00         1
          14       1.00      0.04      0.08        23
          15       0.67      0.86      0.75        21
          16       0.47      0.32      0.38        28
          17       0.67      0.44      0.53         9
          18       1.00      1.00      1.00         2
          19       0.84      0.59      0.70        81
          20       0.33      0.19      0.24        16
          21       0.75      0.60      0.67         5
          22       0.00      0.00      0.00         7
          23       0.82      0.61      0.70        23
          24       0.33      0.10      0.15        10
          25       0.50      0.33      0.40         6
          26       0.81      0.81      0.81        26
          27       0.33      0.17      0.22         6
          28       0.79      0.62      0.69        73
          30       0.11      0.08      0.10        12
          31       0.65      0.81      0.72        16
          32       1.00      1.00      1.00         2
          33       0.73      0.75      0.74        55
          34       0.00      0.00      0.00         1
          35       0.69      0.80      0.74        84
          36       0.00      0.00      0.00        36
          37       0.81      0.85      0.83      1317
          38       0.69      0.72      0.71       718
          39       0.00      0.00      0.00         1

    accuracy                           0.76      4092
   macro avg       0.51      0.40      0.42      4092
weighted avg       0.75      0.76      0.74      4092

Dear @glicerico, may I ask if you can re-upload the checkpoint? Somehow I don't get the results, (and my inference speed is so slow when using yours, do you know why?)

glicerico commented 3 years ago

@minarainbow , you can find the checkpoint at https://www.dropbox.com/s/egiv70dwl1ikrbq/epoch%3D5-val_accuracy%3D0.779101.ckpt?dl=0, I'll remove it from there in a couple days.