guillaume-be / rust-bert

Rust native ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
https://docs.rs/crate/rust-bert
Apache License 2.0
2.67k stars 216 forks source link

Please expose tonekizer params on models where `forward_t` is exposed #431

Open HarryCaveMan opened 1 year ago

HarryCaveMan commented 1 year ago

If I want to use the SequenceClassifier pipeline for something like reranking, I am (sort of) able to do so using the exposed forward_t method. The problem is that I will need to first encode the inputs using the model's tokenizer. I can get a ref to the tokenizer using get_tokenizer, but if I want to pass in tokenizer params (IE max_len and device) to tokenizer.tokenize, I cannot get them from the SequenceClassificationModel, because they are private fields and there are not any get methods like there are for the tokenizer itself.

Alternatively, you could add a method to wrap calls to SequenceClassificationModel.tokenizer.tokenize and pass these parameter in from the model instance.