Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/booster.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use libc::{c_char, c_double, c_longlong, c_void};
use std;
use std::convert::TryInto;
use std::ffi::CString;

use serde_json::Value;
Expand Down Expand Up @@ -32,6 +33,20 @@ impl Booster {
Ok(Booster::new(handle))
}

/// Init from model string.
pub fn from_string(model_description: &str) -> Result<Self> {
let cstring = CString::new(model_description).unwrap();
let mut out_num_iterations = 0;
let mut handle = std::ptr::null_mut();
lgbm_call!(lightgbm_sys::LGBM_BoosterLoadModelFromString(
cstring.as_ptr() as *const c_char,
&mut out_num_iterations,
&mut handle
))?;

Ok(Booster::new(handle))
}

/// Create a new Booster model with given Dataset and parameters.
///
/// Example
Expand Down Expand Up @@ -210,6 +225,48 @@ impl Booster {
))?;
Ok(())
}

/// Save model to string. This returns the same content that `save_file` writes into a file.
pub fn save_string(&self) -> Result<String> {
// get nessesary buffer size
let mut out_size = 0_i64;
lgbm_call!(lightgbm_sys::LGBM_BoosterSaveModelToString(
self.handle,
0_i32,
-1_i32,
0_i32,
0,
&mut out_size as *mut _,
std::ptr::null_mut() as *mut i8
))?;

// write data to buffer and convert
let mut buffer = vec![
0u8;
out_size
.try_into()
.map_err(|_| Error::new("size negative"))?
];
lgbm_call!(lightgbm_sys::LGBM_BoosterSaveModelToString(
self.handle,
0_i32,
-1_i32,
0_i32,
buffer.len() as c_longlong,
&mut out_size as *mut _,
buffer.as_mut_ptr() as *mut c_char
))?;

if buffer.pop() != Some(0) {
// this should never happen, unless lightgbm has a bug
panic!("write out of bounds happened in lightgbm call");
}

let cstring = CString::new(buffer).map_err(|e| Error::new(e.to_string()))?;
cstring
.into_string()
.map_err(|_| Error::new("can't convert model string to unicode"))
}
}

impl Drop for Booster {
Expand Down Expand Up @@ -300,8 +357,28 @@ mod tests {
let _ = fs::remove_file("./test/test_save_file.output");
}

#[test]
fn save_string() {
let params = _default_params();
let bst = _train_booster(&params);
let filename = "./test/test_save_string.output";
assert_eq!(bst.save_file(&filename), Ok(()));
assert!(Path::new(&filename).exists());
let booster_file_content = fs::read_to_string(&filename).unwrap();
let _ = fs::remove_file("./test/test_save_file.output");

assert!(!booster_file_content.is_empty());
assert_eq!(Ok(booster_file_content), bst.save_string())
}

#[test]
fn from_file() {
let _ = Booster::from_file(&"./test/test_from_file.input");
}

#[test]
fn from_string() {
let model_string = fs::read_to_string("./test/test_from_file.input").unwrap();
Booster::from_string(&model_string).unwrap();
}
}