guillaume-be / rust-bert

Rust native ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
https://docs.rs/crate/rust-bert
Apache License 2.0
2.6k stars 215 forks source link

RemoteResource doesn't allow loading safetensors models #447

Open zaytsev opened 7 months ago

zaytsev commented 7 months ago

RemoteResource resource provider doesn't preserve file name or extension

        let cached_path = CACHE
            .cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
        Ok(cached_path)

but Tch-rs requires model path to have safetensor extension to load model file in Safetensors format

    fn named_tensors<T: AsRef<std::path::Path>>(
        &self,
        path: T,
    ) -> Result<HashMap<String, Tensor>, TchError> {
        let named_tensors = match path.as_ref().extension().and_then(|x| x.to_str()) {
            Some("bin") | Some("pt") => Tensor::loadz_multi_with_device(&path, self.device),
            Some("safetensors") => Tensor::read_safetensors(path),
            Some(_) | None => Tensor::load_multi_with_device(&path, self.device),
        };
        Ok(named_tensors?.into_iter().collect())
    }