Skip to content

Commit 420be04

Browse files
committed
Use fonction constraint value to determine best result
1 parent 91d7267 commit 420be04

File tree

7 files changed

+56
-30
lines changed

7 files changed

+56
-30
lines changed

crates/ego/src/solver/egor_impl.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ impl<SB: SurrogateBuilder + DeserializeOwned, C: CstrFn> EgorSolver<SB, C> {
8383
&cstr_tol,
8484
&sampling,
8585
None,
86-
find_best_result_index(y_data, &cstr_tol),
86+
find_best_result_index(y_data, &c_data, &cstr_tol),
8787
&fcstrs,
8888
);
8989
x_dat
@@ -375,6 +375,7 @@ where
375375
state.best_index.unwrap(),
376376
y_data.nrows() - add_count as usize,
377377
&y_data,
378+
&c_data,
378379
&new_state.cstr_tol,
379380
);
380381
new_state.prev_best_index = state.best_index;

crates/ego/src/solver/egor_solver.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ where
236236
let c_data = self.eval_fcstrs(problem, &x_data);
237237

238238
let mut initial_state = state
239-
.data((x_data, y_data.clone(), c_data))
239+
.data((x_data, y_data.clone(), c_data.clone()))
240240
.clusterings(clusterings)
241241
.theta_inits(theta_inits)
242242
.sampling(sampling);
@@ -251,7 +251,7 @@ where
251251
.unwrap_or(Array1::from_elem(self.config.n_cstr, DEFAULT_CSTR_TOL));
252252
initial_state.target_cost = self.config.target;
253253

254-
let best_index = find_best_result_index(&y_data, &initial_state.cstr_tol);
254+
let best_index = find_best_result_index(&y_data, &c_data, &initial_state.cstr_tol);
255255
initial_state.best_index = Some(best_index);
256256
initial_state.prev_best_index = Some(best_index);
257257
initial_state.last_best_iter = 0;

