@@ -106,6 +106,16 @@ impl Booster {
106
106
Ok ( Booster { handle } )
107
107
}
108
108
109
+ /// Load a Booster directly from a buffer.
110
+ pub fn load_buffer ( bytes : & [ u8 ] ) -> XGBResult < Self > {
111
+ debug ! ( "Loading Booster from buffer (length = {})" , bytes. len( ) ) ;
112
+
113
+ let mut handle = ptr:: null_mut ( ) ;
114
+ xgb_call ! ( xgboost_sys:: XGBoosterCreate ( ptr:: null( ) , 0 , & mut handle) ) ?;
115
+ xgb_call ! ( xgboost_sys:: XGBoosterLoadModelFromBuffer ( handle, bytes. as_ptr( ) as * const _, bytes. len( ) as u64 ) ) ?;
116
+ Ok ( Booster { handle } )
117
+ }
118
+
109
119
/// Convenience function for creating/training a new Booster.
110
120
///
111
121
/// This does the following:
@@ -692,6 +702,26 @@ mod tests {
692
702
assert_eq ! ( attr, Some ( "bar" . to_owned( ) ) ) ;
693
703
}
694
704
705
+ #[ test]
706
+ fn save_and_load_from_buffer ( ) {
707
+ let mut booster = load_test_booster ( ) ;
708
+ let attr = booster. get_attribute ( "foo" ) . expect ( "Getting attribute failed" ) ;
709
+ assert_eq ! ( attr, None ) ;
710
+
711
+ booster. set_attribute ( "foo" , "bar" ) . expect ( "Setting attribute failed" ) ;
712
+ let attr = booster. get_attribute ( "foo" ) . expect ( "Getting attribute failed" ) ;
713
+ assert_eq ! ( attr, Some ( "bar" . to_owned( ) ) ) ;
714
+
715
+ let mut dir = tempfile:: tempdir ( ) . expect ( "create temp dir" ) ;
716
+ let path = dir. path ( ) . join ( "test-xgboost-model" ) ;
717
+ booster. save ( & path) . expect ( "saving booster" ) ;
718
+ drop ( booster) ;
719
+ let bytes = std:: fs:: read ( & path) . expect ( "read saved booster file" ) ;
720
+ let booster = Booster :: load_buffer ( & bytes[ ..] ) . expect ( "load booster from buffer" ) ;
721
+ let attr = booster. get_attribute ( "foo" ) . expect ( "Getting attribute failed" ) ;
722
+ assert_eq ! ( attr, Some ( "bar" . to_owned( ) ) ) ;
723
+ }
724
+
695
725
#[ test]
696
726
fn get_attribute_names ( ) {
697
727
let mut booster = load_test_booster ( ) ;
0 commit comments