Open glicerico opened 3 years ago
@glicerico, I've just created a new fork from yours. I made some changes to the code related to running on gpu. Also added annotations to the classes. And added evaluation of the model on a test dataset. The problem is, when using your trained checkpoint, my eval method gives me quite poor results. May I ask you to try it on your side and see if there is any problem with your trained checkpoint or with my code? My fork is here https://github.com/PolKul/CASA-Dialogue-Act-Classifier.git
here is the result of running Eval on your checkpoint "epoch=29-val_accuracy=0.751411.ckpt"
precision recall f1-score support
0 0.59 0.81 0.68 360
1 0.00 0.00 0.00 328
2 0.00 0.00 0.00 19
3 0.00 0.00 0.00 0
4 0.00 0.00 0.00 7
5 0.00 0.00 0.00 17
6 0.00 0.00 0.00 208
7 0.00 0.00 0.00 7
8 0.00 0.00 0.00 27
9 0.00 0.00 0.00 3
10 0.00 0.00 0.00 3
11 0.00 0.00 0.00 765
12 0.00 0.00 0.00 21
13 0.00 0.00 0.00 76
14 0.00 0.00 0.00 1
15 0.00 0.00 0.00 23
16 0.00 0.00 0.00 21
17 0.00 0.00 0.00 28
18 0.00 0.00 0.00 9
19 0.00 0.00 0.00 2
20 0.37 0.09 0.14 81
21 0.00 0.00 0.00 16
22 0.00 0.00 0.00 5
23 0.00 0.00 0.00 7
24 0.00 0.00 0.00 23
25 0.00 0.00 0.00 10
26 0.00 0.00 0.00 6
27 0.00 0.00 0.00 26
28 0.00 0.00 0.00 6
29 0.00 0.00 0.00 73
30 0.00 0.00 0.00 0
31 0.00 0.00 0.00 12
32 0.00 0.00 0.00 16
33 0.00 0.00 0.00 2
34 0.00 0.00 0.00 55
35 0.00 0.00 0.00 1
36 0.00 0.00 0.00 84
37 0.01 0.33 0.01 36
38 0.20 0.11 0.14 1317
39 0.00 0.00 0.00 718
40 0.00 0.00 0.00 1
41 0.00 0.00 0.00 0
42 0.00 0.00 0.00 94
accuracy 0.10 4514
macro avg 0.03 0.03 0.02 4514
weighted avg 0.11 0.10 0.10 4514
It shows accuracy of only 10%...
Hey @PolKul , as commented in one of the issues, that checkpoint was trained before the classes problem was solved, so it probably is using the wrong labels.
Do you have a newer checkpoint, like epoch=5-val_accuracy=0.779101.ckpt
that I uploaded before (and removed bc my dropbox was full)?
@PolKul I uploaded it here again... please try your evaluation with this checkpoint and let me know. I'll probably remove the file in a day or 2, unless someone wants to host it somewhere else :) https://www.dropbox.com/s/e88ymjfej80zabs/epoch%3D28-val_accuracy%3D0.746056.ckpt?dl=0
Oh, @PolKul , I am just noticing that you are using your own class label numbering... so it's expected that the predictions won't match.
In the following line:
https://github.com/PolKul/CASA-Dialogue-Act-Classifier/blob/32214d64d556505424b1efe54905371e7f417dcb/predict.py#L130
you give an arbitrary number to each tag, based on enumerate and the order in which you defined the class labels in dataset.py
. The checkpoint was trained using the sorted list of tags from the training set, as it was suggested by @Christopher-Thornton:
https://github.com/glicerico/CASA-Dialogue-Act-Classifier/blob/92400edff9e0ab724d545d4495346e5eae4cd77e/dataset/dataset.py#L18
So, you probably should leave the classes as it was proposed in my pull request, or train a model with the label order that you prefer :)
@glicerico, thank you for your review. However my question was more about the eval() method of the DialogClassifier class. As you can see it doesn't use my annotated classes (act_label_names list) in any way and still produces really bad results (0.1 F1 score). To avoid confusion, you can add the same eval method to your branch and try running it. Let me know if you can see any better statistics from it?
You're right, I see that you only use act_label_names
to print (to print incorrectly, as the classes in act_label_names
are numbered differently from the predictions).
The other point I made above still remains: that model was trained when there were some problems with class numbering.
Sorry, but I would prefer not to have to checkout your branch, figure it out, and run it, until you explore all possible reasons we see :)
Sorry, but I don't see where you see the problem with the act_label_names.
to print incorrectly, as the classes in act_label_names are numbered differently from the predictions
that is a dictionary, with the following structure: ["name","act_tag","example"]. The code below is finding a "name" by "act_tag":
for utterance, prediction in zip(utterances, predicted_acts):
for index, act_tag in enumerate(act_label_names['act_tag']):
if act_tag == prediction:
print(f"{prediction}({utterance})-> {act_label_names['name'][index]}")
Or you mean that "prediction" is incorrectly labeled?
After your past comment, I don't see a problem with act_label_names
.
I am talking about the 2 posts prior to that:
Hey @PolKul , as commented in one of the issues, that checkpoint was trained before the classes problem was solved, so it probably is using the wrong labels. Do you have a newer checkpoint, like epoch=5-val_accuracy=0.779101.ckpt that I uploaded before (and removed bc my dropbox was full)?
@PolKul I uploaded it here again... please try your evaluation with this checkpoint and let me know. I'll probably remove the file in a day or 2, unless someone wants to host it somewhere else :) https://www.dropbox.com/s/e88ymjfej80zabs/epoch%3D28-val_accuracy%3D0.746056.ckpt?dl=0
Or you mean that "prediction" is incorrectly labeled?
I mean that prediction is labeled differently
I confirm that both "epoch=28-val_accuracy=0.746056.ckpt" and "epoch=29-val_accuracy=0.751411.ckpt" give the same (bad) results with F1 score of 0.1
It would be interesting to see the results of your eval()...
Hi @PolKul , these are the results I got using the best checkpoint I trained, with unfrozen Roberta weights. I invite you to use that checkpoint, I uploaded it here: https://www.dropbox.com/s/1zj4vq59z9h6re3/epoch%3D5-val_accuracy%3D0.779101.ckpt?dl=0 Please let me know when you get it, so I don't have my dropbox account completely full.
Eval on Test dataset
-------------------------------------
100%|██████████| 64/64 [24:00<00:00, 22.50s/it]
/home/andres/src/miniconda3/envs/CASA/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, msg_start, len(result))
precision recall f1-score support
0 0.78 0.84 0.81 360
1 0.08 0.05 0.06 19
3 0.33 0.14 0.20 7
4 1.00 0.06 0.11 17
5 0.74 0.46 0.57 208
6 0.00 0.00 0.00 7
7 0.41 0.26 0.32 27
8 0.00 0.00 0.00 3
9 0.00 0.00 0.00 3
10 0.79 0.92 0.85 765
11 0.50 0.05 0.09 21
12 0.68 0.84 0.75 76
13 0.00 0.00 0.00 1
14 1.00 0.04 0.08 23
15 0.67 0.86 0.75 21
16 0.47 0.32 0.38 28
17 0.67 0.44 0.53 9
18 1.00 1.00 1.00 2
19 0.84 0.59 0.70 81
20 0.33 0.19 0.24 16
21 0.75 0.60 0.67 5
22 0.00 0.00 0.00 7
23 0.82 0.61 0.70 23
24 0.33 0.10 0.15 10
25 0.50 0.33 0.40 6
26 0.81 0.81 0.81 26
27 0.33 0.17 0.22 6
28 0.79 0.62 0.69 73
30 0.11 0.08 0.10 12
31 0.65 0.81 0.72 16
32 1.00 1.00 1.00 2
33 0.73 0.75 0.74 55
34 0.00 0.00 0.00 1
35 0.69 0.80 0.74 84
36 0.00 0.00 0.00 36
37 0.81 0.85 0.83 1317
38 0.69 0.72 0.71 718
39 0.00 0.00 0.00 1
accuracy 0.76 4092
macro avg 0.51 0.40 0.42 4092
weighted avg 0.75 0.76 0.74 4092
Hi @glicerico, thanks for the checkpoint and eval. I've just updated the repo from your latest inference branch and it worked! Not sure what was wrong with my previous code though.. any way, thank you for your assistance.
Hi @PolKul , these are the results I got using the best checkpoint I trained, with unfrozen Roberta weights. I invite you to use that checkpoint, I uploaded it here: https://www.dropbox.com/s/1zj4vq59z9h6re3/epoch%3D5-val_accuracy%3D0.779101.ckpt?dl=0 Please let me know when you get it, so I don't have my dropbox account completely full.
Eval on Test dataset ------------------------------------- 100%|██████████| 64/64 [24:00<00:00, 22.50s/it] /home/andres/src/miniconda3/envs/CASA/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) precision recall f1-score support 0 0.78 0.84 0.81 360 1 0.08 0.05 0.06 19 3 0.33 0.14 0.20 7 4 1.00 0.06 0.11 17 5 0.74 0.46 0.57 208 6 0.00 0.00 0.00 7 7 0.41 0.26 0.32 27 8 0.00 0.00 0.00 3 9 0.00 0.00 0.00 3 10 0.79 0.92 0.85 765 11 0.50 0.05 0.09 21 12 0.68 0.84 0.75 76 13 0.00 0.00 0.00 1 14 1.00 0.04 0.08 23 15 0.67 0.86 0.75 21 16 0.47 0.32 0.38 28 17 0.67 0.44 0.53 9 18 1.00 1.00 1.00 2 19 0.84 0.59 0.70 81 20 0.33 0.19 0.24 16 21 0.75 0.60 0.67 5 22 0.00 0.00 0.00 7 23 0.82 0.61 0.70 23 24 0.33 0.10 0.15 10 25 0.50 0.33 0.40 6 26 0.81 0.81 0.81 26 27 0.33 0.17 0.22 6 28 0.79 0.62 0.69 73 30 0.11 0.08 0.10 12 31 0.65 0.81 0.72 16 32 1.00 1.00 1.00 2 33 0.73 0.75 0.74 55 34 0.00 0.00 0.00 1 35 0.69 0.80 0.74 84 36 0.00 0.00 0.00 36 37 0.81 0.85 0.83 1317 38 0.69 0.72 0.71 718 39 0.00 0.00 0.00 1 accuracy 0.76 4092 macro avg 0.51 0.40 0.42 4092 weighted avg 0.75 0.76 0.74 4092
Dear @glicerico, may I ask if you can re-upload the checkpoint? Somehow I don't get the results, (and my inference speed is so slow when using yours, do you know why?)
@minarainbow , you can find the checkpoint at https://www.dropbox.com/s/egiv70dwl1ikrbq/epoch%3D5-val_accuracy%3D0.779101.ckpt?dl=0, I'll remove it from there in a couple days.
@macabdul9 sorry for the large PR, but I had to accumulate improvements to achieve proper inference. Summary of changes:
+
, as those are continuations of interrupted utterances. Unless these are somehow joined back to their initial utterance, I believe they are useless... See discussion here.