NTMC-Community / MatchZoo-py

Facilitating the design, comparison and sharing of deep text matching models.
Apache License 2.0
495 stars 106 forks source link

How to get the actual rank using trainer.predict()? #136

Open littlewine opened 4 years ago

littlewine commented 4 years ago

Describe the Question

I am trying to get the rank out of a trained model (using trainer). However, when I do trainer.predict() I get back a numpy array of shape num_qids x 1. The number of query ids .predict returns is depending on the dataloader dl passed on trainer.predict(dl).

In other words, as I understand I get a score (probably the first metric I've defined on metrics?) for each query id. However, what I need is a ranked list of documents for each query id, rather than a single score.

How can I get that? I could find no solution through the tutorials.

My code looks like:


    trainer.run()

    # Evaluation
    print('Validation results:')
    print(trainer.evaluate(valid_dl))
    print('Test results:')
    print(trainer.evaluate(test_dl))

    val_preds = trainer.predict(valid_dl)
    test_preds = trainer.predict(train_dl)

val_preds.shape
>> Out[18]: (150, 1)
valid_dl.label.shape
>> Out[19]: (150,)
arita37 commented 4 years ago

Having more details on evaluate, it would be useful

shimengfeng commented 4 years ago

I think when you use .predict, you will get a list of scores. You can sort on those scores and then get the order of it using np.argsort. Then you can get the corresponding document using the order you obtained. This is my understanding and hopefully it is helpful.

littlewine commented 4 years ago

Hi, thanks for your message. You are correct. The problem was that I was using a dataset creation function of MZ that sampled (positives, negatives and shuffled) to create the dataset, so there was a mismatch in input and output.

On Tue, Jun 16, 2020, 17:32 shimengfeng notifications@github.com wrote:

I think when you use .predict, you will get a list of scores. You can sort on those scores and then get the order of it using np.argsort. Then you can get the corresponding document using the order you obtained. This is my understanding and hopefully it is helpful.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/NTMC-Community/MatchZoo-py/issues/136#issuecomment-644840213, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACZZRYIO573JRFR4ZZ4XV4TRW6F7ZANCNFSM4LTVLMQQ .