Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl Default for Benchmark {
exercises_per_session: 25,
initial_performance: [0.3, 0.2, 0.25, 0.15, 0.1],
trials_before_stable: 5,
stable_performance: [0.02, 0.03, 0.1, 0.2, 0.65],
stable_performance: [0.03, 0.02, 0.1, 0.2, 0.65],
lapse_rate: 0.1,
},
below_median_profile: StudentProfile {
Expand Down
14 changes: 3 additions & 11 deletions src/data/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,9 @@ impl KeyValueFilter {
value: &str,
filter_type: &FilterType,
) -> bool {
// Check whether the key-value pair is present in the metadata.
let contains_metadata = if metadata.contains_key(key) {
metadata
.get(key)
.unwrap_or(&Vec::new())
.contains(&value.to_string())
} else {
false
};

// Decide whether the filter passes based on its type.
let contains_metadata = metadata
.get(key)
.is_some_and(|values| values.iter().any(|v| v == value));
match filter_type {
FilterType::Include => contains_metadata,
FilterType::Exclude => !contains_metadata,
Expand Down
5 changes: 3 additions & 2 deletions src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ impl DepthFirstScheduler {
.unit_scorer
.get_unit_score(item.unit_id)?
.unwrap_or_default();
let frequency_map = self.data.frequency_map.read();
let candidates = exercises
.into_iter()
.map(|exercise_id| {
Expand All @@ -399,7 +400,7 @@ impl DepthFirstScheduler {
.get_last_seen_days(exercise_id)?
.unwrap_or_default(),
score_velocity: self.unit_scorer.get_exercise_velocity(exercise_id)?,
frequency: self.data.get_exercise_frequency(exercise_id),
frequency: frequency_map.get(&exercise_id).copied().unwrap_or(0),
dead_end: false,
num_dependents: self.data.get_num_dependents(item.unit_id, course_id),
encompasses_weight: 0.0,
Expand Down Expand Up @@ -756,7 +757,7 @@ impl DepthFirstScheduler {
) -> Result<Vec<Candidate>> {
// Initialize the list of candidates.
let max_candidates = self.data.options.batch_size * MAX_CANDIDATE_FACTOR;
let mut all_candidates: Vec<Candidate> = Vec::new();
let mut all_candidates: Vec<Candidate> = Vec::with_capacity(max_candidates);
let mut lessons_in_progress = UstrSet::default();

// The dependency relationships between a course and its lessons are not explicitly encoded
Expand Down
37 changes: 19 additions & 18 deletions src/scheduler/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,17 +322,6 @@ impl SchedulerData {
*frequency += 1;
}

/// Returns the frequency of the given exercise ID.
#[inline]
#[must_use]
pub fn get_exercise_frequency(&self, exercise_id: Ustr) -> usize {
self.frequency_map
.read()
.get(&exercise_id)
.copied()
.unwrap_or(0)
}

/// Returns the unit filter for the saved filter with the given ID. Returns an error if no
/// filter exists with that ID exists.
pub fn get_saved_filter(&self, filter_id: &str) -> Result<Arc<SavedFilter>> {
Expand Down Expand Up @@ -362,21 +351,22 @@ impl SchedulerData {
#[must_use]
pub fn all_valid_exercises_in_lesson(&self, lesson_id: Ustr) -> Vec<Ustr> {
// If the lesson is blacklisted, return no exercises.
if self.blacklisted(lesson_id).unwrap_or(false) {
let blacklist = self.blacklist.read();
if blacklist.blacklisted(lesson_id).unwrap_or(false) {
return vec![];
}

// If the course to which the lesson belongs is blacklisted, return no exercises.
let course_id = self.get_lesson_course(lesson_id).unwrap_or_default();
if self.blacklisted(course_id).unwrap_or(false) {
if blacklist.blacklisted(course_id).unwrap_or(false) {
return vec![];
}

// Get all exercises in the lesson and filter out the blacklisted ones.
let exercises = self.get_lesson_exercises(lesson_id);
exercises
.into_iter()
.filter(|exercise_id| !self.blacklisted(*exercise_id).unwrap_or(false))
.filter(|exercise_id| !blacklist.blacklisted(*exercise_id).unwrap_or(false))
.collect()
}

Expand All @@ -385,11 +375,12 @@ impl SchedulerData {
pub fn all_valid_exercises(&self, unit_id: Ustr) -> Vec<Ustr> {
// First, get the type of the unit. Then get the exercises based on the unit type.
let unit_type = self.get_unit_type(unit_id);
let blacklist = self.blacklist.read();
match unit_type {
None => vec![],
Some(UnitType::Exercise) => {
// Return the exercise if it's not blacklisted.
if self.blacklisted(unit_id).unwrap_or(false) {
if blacklist.blacklisted(unit_id).unwrap_or(false) {
vec![]
} else {
vec![unit_id]
Expand All @@ -398,7 +389,7 @@ impl SchedulerData {
Some(UnitType::Lesson) => self.all_valid_exercises_in_lesson(unit_id),
Some(UnitType::Course) => {
// If the course is blacklisted, return no exercises.
if self.blacklisted(unit_id).unwrap_or(false) {
if blacklist.blacklisted(unit_id).unwrap_or(false) {
return vec![];
}

Expand Down Expand Up @@ -573,13 +564,23 @@ mod test {
let library = init_test_simulation(temp_dir.path(), &TEST_LIBRARY)?;
let scheduler_data = library.get_scheduler_data();

let frequency_map = scheduler_data.frequency_map.read();
assert_eq!(
scheduler_data.get_exercise_frequency(Ustr::from("0::0::0")),
frequency_map
.get(&Ustr::from("0::0::0"))
.copied()
.unwrap_or(0),
0
);
drop(frequency_map);

scheduler_data.increment_exercise_frequency(Ustr::from("0::0::0"));
let frequency_map = scheduler_data.frequency_map.read();
assert_eq!(
scheduler_data.get_exercise_frequency(Ustr::from("0::0::0")),
frequency_map
.get(&Ustr::from("0::0::0"))
.copied()
.unwrap_or(0),
1
);
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions src/scheduler/review_knocker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ impl ReviewKnocker {
let mut unit_weight_map = UstrMap::default();
let unit_set = initial_batch
.iter()
.flat_map(|candidate| vec![candidate.lesson_id, candidate.course_id])
.flat_map(|candidate| [candidate.lesson_id, candidate.course_id])
.collect::<UstrSet>();

// For each, find all their encompassed lessons and courses.
for unit_id in unit_set {
// Initialize the stack and set of visited units.
let mut stack: Vec<WeightQueueItem> = Vec::new();
let mut stack: Vec<WeightQueueItem> = Vec::with_capacity(16);
stack.push(WeightQueueItem {
unit_id,
reward: 1.0,
Expand Down
4 changes: 2 additions & 2 deletions src/scheduler/reward_propagator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl RewardPropagator {
let initial_reward = Self::initial_reward(score);
let next_lessons = Self::get_next_units(unit_graph, lesson_id, initial_reward);
let next_courses = Self::get_next_units(unit_graph, course_id, initial_reward);
let mut stack: Vec<UnitReward> = Vec::new();
let mut stack: Vec<UnitReward> = Vec::with_capacity(16);
next_lessons
.iter()
.chain(next_courses.iter())
Expand Down Expand Up @@ -156,7 +156,7 @@ impl RewardPropagator {
});
}
}
results.values().cloned().collect()
results.into_values().collect()
}

/// Propagates the given score through the graph.
Expand Down
11 changes: 7 additions & 4 deletions src/scheduler/unit_scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ impl UnitScorer {

// Invalidate the caches depending on the type of the unit.
let graph = self.data.unit_graph.read();
let mut exercise_cache = self.exercise_cache.borrow_mut();
match graph.get_unit_type(unit_id) {
// If the unit is an exercise, invalidate the cached score of its lesson and course. If
// the unit is a lesson, invalidate the cached score of its course.
Expand All @@ -129,19 +130,21 @@ impl UnitScorer {
}
if let Some(exercise_ids) = graph.get_lesson_exercises(unit_id) {
for exercise_id in exercise_ids.iter() {
self.exercise_cache.borrow_mut().remove(exercise_id);
exercise_cache.remove(exercise_id);
}
}
}
// For courses, invalidate the scores of all lessons and exercises in the course.
Some(UnitType::Course) => {
if let Some(lesson_ids) = graph.get_course_lessons(unit_id) {
let mut lesson_cache = self.lesson_cache.borrow_mut();
let mut lesson_trials_cache = self.lesson_trials_cache.borrow_mut();
for lesson_id in lesson_ids.iter() {
self.lesson_cache.borrow_mut().remove(lesson_id);
self.lesson_trials_cache.borrow_mut().remove(lesson_id);
lesson_cache.remove(lesson_id);
lesson_trials_cache.remove(lesson_id);
if let Some(exercise_ids) = graph.get_lesson_exercises(*lesson_id) {
for exercise_id in exercise_ids.iter() {
self.exercise_cache.borrow_mut().remove(exercise_id);
exercise_cache.remove(exercise_id);
}
}
}
Expand Down
Loading