crates/ego/src/solver/egor_state.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -408,11 +408,11 @@ where
408408
/// assert!(state.is_best());
409409
/// ```
410410
fn update(&mut self) {
411-
if let Some((x_data, y_data, _c_data)) = self.data.as_ref() {
411+
if let Some((x_data, y_data, c_data)) = self.data.as_ref() {
412412
let best_index = self
413413
.best_index
414414
// TODO: use cdata in find_best_result_index
415-
.unwrap_or_else(|| find_best_result_index(y_data, &self.cstr_tol));
415+
.unwrap_or_else(|| find_best_result_index(y_data, c_data, &self.cstr_tol));
416416

417417
let param = x_data.row(best_index).to_owned();
418418
std::mem::swap(&mut self.prev_best_param, &mut self.best_param);

crates/ego/src/solver/trego.rs

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ impl<SB: SurrogateBuilder + DeserializeOwned, C: CstrFn> EgorSolver<SB, C> {
9999
best_index,
100100
y_data.nrows() - 1,
101101
&y_data,
102+
&c_data,
102103
&new_state.cstr_tol,
103104
);
104105
new_state = new_state.data((x_data, y_data, c_data));

crates/ego/src/utils/find_result.rs

+39-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2, Zip};
1+
use ndarray::{concatenate, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2, Zip};
22
use ndarray_stats::QuantileExt;
33

44
use crate::utils::sort_axis::*;
@@ -46,39 +46,54 @@ pub fn find_best_result_index_from<F: Float>(
4646
current_index: usize, /* current best index */
4747
offset_index: usize,
4848
ydata: &ArrayBase<impl Data<Elem = F>, Ix2>, /* the whole data so far */
49+
cdata: &ArrayBase<impl Data<Elem = F>, Ix2>, /* the whole function cstrs data so far */
4950
cstr_tol: &Array1<F>,
5051
) -> usize {
51-
let new_ydata = ydata.slice(s![offset_index.., ..]);
52+
// let new_ydata = ydata.slice(s![offset_index.., ..]);
53+
54+
let alldata = concatenate![Axis(1), ydata.to_owned(), cdata.to_owned()];
55+
let alltols = concatenate![Axis(0), cstr_tol.to_owned(), Array1::zeros(cdata.ncols())];
56+
57+
let new_ydata = alldata.slice(s![offset_index.., ..]);
5258

5359
let best = ydata.row(current_index);
5460
let min = new_ydata
5561
.outer_iter()
5662
.enumerate()
5763
.fold((usize::MAX, best), |a, b| {
58-
std::cmp::min_by(a, b, |(i, u), (j, v)| cstr_min((*i, u), (*j, v), cstr_tol))
64+
std::cmp::min_by(a, b, |(i, u), (j, v)| cstr_min((*i, u), (*j, v), &alltols))
5965
});
6066
match min {
6167
(usize::MAX, _) => current_index,
6268
(index, _) => offset_index + index,
6369
}
6470
}
6571

66-
/// Find best (eg minimal) cost value (y_data\[0\]) with valid constraints (y_data\[1..\] < cstr_tol).
67-
/// y_data containing ns samples [objective, cstr_1, ... cstr_nc] is given as a matrix (ns, nc + 1)
72+
/// Find best (eg minimal) cost value (y_data\[0\]) with valid constraints, meaning
73+
/// * y_data\[1..\] < cstr_tol
74+
/// * c_data[..] < 0
75+
///
76+
/// y_data containing ns samples [objective, cstr_1, ... cstr_nc] is given as a matrix (ns, nc + 1)
77+
/// c_data containing [fcstr_1, ... fcstr1_nfc] where fcstr_i is the value of function constraints at x_i
6878
pub fn find_best_result_index<F: Float>(
6979
y_data: &ArrayBase<impl Data<Elem = F>, Ix2>,
80+
c_data: &ArrayBase<impl Data<Elem = F>, Ix2>,
7081
cstr_tol: &Array1<F>,
7182
) -> usize {
72-
if y_data.ncols() > 1 {
83+
if y_data.ncols() > 1 || c_data.ncols() > 0 {
84+
// Merge metamodelised constraints and function constraints
85+
let alldata = concatenate![Axis(1), y_data.to_owned(), c_data.to_owned()];
86+
let alltols = concatenate![Axis(0), cstr_tol.to_owned(), Array1::zeros(c_data.ncols())];
87+
7388
// Compute sum of violated constraints
74-
let cstrs = y_data.slice(s![.., 1..]);
89+
let cstrs = &alldata.slice(s![.., 1..]);
7590
let mut c_obj = Array2::zeros((y_data.nrows(), 2));
7691

7792
Zip::from(c_obj.rows_mut())
7893
.and(cstrs.rows())
79-
.and(y_data.slice(s![.., 0]))
94+
.and(alldata.slice(s![.., 0]))
8095
.for_each(|mut c_obj_row, c_row, obj| {
81-
let c_sum = zip(c_row, cstr_tol)
96+
let c_sum = zip(c_row, &alltols)
8297
.filter(|(c, ctol)| *c > ctol)
8398
.fold(F::zero(), |acc, (c, ctol)| acc + (*c - *ctol).abs());
8499
c_obj_row.assign(&array![c_sum, *obj]);
@@ -106,13 +121,13 @@ pub fn find_best_result_index<F: Float>(
106121
let mut index = 0;
107122

108123
// sort regardoing minimal objective
109-
let perm = y_data.sort_axis_by(Axis(0), |i, j| y_data[[i, 0]] < y_data[[j, 0]]);
110-
let y_sort = y_data.to_owned().permute_axis(Axis(0), &perm);
124+
let perm = alldata.sort_axis_by(Axis(0), |i, j| alldata[[i, 0]] < alldata[[j, 0]]);
125+
let y_sort = alldata.to_owned().permute_axis(Axis(0), &perm);
111126

112127
// Take the first one which do not violate constraints
113128
for (i, row) in y_sort.axis_iter(Axis(0)).enumerate() {
114129
let success =
115-
zip(row.slice(s![1..]), cstr_tol).fold(true, |acc, (c, tol)| acc && c < tol);
130+
zip(row.slice(s![1..]), &alltols).fold(true, |acc, (c, tol)| acc && c < tol);
116131

117132
if success {
118133
index = i;
@@ -170,28 +185,29 @@ mod tests {
170185
fn test_find_best_obj() {
171186
// respect constraint (0, 1, 2) and minimize obj (1)
172187
let ydata = array![[1.0, -0.15], [-1.0, -0.01], [2.0, -0.2], [-3.0, 2.0]];
188+
let cdata = array![[], [], [], []];
173189
let cstr_tol = Array1::from_elem(4, 0.1);
174-
assert_abs_diff_eq!(1, find_best_result_index(&ydata, &cstr_tol));
190+
assert_abs_diff_eq!(1, find_best_result_index(&ydata, &cdata, &cstr_tol));
175191

176192
// respect constraint (0, 1, 2) and minimize obj (2)
177193
let ydata = array![[1.0, -0.15], [-1.0, -0.01], [-2.0, -0.2], [-3.0, 2.0]];
178194
let cstr_tol = Array1::from_elem(4, 0.1);
179-
assert_abs_diff_eq!(2, find_best_result_index(&ydata, &cstr_tol));
195+
assert_abs_diff_eq!(2, find_best_result_index(&ydata, &cdata, &cstr_tol));
180196

181197
// all out of tolerance => minimize constraint overshoot sum (0)
182198
let ydata = array![[1.0, 0.15], [-1.0, 0.3], [2.0, 0.2], [-3.0, 2.0]];
183199
let cstr_tol = Array1::from_elem(4, 0.1);
184-
assert_abs_diff_eq!(0, find_best_result_index(&ydata, &cstr_tol));
200+
assert_abs_diff_eq!(0, find_best_result_index(&ydata, &cdata, &cstr_tol));
185201

186202
// all in tolerance => min obj
187203
let ydata = array![[1.0, 0.15], [-1.0, 0.3], [2.0, 0.2], [-3.0, 2.0]];
188204
let cstr_tol = Array1::from_elem(4, 3.0);
189-
assert_abs_diff_eq!(3, find_best_result_index(&ydata, &cstr_tol));
205+
assert_abs_diff_eq!(3, find_best_result_index(&ydata, &cdata, &cstr_tol));
190206

191207
// unconstrained => min obj
192208
let ydata = array![[1.0], [-1.0], [2.0], [-3.0]];
193209
let cstr_tol = Array1::from_elem(4, 0.1);
194-
assert_abs_diff_eq!(3, find_best_result_index(&ydata, &cstr_tol));
210+
assert_abs_diff_eq!(3, find_best_result_index(&ydata, &cdata, &cstr_tol));
195211
}
196212

197213
#[test]
@@ -216,11 +232,12 @@ mod tests {
216232
[-5.50801509642, 1.951629235996e-7, 2.48275059533e-6],
217233
[-5.50801399313, -6.707576982734e-8, 1.03991762046e-6]
218234
];
235+
let c_data = Array2::zeros((y_data.nrows(), 0));
219236
let cstr_tol = Array1::from_vec(vec![1e-6; 2]); // this is the default
220-
let index = find_best_result_index(&y_data, &cstr_tol);
237+
let index = find_best_result_index(&y_data, &c_data, &cstr_tol);
221238
assert_eq!(11, index);
222239
let cstr_tol = Array1::from_vec(vec![2e-6; 2]);
223-
let index = find_best_result_index(&y_data, &cstr_tol);
240+
let index = find_best_result_index(&y_data, &c_data, &cstr_tol);
224241
assert_eq!(17, index);
225242
}
226243

@@ -246,11 +263,12 @@ mod tests {
246263
[-5.50801509642, 1.951629235996e-7, 2.48275059533e-6],
247264
[-5.50801399313, -6.707576982734e-8, 1.03991762046e-6]
248265
];
266+
let c_data = Array2::zeros((y_data.nrows(), 0));
249267
let cstr_tol = Array1::from_vec(vec![2e-6; 2]);
250-
let index = find_best_result_index_from(11, 12, &y_data, &cstr_tol);
268+
let index = find_best_result_index_from(11, 12, &y_data, &c_data, &cstr_tol);
251269
assert_eq!(17, index);
252270
let cstr_tol = Array1::from_vec(vec![2e-6; 2]);
253-
let index = find_best_result_index(&y_data, &cstr_tol);
271+
let index = find_best_result_index(&y_data, &c_data, &cstr_tol);
254272
assert_eq!(17, index);
255273
}
256274
}

crates/moe/src/algorithm.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ impl GpMixtureValidParams<f64> {
103103

104104
let training = if recomb == Recombination::Smooth(None) && self.n_clusters() > 1 {
105105
// Extract 5% of data for validation to find best heaviside factor
106-
// TODO: Use cross-validation ? Performances
106+
// TODO: Better use cross-validation... but performances impact?
107107
let (_, training_data) = extract_part(&data, 5);
108108
training_data
109109
} else {
@@ -176,7 +176,7 @@ impl GpMixtureValidParams<f64> {
176176

177177
if recomb == Recombination::Smooth(None) && self.n_clusters() > 1 {
178178
// Extract 5% of data for validation to find best heaviside factor
179-
// TODO: Use cross-validation ? Performances
179+
// TODO: Better use cross-validation... but performances impact?
180180
let (test, _) = extract_part(&data, 5);
181181
let xtest = test.slice(s![.., ..nx]).to_owned();
182182
let ytest = test.slice(s![.., nx..]).to_owned().remove_axis(Axis(1));

python/src/egor.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ impl Egor {
376376

377377
/// This function gives the best evaluation index given the outputs
378378
/// of the function (objective wrt constraints) under minimization.
379+
/// Caveat: This function does not take into account function constraints values
379380
///
380381
/// # Parameters
381382
/// y_doe (array[ns, 1 + n_cstr]): ns values of objective and constraints
@@ -386,11 +387,14 @@ impl Egor {
386387
#[pyo3(signature = (y_doe))]
387388
fn get_result_index(&self, y_doe: PyReadonlyArray2<f64>) -> usize {
388389
let y_doe = y_doe.as_array();
389-
find_best_result_index(&y_doe, &self.cstr_tol())
390+
// TODO: Make c_doe an optional argument ?
391+
let c_doe = Array2::zeros((y_doe.ncols(), 0));
392+
find_best_result_index(&y_doe, &c_doe, &self.cstr_tol())
390393
}
391394

392395
/// This function gives the best result given inputs and outputs
393396
/// of the function (objective wrt constraints) under minimization.
397+
/// Caveat: This function does not take into account function constraints values
394398
///
395399
/// # Parameters
396400
/// x_doe (array[ns, nx]): ns samples where function has been evaluated
@@ -410,7 +414,9 @@ impl Egor {
410414
) -> OptimResult {
411415
let x_doe = x_doe.as_array();
412416
let y_doe = y_doe.as_array();
413-
let idx = find_best_result_index(&y_doe, &self.cstr_tol());
417+
// TODO: Make c_doe an optional argument ?
418+
let c_doe = Array2::zeros((y_doe.ncols(), 0));
419+
let idx = find_best_result_index(&y_doe, &c_doe, &self.cstr_tol());
414420
let x_opt = x_doe.row(idx).to_pyarray_bound(py).into();
415421
let y_opt = y_doe.row(idx).to_pyarray_bound(py).into();
416422
let x_doe = x_doe.to_pyarray_bound(py).into();

0 commit comments

Comments
 (0)