Skip to content

Commit b1d0d7e

Browse files
authored
Merge pull request #1 from jonathanstrong/master
Add `Booster::load_buffer` (safe interface to `XGBoosterLoadModelFromBuffer`)
2 parents 55cd90c + 99635d5 commit b1d0d7e

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

src/booster.rs

+30
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,16 @@ impl Booster {
106106
Ok(Booster { handle })
107107
}
108108

109+
/// Load a Booster directly from a buffer.
110+
pub fn load_buffer(bytes: &[u8]) -> XGBResult<Self> {
111+
debug!("Loading Booster from buffer (length = {})", bytes.len());
112+
113+
let mut handle = ptr::null_mut();
114+
xgb_call!(xgboost_sys::XGBoosterCreate(ptr::null(), 0, &mut handle))?;
115+
xgb_call!(xgboost_sys::XGBoosterLoadModelFromBuffer(handle, bytes.as_ptr() as *const _, bytes.len() as u64))?;
116+
Ok(Booster { handle })
117+
}
118+
109119
/// Convenience function for creating/training a new Booster.
110120
///
111121
/// This does the following:
@@ -692,6 +702,26 @@ mod tests {
692702
assert_eq!(attr, Some("bar".to_owned()));
693703
}
694704

705+
#[test]
706+
fn save_and_load_from_buffer() {
707+
let mut booster = load_test_booster();
708+
let attr = booster.get_attribute("foo").expect("Getting attribute failed");
709+
assert_eq!(attr, None);
710+
711+
booster.set_attribute("foo", "bar").expect("Setting attribute failed");
712+
let attr = booster.get_attribute("foo").expect("Getting attribute failed");
713+
assert_eq!(attr, Some("bar".to_owned()));
714+
715+
let mut dir = tempfile::tempdir().expect("create temp dir");
716+
let path = dir.path().join("test-xgboost-model");
717+
booster.save(&path).expect("saving booster");
718+
drop(booster);
719+
let bytes = std::fs::read(&path).expect("read saved booster file");
720+
let booster = Booster::load_buffer(&bytes[..]).expect("load booster from buffer");
721+
let attr = booster.get_attribute("foo").expect("Getting attribute failed");
722+
assert_eq!(attr, Some("bar".to_owned()));
723+
}
724+
695725
#[test]
696726
fn get_attribute_names() {
697727
let mut booster = load_test_booster();

0 commit comments

Comments
 (0)