guillaume-be / rust-bert

Rust native ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
https://docs.rs/crate/rust-bert
Apache License 2.0
2.51k stars 211 forks source link

How to run custom RobertaForSequenceClassification model #289

Open antonioualex opened 1 year ago

antonioualex commented 1 year ago

Hi guys, I'm trying to load a custom model with RobertaForSequenceClassification but I don't know how to "predict". I assume I have to use forward_t method, but I'm not sure how to and if that's the case. All I want to do is insert a text and receive a prediction. Bellow you can see my code.

    let config_resource = LocalResource {
        local_path: PathBuf::from("config.json"),
    };
    let vocab_resource = LocalResource {
        local_path: PathBuf::from("vocab.json"),
    };
    let merges_resource = LocalResource {
        local_path: PathBuf::from("merges.txt"),
    };
    let weights_resource = LocalResource {
        local_path: PathBuf::from("rust_model.ot"),
    };

    let config_path = config_resource.get_local_path().unwrap();
    let vocab_path = vocab_resource.get_local_path().unwrap();
    let merges_path = merges_resource.get_local_path().unwrap();
    let weights_path = weights_resource.get_local_path().unwrap();

    let vocab = RobertaVocab::from_file(vocab_path.to_str().unwrap()).unwrap();
    let merges = BpePairVocab::from_file(merges_path.to_str().unwrap()).unwrap();

    let tokenizer: RobertaTokenizer = RobertaTokenizer::from_existing_vocab_and_merges(
        vocab,
        merges,
        true,
        true,
    );

    let Input = ["my_text"];

    let device = Device::Cpu;
    let mut vs = nn::VarStore::new(device);
    let config = RobertaConfig::from_file(config_path);
    let roberta = RobertaForSequenceClassification::new(&vs.root(), &config);

    vs.load(weights_path).unwrap();

    //TODO: predict Input variable
guillaume-be commented 1 year ago

Hello @antonioualex ,

For sequence classification, the easiest would be to use pipelines that take care of tokenization, batching and padding for you. Can you please check the example at https://github.com/guillaume-be/rust-bert/blob/master/examples/sequence_classification.rs that illustrates how to do this with a set of defaults. You can update the configuration to use a custom model instead (see for example https://github.com/guillaume-be/rust-bert/blob/master/examples/sentiment_analysis_fnet.rs)

Please let me know if this helps.