diff --git a/src/booster.rs b/src/booster.rs index bd1732b..0d97db4 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -1,5 +1,6 @@ use libc::{c_char, c_double, c_longlong, c_void}; use std; +use std::convert::TryInto; use std::ffi::CString; use serde_json::Value; @@ -32,6 +33,20 @@ impl Booster { Ok(Booster::new(handle)) } + /// Init from model string. + pub fn from_string(model_description: &str) -> Result { + let cstring = CString::new(model_description).unwrap(); + let mut out_num_iterations = 0; + let mut handle = std::ptr::null_mut(); + lgbm_call!(lightgbm_sys::LGBM_BoosterLoadModelFromString( + cstring.as_ptr() as *const c_char, + &mut out_num_iterations, + &mut handle + ))?; + + Ok(Booster::new(handle)) + } + /// Create a new Booster model with given Dataset and parameters. /// /// Example @@ -210,6 +225,48 @@ impl Booster { ))?; Ok(()) } + + /// Save model to string. This returns the same content that `save_file` writes into a file. + pub fn save_string(&self) -> Result { + // get nessesary buffer size + let mut out_size = 0_i64; + lgbm_call!(lightgbm_sys::LGBM_BoosterSaveModelToString( + self.handle, + 0_i32, + -1_i32, + 0_i32, + 0, + &mut out_size as *mut _, + std::ptr::null_mut() as *mut i8 + ))?; + + // write data to buffer and convert + let mut buffer = vec![ + 0u8; + out_size + .try_into() + .map_err(|_| Error::new("size negative"))? + ]; + lgbm_call!(lightgbm_sys::LGBM_BoosterSaveModelToString( + self.handle, + 0_i32, + -1_i32, + 0_i32, + buffer.len() as c_longlong, + &mut out_size as *mut _, + buffer.as_mut_ptr() as *mut c_char + ))?; + + if buffer.pop() != Some(0) { + // this should never happen, unless lightgbm has a bug + panic!("write out of bounds happened in lightgbm call"); + } + + let cstring = CString::new(buffer).map_err(|e| Error::new(e.to_string()))?; + cstring + .into_string() + .map_err(|_| Error::new("can't convert model string to unicode")) + } } impl Drop for Booster { @@ -300,8 +357,28 @@ mod tests { let _ = fs::remove_file("./test/test_save_file.output"); } + #[test] + fn save_string() { + let params = _default_params(); + let bst = _train_booster(¶ms); + let filename = "./test/test_save_string.output"; + assert_eq!(bst.save_file(&filename), Ok(())); + assert!(Path::new(&filename).exists()); + let booster_file_content = fs::read_to_string(&filename).unwrap(); + let _ = fs::remove_file("./test/test_save_file.output"); + + assert!(!booster_file_content.is_empty()); + assert_eq!(Ok(booster_file_content), bst.save_string()) + } + #[test] fn from_file() { let _ = Booster::from_file(&"./test/test_from_file.input"); } + + #[test] + fn from_string() { + let model_string = fs::read_to_string("./test/test_from_file.input").unwrap(); + Booster::from_string(&model_string).unwrap(); + } }