Skip to content

Commit

Permalink
add save_file_size api
Browse files Browse the repository at this point in the history
  • Loading branch information
leofidus committed Feb 14, 2022
1 parent fdac515 commit 7502f39
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/booster.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use libc::{c_char, c_double, c_longlong, c_void};
use std;
use std::convert::TryInto;
use std::ffi::CString;

use serde_json::Value;
Expand Down Expand Up @@ -210,6 +211,24 @@ impl Booster {
))?;
Ok(())
}

/// Returns the size the model would have if saved using `save_file`, without having to write the file
pub fn save_file_size(&self) -> Result<u64> {
let mut out_size = 0_i64;
lgbm_call!(lightgbm_sys::LGBM_BoosterSaveModelToString(
self.handle,
0_i32,
-1_i32,
0_i32,
0,
&mut out_size as *mut _,
std::ptr::null_mut() as *mut i8
))?;
// subtract 1 because the file doesn't contain the final null character
(out_size - 1)
.try_into()
.map_err(|_| Error::new("size negative"))
}
}

impl Drop for Booster {
Expand Down Expand Up @@ -300,6 +319,18 @@ mod tests {
let _ = fs::remove_file("./test/test_save_file.output");
}

#[test]
fn save_file_size() {
let params = _default_params();
let bst = _train_booster(&params);
let filename = "./test/test_save_file_size.output";
assert_eq!(bst.save_file(filename), Ok(()));
let file_size = Path::new(filename).metadata().unwrap().len();
assert!(file_size > 0);
assert_eq!(bst.save_file_size(), Ok(file_size));
let _ = fs::remove_file(filename);
}

#[test]
fn from_file() {
let _ = Booster::from_file(&"./test/test_from_file.input");
Expand Down

0 comments on commit 7502f39

Please sign in to comment.