Skip to content

Commit

Permalink
Merge pull request #3 from cmccomb/master
Browse files Browse the repository at this point in the history
Checking for non-zero size of x-matrix
  • Loading branch information
cmccomb authored Feb 9, 2024
2 parents d9467d3 + 65b1bc4 commit a36ea10
Showing 1 changed file with 148 additions and 108 deletions.
256 changes: 148 additions & 108 deletions src/train_and_predict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,69 +56,90 @@ pub mod train_and_predict_functions {
y: Array,
algorithm: ImmutableString,
) -> Result<Model, Box<EvalAltResult>> {
let algorithm_string = algorithm.as_str();
let xvec = smartcorelib::linalg::basic::matrix::DenseMatrix::from_2d_vec(
&x.into_iter()
.map(|observation| {
array_to_vec_float(&mut observation.clone().into_array().unwrap())
})
.collect::<Vec<Vec<FLOAT>>>(),
);
match algorithm_string {
"linear" => {
let yvec = y
.clone()
.into_iter()
.map(|el| el.as_float().unwrap())
.collect::<Vec<FLOAT>>();
match LinearRegression::fit(&xvec, &yvec, LinearRegressionParameters::default()) {
Ok(model) => Ok(Model {
saved_model: bincode::serialize(&model).unwrap(),
model_type: algorithm_string.to_string(),
}),
Err(e) => {
Err(EvalAltResult::ErrorArithmetic(format!("{e}"), Position::NONE).into())
// Make x array
let array_as_vec_vec_float = &x
.into_iter()
.map(|observation| {
crate::train_and_predict_functions::array_to_vec_float(
&mut observation.clone().into_array().unwrap(),
)
})
.collect::<Vec<Vec<FLOAT>>>();

// Check if x array is empty
if array_as_vec_vec_float.len() == 0 {
Err(EvalAltResult::ErrorArrayBounds(0, 0, Position::NONE).into())
} else {
let algorithm_string = algorithm.as_str();
let xvec = smartcorelib::linalg::basic::matrix::DenseMatrix::from_2d_vec(
array_as_vec_vec_float,
);
match algorithm_string {
"linear" => {
let yvec = y
.clone()
.into_iter()
.map(|el| el.as_float().unwrap())
.collect::<Vec<FLOAT>>();
match LinearRegression::fit(&xvec, &yvec, LinearRegressionParameters::default())
{
Ok(model) => Ok(Model {
saved_model: bincode::serialize(&model).unwrap(),
model_type: algorithm_string.to_string(),
}),
Err(e) => Err(EvalAltResult::ErrorArithmetic(
format!("{e}"),
Position::NONE,
)
.into()),
}
}
}
"lasso" => {
let yvec = y
.clone()
.into_iter()
.map(|el| el.as_float().unwrap())
.collect::<Vec<FLOAT>>();
match Lasso::fit(&xvec, &yvec, LassoParameters::default()) {
Ok(model) => Ok(Model {
saved_model: bincode::serialize(&model).unwrap(),
model_type: algorithm_string.to_string(),
}),
Err(e) => {
Err(EvalAltResult::ErrorArithmetic(format!("{e}"), Position::NONE).into())
"lasso" => {
let yvec = y
.clone()
.into_iter()
.map(|el| el.as_float().unwrap())
.collect::<Vec<FLOAT>>();
match Lasso::fit(&xvec, &yvec, LassoParameters::default()) {
Ok(model) => Ok(Model {
saved_model: bincode::serialize(&model).unwrap(),
model_type: algorithm_string.to_string(),
}),
Err(e) => Err(EvalAltResult::ErrorArithmetic(
format!("{e}"),
Position::NONE,
)
.into()),
}
}
}
"logistic" => {
let yvec = y
.clone()
.into_iter()
.map(|el| el.as_int().unwrap())
.collect::<Vec<INT>>();
match LogisticRegression::fit(&xvec, &yvec, LogisticRegressionParameters::default())
{
Ok(model) => Ok(Model {
saved_model: bincode::serialize(&model).unwrap(),
model_type: algorithm_string.to_string(),
}),
Err(e) => {
Err(EvalAltResult::ErrorArithmetic(format!("{e}"), Position::NONE).into())
"logistic" => {
let yvec = y
.clone()
.into_iter()
.map(|el| el.as_int().unwrap())
.collect::<Vec<INT>>();
match LogisticRegression::fit(
&xvec,
&yvec,
LogisticRegressionParameters::default(),
) {
Ok(model) => Ok(Model {
saved_model: bincode::serialize(&model).unwrap(),
model_type: algorithm_string.to_string(),
}),
Err(e) => Err(EvalAltResult::ErrorArithmetic(
format!("{e}"),
Position::NONE,
)
.into()),
}
}
&_ => Err(EvalAltResult::ErrorArithmetic(
format!("{} is not a recognized model type.", algorithm_string),
Position::NONE,
)
.into()),
}
&_ => Err(EvalAltResult::ErrorArithmetic(
format!("{} is not a recognized model type.", algorithm_string),
Position::NONE,
)
.into()),
}
}

Expand All @@ -136,59 +157,78 @@ pub mod train_and_predict_functions {
/// ```
#[rhai_fn(name = "predict", return_raw, pure)]
pub fn predict_with_model(x: &mut Array, model: Model) -> Result<Array, Box<EvalAltResult>> {
let xvec = DenseMatrix::from_2d_vec(
&x.into_iter()
.map(|observation| {
array_to_vec_float(&mut observation.clone().into_array().unwrap())
})
.collect::<Vec<Vec<FLOAT>>>(),
);
let algorithm_string = model.model_type.as_str();
match algorithm_string {
"linear" => {
let model_ready: LinearRegression<FLOAT, FLOAT, DenseMatrix<FLOAT>, Vec<FLOAT>> =
bincode::deserialize(&*model.saved_model).unwrap();
return match model_ready.predict(&xvec) {
Ok(y) => Ok(y
.into_iter()
.map(|observation| Dynamic::from_float(observation))
.collect::<Vec<Dynamic>>()),
Err(e) => {
Err(EvalAltResult::ErrorArithmetic(format!("{e}"), Position::NONE).into())
}
};
}
"lasso" => {
let model_ready: Lasso<FLOAT, FLOAT, DenseMatrix<FLOAT>, Vec<FLOAT>> =
bincode::deserialize(&*model.saved_model).unwrap();
return match model_ready.predict(&xvec) {
Ok(y) => Ok(y
.into_iter()
.map(|observation| Dynamic::from_float(observation))
.collect::<Vec<Dynamic>>()),
Err(e) => {
Err(EvalAltResult::ErrorArithmetic(format!("{e}"), Position::NONE).into())
}
};
}
"logistic" => {
let model_ready: LogisticRegression<FLOAT, INT, DenseMatrix<FLOAT>, Vec<INT>> =
bincode::deserialize(&*model.saved_model).unwrap();
return match model_ready.predict(&xvec) {
Ok(y) => Ok(y
.into_iter()
.map(|observation| Dynamic::from_int(observation))
.collect::<Vec<Dynamic>>()),
Err(e) => {
Err(EvalAltResult::ErrorArithmetic(format!("{e}"), Position::NONE).into())
}
};
// Make x array
let array_as_vec_vec_float = &x
.into_iter()
.map(|observation| {
crate::train_and_predict_functions::array_to_vec_float(
&mut observation.clone().into_array().unwrap(),
)
})
.collect::<Vec<Vec<FLOAT>>>();

// Check if x array is empty
if array_as_vec_vec_float.len() == 0 {
Err(EvalAltResult::ErrorArrayBounds(0, 0, Position::NONE).into())
} else {
let xvec = DenseMatrix::from_2d_vec(array_as_vec_vec_float);
let algorithm_string = model.model_type.as_str();
match algorithm_string {
"linear" => {
let model_ready: LinearRegression<
FLOAT,
FLOAT,
DenseMatrix<FLOAT>,
Vec<FLOAT>,
> = bincode::deserialize(&*model.saved_model).unwrap();
return match model_ready.predict(&xvec) {
Ok(y) => Ok(y
.into_iter()
.map(|observation| Dynamic::from_float(observation))
.collect::<Vec<Dynamic>>()),
Err(e) => Err(EvalAltResult::ErrorArithmetic(
format!("{e}"),
Position::NONE,
)
.into()),
};
}
"lasso" => {
let model_ready: Lasso<FLOAT, FLOAT, DenseMatrix<FLOAT>, Vec<FLOAT>> =
bincode::deserialize(&*model.saved_model).unwrap();
return match model_ready.predict(&xvec) {
Ok(y) => Ok(y
.into_iter()
.map(|observation| Dynamic::from_float(observation))
.collect::<Vec<Dynamic>>()),
Err(e) => Err(EvalAltResult::ErrorArithmetic(
format!("{e}"),
Position::NONE,
)
.into()),
};
}
"logistic" => {
let model_ready: LogisticRegression<FLOAT, INT, DenseMatrix<FLOAT>, Vec<INT>> =
bincode::deserialize(&*model.saved_model).unwrap();
return match model_ready.predict(&xvec) {
Ok(y) => Ok(y
.into_iter()
.map(|observation| Dynamic::from_int(observation))
.collect::<Vec<Dynamic>>()),
Err(e) => Err(EvalAltResult::ErrorArithmetic(
format!("{e}"),
Position::NONE,
)
.into()),
};
}
&_ => Err(EvalAltResult::ErrorArithmetic(
format!("{} is not a recognized model type.", algorithm_string),
Position::NONE,
)
.into()),
}
&_ => Err(EvalAltResult::ErrorArithmetic(
format!("{} is not a recognized model type.", algorithm_string),
Position::NONE,
)
.into()),
}
}
}

0 comments on commit a36ea10

Please sign in to comment.