rust-ml / linfa

A Rust machine learning framework.
Apache License 2.0
3.76k stars 249 forks source link

Exporting & Loading Trained Model? #290

Closed Bastian1110 closed 1 year ago

Bastian1110 commented 1 year ago

Is there a method to save a trained (GaussianNb)model and the load it?

I'm just learning how to use Rust, I just managed to implement a Gaussian Naive Bayes classifier model, is there any way to use the "predict" method without having to train the whole model again? I know that in libraries like Sklearn you can export them models and then load them in .pkl formats, is there a similar implementation in linfa?

Thank you so much!

YuhanLiin commented 1 year ago

The models should implement the Serde trait, so you can serialize them using something like ciborium

Bastian1110 commented 1 year ago

I'm having a little problem hehe, I'm trying to serialize the model (I'm using the example of linfa_svm) but I don't know if I'm using the correct syntax since I get the error in the line where I use cbor:

the trait bound `MultiClassModel<ndarray::ArrayBase<ndarray::data_repr::OwnedRepr<f64>, ndarray::dimension::dim::Dim<[usize; 2]>>, usize>: serde::ser::Serialize` is not satisfied
the following other types implement trait `serde::ser::Serialize`:

This is the code I'm using (linfa_svm/examples/winequality_multi.rs) :

let model = train
        .one_vs_all()?
        .into_iter()
        .map(|(l, x)| (l, params.fit(&x).unwrap()))
        .collect::<MultiClassModel<_, _>>();
 let pred = model.predict(&valid);

//Trying to serialize model
 let save_model = cbor!(model).unwrap();

Could you give me a more detailed example? I would appreciate it too much!

YuhanLiin commented 1 year ago

Seems like MultiClassModel and linfa-bayes have no support for Serde. Weird. We'll need to add that.

Bastian1110 commented 1 year ago

Oh! I will try with normal SVM then, thank you!

Bastian1110 commented 1 year ago

One last question, I already managed to serialize the SVM model without Multiclass to CBOR, using ciborium, I also managed to de-serialize it in another file and convert it to Value, the last step would be to convert it from Value to SVM, any idea how to do this?

Code for creating and exporting the model :

    let model = Svm::<_, bool>::params().pos_neg_weights(50000., 5000.).gaussian_kernel(80.0).fit(&train)?;

    //Serializing the trained model with ciborium
    let value_model : Value = cbor!(model).unwrap();
    let mut vec_model : Vec<u8> = Vec::new();
    let _cebor_writer = ciborium::ser::into_writer(&value_model, &mut vec_model);

    //Esporting it to a .cbor file
    let path: &Path = Path::new("./model.cbor");
    fs::write(path, vec_model).unwrap();

Attempt to use the trained model in other .rs program :

    let mut file = File::open("./model.cbor").unwrap();
    let mut data: Vec<u8> = Vec::new();
    file.read_to_end(&mut data).unwrap();

    let model_value : Value = ciborium::de::from_reader::<Value, _>(&data[..]).unwrap();
    let model: Svm<_, bool> = model_value.deserialized().unwrap(); // Error 

But I keep getting error when trying to converting form ciborium::Value to SVM, the rust-analyzer suggests : consider specifying the generic argument: ::<Svm<_, bool>>, I guess I have to pass the SVM serve-deserializer but I don't know how to do that.

I know this has nothing to do with linfa, but I really think that exporting and importing the models can be very useful.

Thank you!

Bastian1110 commented 1 year ago

My bad, it turns out the example SVM model uses Svm<f64, bool> not Svm<_, bool>. I only changed the line to :

let model: Svm<f65, bool> = model_value.deserialized().unwrap();

And it works super cool!

coolstudio1678 commented 7 months ago

One last question, I already managed to serialize the SVM model without Multiclass to CBOR, using ciborium, I also managed to de-serialize it in another file and convert it to Value, the last step would be to convert it from Value to SVM, any idea how to do this?

Code for creating and exporting the model :

    let model = Svm::<_, bool>::params().pos_neg_weights(50000., 5000.).gaussian_kernel(80.0).fit(&train)?;

    //Serializing the trained model with ciborium
    let value_model : Value = cbor!(model).unwrap();
    let mut vec_model : Vec<u8> = Vec::new();
    let _cebor_writer = ciborium::ser::into_writer(&value_model, &mut vec_model);

    //Esporting it to a .cbor file
    let path: &Path = Path::new("./model.cbor");
    fs::write(path, vec_model).unwrap();

Attempt to use the trained model in other .rs program :

    let mut file = File::open("./model.cbor").unwrap();
    let mut data: Vec<u8> = Vec::new();
    file.read_to_end(&mut data).unwrap();

    let model_value : Value = ciborium::de::from_reader::<Value, _>(&data[..]).unwrap();
    let model: Svm<_, bool> = model_value.deserialized().unwrap(); // Error 

But I keep getting error when trying to converting form ciborium::Value to SVM, the rust-analyzer suggests : consider specifying the generic argument: ::<Svm<_, bool>>, I guess I have to pass the SVM serve-deserializer but I don't know how to do that.

I know this has nothing to do with linfa, but I really think that exporting and importing the models can be very useful.

Thank you!

I use cbor!(model), it show errors,how to resolve?

the trait bound MultiClassModel<ndarray::ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>, Option<&str>>: serde::ser::Serialize is not satisfied the following other types implement trait serde::ser::Serialize: bool char isize i8 i16 i32 i64 i128 and 196 othersrustcClick for full compiler diagnostic lib.rs(222, 42): Actual error occurred here lib.rs(222, 9): required by a bound introduced by this call ser.rs(435, 35): required by a bound in value::ser::<impl ciborium::Value>::serialized