guillaume-be / rust-bert

Rust native ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
https://docs.rs/crate/rust-bert
Apache License 2.0
2.67k stars 216 forks source link

Is multilabel prediction correct? #429

Closed Mnwa closed 1 year ago

Mnwa commented 1 year ago

Hi! When I call predict_multilabel with threshold bigger then returned scores, I expected what method will return an empty vectors of labels. Like are

model.predict_multilabel(&["foo", "bar"], threshold)?

Will return

vec![vec![], vec![]]

But real behavior is

vec![]

It's very hard to map, if second input will return a scores and first will not, how can I map back?

Responses of model filtering here https://github.com/guillaume-be/rust-bert/blob/main/src/pipelines/sequence_classification.rs#L840

guillaume-be commented 1 year ago

Hello @Mnwa ,

In your example you are passing 2 input texts for classification, therefore the output will be a vector with 2 vectors (1 for each input sequence). Each of the sub-vectors will contains the labels for its input text, hence I believe the shape of the output is correct.

Vec          \\ 1 element per input sequence
|_____ Vec   \\ 1 element per class above threshold for text 1
|_____ Vec   \\ 1 element per class above threshold for text 2

If the vector of label for a given input text is not empty, it will contain the full label information allowing you to perform the mapping.

I hope this helps