davechallis / rust-xgboost

Rust bindings for XGBoost.
MIT License
102 stars 37 forks source link

help on safe implementation of XGBoosterGetModelRaw? #4

Open jonathanstrong opened 5 years ago

jonathanstrong commented 5 years ago

hey -

I was trying to implement the other side of the load/save from buffer, XGBoosterGetModelRaw, but I'm stuck, and thought maybe you could help.

This is where I am (in Booster impl):

    /// Returns a `Vec<u8>` with model weights.
    pub fn to_vec(&self) -> XGBResult<Vec<u8>> {
        debug!("Writing Booster to_vec");
        let mut out_len = 0; 
        let mut out_dptr = ptr::null_mut();
        xgb_call!(xgboost_sys::XGBoosterGetModelRaw(self.handle, &mut out_len, out_dptr))?;
        // let bytes: &[u8] = unsafe {
        //     let length: u64 = *(out_len as *const _);
        //     std::slice::from_raw_parts(out_dptr as *const _, length as usize)
        // };
        // let mut out: Vec<u8> = vec![0u8; bytes.len()];
        // out[..].copy_from_slice(bytes);
        // Ok(out)
    }

The commented out section is because I can't get past calling the xgboost function.

I have tried a fairly wide variety of various pointer things as it relates to calling XGBoosterGetModelRaw, but I get SIGSEGV no matter what.

The xgboost api is defined as:

/*!
 * \brief save model into binary raw bytes, return header of the array
 * user must copy the result out, before next xgboost call
 * \param handle handle
 * \param out_len the argument to hold the output length
 * \param out_dptr the argument to hold the output data pointer
 * \return 0 when success, -1 when failure happens
 */
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
                                 bst_ulong *out_len,
                                 const char **out_dptr);

Any ideas? Thanks!

davechallis commented 5 years ago

I think it might be because out_dptr is a double pointer in the C API (and marked as const). So making out_dptr a ptr::null instead of null_mut and passing in a mutable reference to it to the C API should work. Something like the following (though I haven't tested it):

/// Returns a `Vec<u8>` with model weights.
pub fn to_vec(&self) -> XGBResult<Vec<u8>> {
    debug!("Writing Booster to_vec");
    let mut out_len = 0;
    let mut out_dptr = ptr::null();
    xgb_call!(xgboost_sys::XGBoosterGetModelRaw(self.handle, &mut out_len, &mut out_dptr))?;
    let out_ptr_slice = unsafe { slice::from_raw_parts(out_dptr, out_len as usize) };
    let out_vec: Vec<u8> = out_ptr_slice.iter()
        .map(|str_ptr| unsafe { ffi::CStr::from_ptr(str_ptr).to_bytes().to_owned() })
        .collect::<Vec<Vec<u8>>>()
        .concat();
    Ok(out_vec)
}