Closed JunwenBai closed 4 years ago
@GilbertHeywood thanks for recommending this, I have updated the readme. The reason the command provided for reuters doesn't work on nuswide_vector is because this dataset uses a vectorized input, so the graph encoder fails. So you will have to use 'mlp' for the encoder argument:
python main.py -dataset nuswide_vector -batch_size 32 -d_model 512 -d_inner_hid 512 -n_layers_enc 2 -n_layers_dec 2 -n_head 4 -epoch 50 -dropout 0.2 -dec_dropout 0.2 -lr 0.0002 -encoder 'mlp' -decoder 'graph' -label_mask 'prior'
There was however a bug on the mlp encoder for this dataset, so please pull my latest commit.
Thanks for your quick reply! After git pull and updating the running command, I got the following results, which are quite poor. Do you have any clue?
Yes, this is likely because the test script evaluates the shown metrics using the default threshold of 0.5. However, this is not the ideal threshold for most datasets/metrics. As explained in the paper and as is done in many MLC papers, we find the best threshold for each metric on the validation set. Note this is only need for the metrics that require a threshold: ACC, HA, ebF1, miF1, maF1.
We selected from the following thresholds: [0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.10,0.15,0.20,0.25,0.30,0.35,0.40,0.45,0.50,0.55,0.60,0.65,0.70,0.75,0.8,0.85,0.9,0.95]
The reason I didn't find the best threshold for each epoch is because it is very time consuming to loop through all thresholds and compute the metrics.
I left out my code which does the threshold finding for now (I will plan to add it later, but it's part of a bigger repository so doesn't exactly fit in right now). However, please pull the latest commit which uses the best epoch based on loss. You can then find the best threshold for each metric from that epoch.
Thx for your explanation. I think I found a bug in the code utils/evals.py. Inside the function compute_metrics: If you set 'graph' to be the decoder, the output would be almost all 1. After changing to
if args.decoder in ['mlp','pmlp','rnn_b','star','dual_linear','graph']:
It looks ok for now.
That's true, thank you for pointing that out. I never actually used this evals code since I did the threshold tuning using separate code, so thank you for finding this. I have since updated it.
Thank you too! You are the most responsive GitHub owner I have seen so far.
Haha, well thank you. Feel free to let me know if you have any questions. Happy to help.
Hi, Do you mind also providing the commands for other datasets besides reuters? I encountered some errors when running other datasets like nuswide_vector:
The command line I used is just:
Thx!