From 6bba1cca7c0331ba6597e2e80154b4e1be225ce9 Mon Sep 17 00:00:00 2001 From: atomflunder <80397293+atomflunder@users.noreply.github.com> Date: Thu, 10 Oct 2024 20:21:08 +0200 Subject: [PATCH] Simplify trueskill calculations Actually improves the accuracy slightly, compared to the other trueskill libraries like ts-trueskill. And by slightly it means ~0.0000000000007 difference instead of ~0.0000000000009 --- src/trueskill/factor_graph.rs | 16 +++--- src/trueskill/mod.rs | 92 +++++++++-------------------------- 2 files changed, 30 insertions(+), 78 deletions(-) diff --git a/src/trueskill/factor_graph.rs b/src/trueskill/factor_graph.rs index 5408801..84d820a 100644 --- a/src/trueskill/factor_graph.rs +++ b/src/trueskill/factor_graph.rs @@ -230,8 +230,8 @@ impl SumFactor { pub struct TruncateFactor { id: usize, variable: Rc>, - v_func: Box f64>, - w_func: Box f64>, + v_func: Box f64>, + w_func: Box f64>, draw_margin: f64, } @@ -239,8 +239,8 @@ impl TruncateFactor { pub fn new( id: usize, variable: Rc>, - v_func: Box f64>, - w_func: Box f64>, + v_func: Box f64>, + w_func: Box f64>, draw_margin: f64, ) -> Self { variable.borrow_mut().messages.entry(id).or_default(); @@ -260,10 +260,10 @@ impl TruncateFactor { variable.gaussian / variable.messages[&self.id] }; let pi_sqrt = div.pi.sqrt(); - let arg_1 = div.tau / pi_sqrt; - let arg_2 = self.draw_margin * pi_sqrt; - let v = (self.v_func)(arg_1, arg_2); - let w = (self.w_func)(arg_1, arg_2); + let arg_1 = div.tau; + let arg_2 = self.draw_margin * div.pi; + let v = (self.v_func)(arg_1, arg_2, pi_sqrt); + let w = (self.w_func)(arg_1, arg_2, pi_sqrt); let denom = 1.0 - w; let pi = div.pi / denom; diff --git a/src/trueskill/mod.rs b/src/trueskill/mod.rs index a76b861..8205a3d 100644 --- a/src/trueskill/mod.rs +++ b/src/trueskill/mod.rs @@ -1809,54 +1809,6 @@ fn build_trunc_layer( beta: f64, starting_id: usize, ) -> Vec { - fn v_w(diff: f64, draw_margin: f64) -> f64 { - let x = diff - draw_margin; - let denom = cdf(x, 0.0, 1.0); - - if denom == 0.0 { - -x - } else { - pdf(x, 0.0, 1.0) / denom - } - } - - fn v_d(diff: f64, draw_margin: f64) -> f64 { - let abs_diff = diff.abs(); - let a = draw_margin - abs_diff; - let b = -draw_margin - abs_diff; - let denom = cdf(a, 0.0, 1.0) - cdf(b, 0.0, 1.0); - let numer = pdf(b, 0.0, 1.0) - pdf(a, 0.0, 1.0); - - let lhs = if denom == 0.0 { a } else { numer / denom }; - let rhs = if diff < 0.0 { -1.0 } else { 1.0 }; - - lhs * rhs - } - - fn w_w(diff: f64, draw_margin: f64) -> f64 { - let x = diff - draw_margin; - let v = v_w(diff, draw_margin); - let w = v * (v + x); - if 0.0 < w && w < 1.0 { - return w; - } - - panic!("floating point error"); - } - - #[allow(clippy::suboptimal_flops)] - fn w_d(diff: f64, draw_margin: f64) -> f64 { - let abs_diff = diff.abs(); - let a = draw_margin - abs_diff; - let b = -draw_margin - abs_diff; - let denom = cdf(a, 0.0, 1.0) - cdf(b, 0.0, 1.0); - - assert!(!(denom == 0.0), "floating point error"); - - let v = v_d(abs_diff, draw_margin); - v.mul_add(v, (a * pdf(a, 0.0, 1.0) - b * pdf(b, 0.0, 1.0)) / denom) - } - let mut v = Vec::with_capacity(team_diff_vars.len()); let mut i = starting_id; for (x, team_diff_var) in team_diff_vars.iter().enumerate() { @@ -1865,14 +1817,14 @@ fn build_trunc_layer( .map(|v| v.0.len() as f64) .sum(); let draw_margin = draw_margin(draw_probability, beta, size); - let v_func: Box f64>; - let w_func: Box f64>; + let v_func: Box f64>; + let w_func: Box f64>; if sorted_teams_and_ranks[x].1 == sorted_teams_and_ranks[x + 1].1 { - v_func = Box::new(v_d); - w_func = Box::new(w_d); + v_func = Box::new(v_draw); + w_func = Box::new(w_draw); } else { - v_func = Box::new(v_w); - w_func = Box::new(w_w); + v_func = Box::new(v_non_draw); + w_func = Box::new(w_non_draw); }; v.push(TruncateFactor::new( @@ -2668,26 +2620,26 @@ mod tests { let results = trueskill_multi_team(&teams_and_ranks, &TrueSkillConfig::new()); assert!((results[0][0].rating - 40.876_849_177_315_655).abs() < f64::EPSILON); - assert!((results[0][1].rating - 45.493_394_092_398_44).abs() < f64::EPSILON); + assert!((results[0][1].rating - 45.493_394_092_398_45).abs() < f64::EPSILON); - assert!((results[1][0].rating - 19.608_650_920_845_236).abs() < f64::EPSILON); + assert!((results[1][0].rating - 19.608_650_920_845_23).abs() < f64::EPSILON); assert!((results[1][1].rating - 18.712_463_514_890_54).abs() < f64::EPSILON); assert!((results[1][2].rating - 29.353_112_227_810_637).abs() < f64::EPSILON); assert!((results[1][3].rating - 9.872_175_198_037_164).abs() < f64::EPSILON); - assert!((results[2][0].rating - 48.829_832_201_455_32).abs() < f64::EPSILON); - assert!((results[2][1].rating - 29.812_500_188_903_005).abs() < f64::EPSILON); + assert!((results[2][0].rating - 48.829_832_201_455_31).abs() < f64::EPSILON); + assert!((results[2][1].rating - 29.812_500_188_902_998).abs() < f64::EPSILON); - assert!((results[0][0].uncertainty - 3.839_527_589_355_37).abs() < f64::EPSILON); + assert!((results[0][0].uncertainty - 3.839_527_589_355_369_8).abs() < f64::EPSILON); assert!((results[0][1].uncertainty - 2.933_671_613_522_051).abs() < f64::EPSILON); - assert!((results[1][0].uncertainty - 6.396_044_310_523_897).abs() < f64::EPSILON); + assert!((results[1][0].uncertainty - 6.396_044_310_523_896).abs() < f64::EPSILON); assert!((results[1][1].uncertainty - 5.624_556_429_622_889).abs() < f64::EPSILON); - assert!((results[1][2].uncertainty - 7.673_456_361_986_594).abs() < f64::EPSILON); + assert!((results[1][2].uncertainty - 7.673_456_361_986_593).abs() < f64::EPSILON); assert!((results[1][3].uncertainty - 3.891_408_425_994_520_3).abs() < f64::EPSILON); - assert!((results[2][0].uncertainty - 4.590_018_525_151_38).abs() < f64::EPSILON); - assert!((results[2][1].uncertainty - 1.976_314_792_712_798_2).abs() < f64::EPSILON); + assert!((results[2][0].uncertainty - 4.590_018_525_151_379).abs() < f64::EPSILON); + assert!((results[2][1].uncertainty - 1.976_314_792_712_798).abs() < f64::EPSILON); } #[test] @@ -2714,7 +2666,7 @@ mod tests { let results = trueskill_multi_team(teams_and_ranks, &TrueSkillConfig::new()); - assert!((results[0][0].rating - 41.720_925_460_665).abs() < f64::EPSILON); + assert!((results[0][0].rating - 41.720_925_460_665_01).abs() < f64::EPSILON); assert!((results[1][0].rating - 20.997_268_045_415_94).abs() < f64::EPSILON); assert!((results[2][0].rating - 41.771_076_420_914_83).abs() < f64::EPSILON); @@ -2753,15 +2705,15 @@ mod tests { let results = trueskill_multi_team(teams_and_ranks, &TrueSkillConfig::new()); - assert!((results[0][0].rating - 46.844_398_641_974_195).abs() < f64::EPSILON); + assert!((results[0][0].rating - 46.844_398_641_974_97).abs() < f64::EPSILON); assert!((results[1][0].rating - -21.0).abs() < f64::EPSILON); - assert!((results[2][0].rating - 121.973_594_228_967_41).abs() < f64::EPSILON); - assert!((results[3][0].rating - 3.577_783_039_440_541_7).abs() < f64::EPSILON); + assert!((results[2][0].rating - 121.973_594_228_967_43).abs() < f64::EPSILON); + assert!((results[3][0].rating - 3.577_783_039_440_43).abs() < f64::EPSILON); - assert!((results[0][0].uncertainty - 4.453_979_220_473_841).abs() < f64::EPSILON); + assert!((results[0][0].uncertainty - 4.453_979_220_477_661).abs() < f64::EPSILON); assert!((results[1][0].uncertainty - 1.871_855_882_391_709_3).abs() < f64::EPSILON); - assert!((results[2][0].uncertainty - 0.083_922_196_135_183_51).abs() < f64::EPSILON); - assert!((results[3][0].uncertainty - 1.197_926_990_096_083_4).abs() < f64::EPSILON); + assert!((results[2][0].uncertainty - 0.083_922_196_135_183_55).abs() < f64::EPSILON); + assert!((results[3][0].uncertainty - 1.197_926_990_096_302_3).abs() < f64::EPSILON); } #[test]