Skip to content

Commit

Permalink
speed up Any by ~30% and Weights by ~8% by using lighter random library
Browse files Browse the repository at this point in the history
  • Loading branch information
Fogapod committed Feb 24, 2024
1 parent 92e5753 commit 78c4128
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ include = [

[dependencies]
# random replacements
rand = "0.8"
fastrand = "2.0"

# pattern definition
# excluded a couple unicode features
Expand Down
2 changes: 1 addition & 1 deletion src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ mod tests {
0: {"Original": ()},
0: {"Original": ()},
0: {"Original": ()},
},
}
"#,
)
.err()
Expand Down
94 changes: 73 additions & 21 deletions src/tag_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ use crate::deserialize::SortedMap;

use std::{borrow::Cow, error::Error, fmt::Display};

use rand::seq::SliceRandom;

use crate::{tag::Tag, utils::LiteralString, Match};

/// Same as [`Literal`] with `"$0"` argument: returns entire match.
Expand Down Expand Up @@ -109,15 +107,17 @@ impl Any {
#[cfg_attr(feature = "deserialize", typetag::deserialize)]
impl Tag for Any {
fn generate<'a>(&self, m: &Match<'a>) -> Cow<'a, str> {
let mut rng = rand::thread_rng();
let i = fastrand::usize(..self.0.len());

self.0.choose(&mut rng).expect("empty Any").generate(m)
self.0[i].generate(m)
}
}

/// [`Weights`] creation might fail
#[derive(Debug)]
pub enum WeightsError {
/// Must provide at least one element
ZeroItems,
/// Sum of all weights must be positive
NonPositiveTotalWeights,
}
Expand All @@ -128,6 +128,7 @@ impl Display for WeightsError {
f,
"{}",
match self {
Self::ZeroItems => "expected at least one element to choose from",
Self::NonPositiveTotalWeights => "weights must add up to a positive number",
}
)
Expand All @@ -141,32 +142,58 @@ impl Display for WeightsError {
derive(serde::Deserialize),
serde(try_from = "SortedMap<u64, Box<dyn Tag>, false>")
)]
pub struct Weights(Vec<(u64, Box<dyn Tag>)>);
pub struct Weights {
choices: Vec<Box<dyn Tag>>,
cum_weights: Vec<u64>,
}

impl Weights {
pub fn new(items: Vec<(u64, Box<dyn Tag>)>) -> Result<Self, WeightsError> {
if items.iter().fold(0, |sum, (w, _)| sum + w) == 0 {
return Err(WeightsError::NonPositiveTotalWeights);
}
let (weights, choices) = items.into_iter().unzip();

Ok(Self(items))
let cum_weights = Self::cum_weights(weights)?;

Ok(Self {
choices,
cum_weights,
})
}

pub fn new_boxed(items: Vec<(u64, Box<dyn Tag>)>) -> Result<Box<Self>, WeightsError> {
Ok(Box::new(Self::new(items)?))
}

fn cum_weights(mut weights: Vec<u64>) -> Result<Vec<u64>, WeightsError> {
if weights.is_empty() {
return Err(WeightsError::ZeroItems);
}

let mut previous = weights[0];
for w in &mut weights[1..] {
*w += previous;
previous += *w - previous;
}

if weights[weights.len() - 1] == 0 {
return Err(WeightsError::NonPositiveTotalWeights);
}

Ok(weights)
}

fn random_choice(&self) -> usize {
let random_point = fastrand::u64(0..self.cum_weights.len() as u64);

match self.cum_weights.binary_search(&random_point) {
Ok(i) | Err(i) => i,
}
}
}

#[cfg_attr(feature = "deserialize", typetag::deserialize)]
impl Tag for Weights {
fn generate<'a>(&self, m: &Match<'a>) -> Cow<'a, str> {
let mut rng = rand::thread_rng();

self.0
.choose_weighted(&mut rng, |item| item.0)
.expect("empty Weights")
.1
.generate(m)
self.choices[self.random_choice()].generate(m)
}
}

Expand Down Expand Up @@ -250,15 +277,13 @@ impl Tag for Concat {

#[cfg(test)]
mod tests {
use super::*;
use crate::Match;

use std::borrow::Cow;

use regex_automata::meta::Regex;

use crate::{
tag_impls::{Any, Concat, Literal, Lower, Original, Tag, Upper, Weights},
Match,
};

fn make_match(pattern: &str) -> Match {
let re = Regex::new(".+").unwrap();
let mut caps = re.create_captures();
Expand Down Expand Up @@ -314,6 +339,24 @@ mod tests {
assert!(["bar", "baz"].contains(&selected.as_str()));
}

#[test]
fn weights_cum_weights_errors() {
assert!(Weights::cum_weights(Vec::new()).is_err());
assert!(Weights::cum_weights(vec![0, 0, 0, 0, 0]).is_err());
}

#[test]
fn weights_cum_weights() {
assert_eq!(
Weights::cum_weights(vec![1, 2, 3, 4, 5]).unwrap(),
vec![1, 3, 6, 10, 15]
);
assert_eq!(
Weights::cum_weights(vec![5, 4, 3, 2, 1]).unwrap(),
vec![5, 9, 12, 14, 15]
);
}

#[test]
fn weights() {
let tag = Weights::new(vec![
Expand All @@ -328,6 +371,15 @@ mod tests {
assert!(vec!["bar", "baz"].contains(&selected.as_str()));
}

#[test]
fn weights_single() {
let tag = Weights::new(vec![(50, Literal::new_boxed("test"))]).unwrap();

let selected = apply(&tag, "test").into_owned();

assert_eq!(selected, "test");
}

#[test]
fn upper() {
// double wrapped for coverage
Expand Down

0 comments on commit 78c4128

Please sign in to comment.