Skip to content

Commit 22ba1ca

Browse files
Update default min_prob (tensorzero#4337)
1 parent dde9e87 commit 22ba1ca

File tree

2 files changed

+53
-9
lines changed

2 files changed

+53
-9
lines changed

tensorzero-core/src/experimentation/track_and_stop/estimate_optimal_probabilities.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ pub fn estimate_optimal_probabilities(
133133
// TODO: for boolean metrics, set default epsilon to e.g. 0.01. For float metrics, anchor to reward distributions once available.
134134
let epsilon: f64 = epsilon.unwrap_or(0.0);
135135
let variance_floor: f64 = variance_floor.unwrap_or(1e-12);
136-
let min_prob: f64 = min_prob.unwrap_or(1e-6);
136+
// Default min_prob is 0.0, but we apply a floor of 1e-6 for numerical stability in the optimization
137+
let min_prob: f64 = min_prob.unwrap_or(0.0).max(1e-6);
137138
let reg0: f64 = reg0.unwrap_or(0.01);
138139

139140
let pull_counts: Vec<u64> = feedback.iter().map(|x| x.count).collect();

tensorzero-core/src/experimentation/track_and_stop/mod.rs

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,11 @@ fn default_update_period_s() -> u64 {
282282
300
283283
}
284284

285+
/// Default minimum probability is 0.0, but it will be treated as 1e-6 internally
286+
/// in the optimization algorithm for numerical stability.
285287
#[expect(clippy::unnecessary_wraps)]
286288
fn default_min_prob() -> Option<f64> {
287-
Some(1e-6)
289+
Some(0.0)
288290
}
289291

290292
impl UninitializedTrackAndStopConfig {
@@ -355,23 +357,26 @@ impl UninitializedTrackAndStopConfig {
355357
}
356358

357359
// Validate min_prob if provided
360+
// Note: min_prob can be 0.0, but it will be treated as 1e-6 internally for numerical stability
358361
if let Some(min_prob) = self.min_prob {
359362
// Check non-negative
360363
if min_prob < 0.0 {
361364
return Err(Error::new(ErrorDetails::Config {
362-
message: format!("Track-and-Stop min_prob must be >= 0, got {min_prob}"),
365+
message: format!("Track-and-Stop `min_prob` must be >= 0.0, got {min_prob}"),
363366
}));
364367
}
365368

366369
// Check finite
367370
if !min_prob.is_finite() {
368371
return Err(Error::new(ErrorDetails::Config {
369-
message: format!("Track-and-Stop min_prob must be finite, got {min_prob}"),
372+
message: format!("Track-and-Stop `min_prob` must be finite, got {min_prob}"),
370373
}));
371374
}
372375

373376
// Check that min_prob * num_candidate_variants <= 1.0
374377
// Only candidate variants get probability mass, not fallback variants
378+
// Note: This check uses the configured min_prob value (which can be 0.0).
379+
// The actual optimization uses max(min_prob, 1e-6) for numerical stability.
375380
let num_candidate_variants = self.candidate_variants.len();
376381
let min_total_prob = min_prob * (num_candidate_variants as f64);
377382
if min_total_prob > 1.0 + 1e-9 {
@@ -2184,8 +2189,8 @@ mod tests {
21842189

21852190
#[test]
21862191
fn test_min_prob_none_uses_default() {
2187-
// Test that when min_prob is None, the default value from
2188-
// estimate_optimal_probabilities (1e-6) is used
2192+
// Test that when min_prob is None, the default config value (0.0) is used,
2193+
// but the optimization algorithm applies a floor of 1e-6 for numerical stability
21892194
let candidates = vec!["A".to_string(), "B".to_string()];
21902195
let performances = vec![
21912196
create_feedback("A", 20, 0.5, 0.1),
@@ -2198,7 +2203,7 @@ mod tests {
21982203
10,
21992204
0.05,
22002205
0.0,
2201-
None, // min_prob is None, should use default
2206+
None, // min_prob is None, defaults to 0.0 in config, but 1e-6 is applied in optimization
22022207
MetricConfigOptimize::Max,
22032208
)
22042209
.unwrap();
@@ -2207,11 +2212,49 @@ mod tests {
22072212
TrackAndStopState::BanditsOnly {
22082213
sampling_probabilities,
22092214
} => {
2210-
// All probabilities should be >= default min_prob (1e-6)
2215+
// All probabilities should be >= 1e-6 (the floor applied in optimization)
22112216
for (variant_name, &prob) in &sampling_probabilities {
22122217
assert!(
22132218
prob >= 1e-6 - 1e-9,
2214-
"Variant {variant_name} has probability {prob} which is less than default min_prob"
2219+
"Variant {variant_name} has probability {prob} which is less than the optimization floor (1e-6)"
2220+
);
2221+
}
2222+
}
2223+
_ => panic!("Expected BanditsOnly state, got {state:?}"),
2224+
}
2225+
}
2226+
2227+
#[test]
2228+
fn test_min_prob_zero_accepted_and_uses_floor() {
2229+
// Test that min_prob=0.0 is accepted in config but the optimization
2230+
// algorithm applies a floor of 1e-6 for numerical stability
2231+
let candidates = vec!["A".to_string(), "B".to_string()];
2232+
let performances = vec![
2233+
create_feedback("A", 20, 0.5, 0.1),
2234+
create_feedback("B", 20, 0.6, 0.2),
2235+
];
2236+
2237+
let state = TrackAndStopState::new(
2238+
&candidates,
2239+
performances,
2240+
10,
2241+
0.05,
2242+
0.0,
2243+
Some(0.0), // min_prob is explicitly set to 0.0
2244+
MetricConfigOptimize::Max,
2245+
)
2246+
.unwrap();
2247+
2248+
match state {
2249+
TrackAndStopState::BanditsOnly {
2250+
sampling_probabilities,
2251+
} => {
2252+
// All probabilities should be >= 1e-6 (the floor applied in optimization)
2253+
// even though min_prob was set to 0.0
2254+
for (variant_name, &prob) in &sampling_probabilities {
2255+
assert!(
2256+
prob >= 1e-6 - 1e-9,
2257+
"Variant {variant_name} has probability {prob} which is less than the optimization floor (1e-6)"
22152258
);
22162259
}
22172260
}

0 commit comments

Comments
 (0)