guillaume-be / rust-bert

Rust native ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
https://docs.rs/crate/rust-bert
Apache License 2.0
2.6k stars 215 forks source link

Downloading a model to a local Directory #432

Closed kj3moraes closed 10 months ago

kj3moraes commented 11 months ago

Hello there, I recently started using rust-bert and I wanted to know what the right way was to

Currently I am doing something like this

 let mut model_cache_path = PathBuf::from(MODELS_CACHE_DIR);
        model_cache_path.push("albert");

        let model = if model_cache_path.exists() {
            println!("Creating from local cached model");
            SentenceEmbeddingsBuilder::local(model_cache_path)
                .create_model()
                .expect("Failed to create the AlBERT model")
        } else {
            fs::create_dir_all(&model_cache_path).unwrap();

            // Define the url where the model will be obtained and the directory
            // that it will be cached to
            let remote_resource = RemoteResource::new(
                SentenceEmbeddingsConfigResources::PARAPHRASE_ALBERT_SMALL_V2.1,
                model_cache_path.to_str().unwrap(),
            );

            // Get the model from a remote resource
            let model = SentenceEmbeddingsBuilder::remote(
                SentenceEmbeddingsModelType::ParaphraseAlbertSmallV2,
            )
            .create_model()
            .expect("Failed to create the AlBERT model");

             // Save the model to the cache here 
            // ??

            model
        };

This code would not save to my MODELS_CACHE_DIR. Using RemoteResources to specify modules_config, transformer_config, etc. will not download these to that specific directory, it just declares these paths.

blmarket commented 11 months ago

Model is lazy-downloaded - you need to use model at least once to make sure you download the model. Minimal example code to download the model:

    #[test]
    fn test_load_using_bert() -> anyhow::Result<()> {
        let model = SentenceEmbeddingsModel::new(SentenceEmbeddingsConfig::from(SentenceEmbeddingsModelType::ParaphraseAlbertSmallV2))?;
        dbg!(model.encode(&["This is a test sentence."])?);
        Ok(())
    }
kj3moraes commented 11 months ago

Fair enough. Is there a way to say at model initialization that when the model is downloaded I want it downloaded to my_cache_dir instead of ~/.cache/.rustbert ?

guillaume-be commented 11 months ago

Models get downloaded to the RUSTBERT_CACHE environment variable if set

kj3moraes commented 11 months ago

Ok, thanks that solves it.