Skip to content

Commit

Permalink
More functions on dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Ten0 committed Aug 8, 2023
1 parent dbc9f81 commit 64a9559
Showing 1 changed file with 115 additions and 3 deletions.
118 changes: 115 additions & 3 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,19 @@ impl Dataset {
/// ```
pub fn from_mat(data: &[f64], n_rows: usize, label: &[f32]) -> Result<Self> {
let data_length = data.len();
if data.len() % n_rows != 0 {
if (data_length != 0 || n_rows != 0) && data_length % n_rows != 0 {
return Err(Error::new(format!(
"data len is not multiple of n_rows ({n_rows}), but all rows \
should have the same number of features",
)));
}
let feature_length = data_length / n_rows;
let feature_length = if data_length == 0 && n_rows == 0 {
0
} else {
data_length / n_rows
};

let nrow = data_length
let nrow = n_rows
.try_into()
.map_err(|_| Error::new("number of rows doesn't fit into an i32"))?;
let ncol = feature_length
Expand Down Expand Up @@ -239,6 +243,52 @@ impl Dataset {
}
Self::from_mat(feature_values, label_values)
}

pub fn n_rows(&self) -> Result<usize> {
let mut result = 0_i32;
lgbm_call!(lightgbm_sys::LGBM_DatasetGetNumData(
self.handle,
&mut result
))?;
result
.try_into()
.map_err(|_| Error::new("dataset length negative"))
}

pub fn n_features(&self) -> Result<usize> {
let mut result = 0_i32;
lgbm_call!(lightgbm_sys::LGBM_DatasetGetNumFeature(
self.handle,
&mut result
))?;
result
.try_into()
.map_err(|_| Error::new("feature count negative"))
}

pub fn set_weights(&mut self, weights: &[f32]) -> Result<()> {
let n_rows = self.n_rows()?;
if n_rows != weights.len() {
return Err(Error::new(format!(
"got {} weights, but dataset has {} records",
weights.len(),
n_rows
)));
}
let field_name = CString::new("weight").unwrap();
let len = weights
.len()
.try_into()
.map_err(|_| Error::new("weights len doesn't fit into an i32"))?;
lgbm_call!(lightgbm_sys::LGBM_DatasetSetField(
self.handle,
field_name.as_ptr() as *const c_char,
weights.as_ptr() as *const c_void,
len,
lightgbm_sys::C_API_DTYPE_FLOAT32 as i32,
))?;
Ok(())
}
}

#[cfg(test)]
Expand Down Expand Up @@ -287,4 +337,66 @@ mod tests {
let df_dataset = Dataset::from_dataframe(df, String::from("label"));
assert!(df_dataset.is_ok());
}

#[test]
fn get_dataset_properties() {
let data = &[
[1.0, 0.1, 0.2, 0.1],
[0.7, 0.4, 0.5, 0.1],
[0.9, 0.8, 0.5, 0.1],
[0.2, 0.2, 0.8, 0.7],
[0.1, 0.7, 1.0, 0.9],
];
let label = &[0.0, 0.0, 0.0, 1.0, 1.0];
let dataset = Dataset::from_mat(
&data.iter().flatten().copied().collect::<Vec<_>>(),
data.len(),
label,
)
.unwrap();
assert_eq!(dataset.n_rows(), Ok(5));
assert_eq!(dataset.n_features(), Ok(4));
}

#[test]
fn set_weights() {
let data = &[
[1.0, 0.1, 0.2, 0.1],
[0.7, 0.4, 0.5, 0.1],
[0.9, 0.8, 0.5, 0.1],
[0.2, 0.2, 0.8, 0.7],
[0.1, 0.7, 1.0, 0.9],
];
let label = &[0.0, 0.0, 0.0, 1.0, 1.0];
let mut dataset = Dataset::from_mat(
&data.iter().flatten().copied().collect::<Vec<_>>(),
data.len(),
label,
)
.unwrap();
let weights = &[0.5, 1.0, 2.0, 0.5, 0.5];
dataset.set_weights(weights).unwrap();
}

#[test]
fn set_weights_wrong_len() {
let data = &[
[1.0, 0.1, 0.2, 0.1],
[0.7, 0.4, 0.5, 0.1],
[0.9, 0.8, 0.5, 0.1],
[0.2, 0.2, 0.8, 0.7],
[0.1, 0.7, 1.0, 0.9],
];
let label = &[0.0, 0.0, 0.0, 1.0, 1.0];
let mut dataset = Dataset::from_mat(
&data.iter().flatten().copied().collect::<Vec<_>>(),
data.len(),
label,
)
.unwrap();
let weights_short = &[0.5, 1.0, 2.0, 0.5];
let weights_long = &[0.5, 1.0, 2.0, 0.5, 0.1, 0.1];
assert!(dataset.set_weights(weights_short).is_err());
assert!(dataset.set_weights(weights_long).is_err());
}
}

0 comments on commit 64a9559

Please sign in to comment.