Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trial using threading to run events simultaneously #1327

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
24 changes: 18 additions & 6 deletions src/scripts/profiling/scale_run.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
from tlo import Date, Simulation, logging
from tlo.analysis.utils import parse_log_file as parse_log_file_fn
from tlo.methods.fullmodel import fullmodel
from tlo.threaded_simulation import ThreadedSimulation

_TLO_ROOT: Path = Path(__file__).parents[3].resolve()
_TLO_OUTPUT_DIR: Path = (_TLO_ROOT / "outputs").resolve()
@@ -55,6 +56,7 @@ def scale_run(
ignore_warnings: bool = False,
log_final_population_checksum: bool = True,
profiler: Optional["Profiler"] = None,
n_threads: Optional[int] = 0,
) -> Simulation:
if ignore_warnings:
warnings.filterwarnings("ignore")
@@ -74,12 +76,16 @@ def scale_run(
"suppress_stdout": disable_log_output_to_stdout,
}

sim = Simulation(
start_date=start_date,
seed=seed,
log_config=log_config,
show_progress_bar=show_progress_bar,
)
sim_args = {
"start_date": start_date,
"seed": seed,
"log_config": log_config,
"show_progress_bar": show_progress_bar,
}
if n_threads:
sim = ThreadedSimulation(n_threads=n_threads, **sim_args)
else:
sim = Simulation(**sim_args)

# Register the appropriate modules with the arguments passed through
sim.register(
@@ -269,6 +275,12 @@ def scale_run(
),
action="store_true",
)
parser.add_argument(
"--n-threads",
help="Run a threaded simulation using the given number of threaded workers",
type=int,
default=0,
)
args = parser.parse_args()
args_dict = vars(args)

118 changes: 81 additions & 37 deletions src/tlo/simulation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""The main simulation controller."""
from __future__ import annotations

import datetime
import heapq
import itertools
import time
from collections import OrderedDict
from pathlib import Path
from typing import Dict, Optional, Union
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union

import numpy as np

@@ -15,11 +16,14 @@
from tlo.events import Event, IndividualScopeEventMixin
from tlo.progressbar import ProgressBar

if TYPE_CHECKING:
from tlo import Module

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Simulation:
class _BaseSimulation:
"""The main control centre for a simulation.
This class contains the core simulation logic and event queue, and holds
@@ -41,8 +45,15 @@ class Simulation:
The simulation-level random number generator.
Note that individual modules also have their own random number generator
with independent state.
The `step_through_events` method is implemented by the `Simulation` and
`ThreadedSimulation` classes, which controls how the simulation events are
fired.
"""

__name__: str = "_BaseSimulation"
modules: OrderedDict[str, Module]

def __init__(self, *, start_date: Date, seed: int = None, log_config: dict = None,
show_progress_bar=False):
"""Create a new simulation.
@@ -63,6 +74,7 @@ def __init__(self, *, start_date: Date, seed: int = None, log_config: dict = Non
self.population: Optional[Population] = None

self.show_progress_bar = show_progress_bar
self.progress_bar = None

# logging
if log_config is None:
@@ -205,37 +217,18 @@ def simulate(self, *, end_date):
for module in self.modules.values():
module.initialise_simulation(self)

progress_bar = None
if self.show_progress_bar:
num_simulated_days = (end_date - self.start_date).days
progress_bar = ProgressBar(
num_simulated_days = (self.end_date - self.start_date).days
self.progress_bar = ProgressBar(
num_simulated_days, "Simulation progress", unit="day")
progress_bar.start()

while self.event_queue:
event, date = self.event_queue.next_event()
self.progress_bar.start()

if self.show_progress_bar:
simulation_day = (date - self.start_date).days
stats_dict = {
"date": str(date.date()),
"dataframe size": str(len(self.population.props)),
"queued events": str(len(self.event_queue)),
}
if "HealthSystem" in self.modules:
stats_dict["queued HSI events"] = str(
len(self.modules["HealthSystem"].HSI_EVENT_QUEUE)
)
progress_bar.update(simulation_day, stats_dict=stats_dict)

if date >= end_date:
self.date = end_date
break
self.fire_single_event(event, date)
# Run the simulation by firing events in the queue
self.step_through_events()

# The simulation has ended.
if self.show_progress_bar:
progress_bar.stop()
self.progress_bar.stop()

for module in self.modules.values():
module.on_simulation_end()
@@ -253,6 +246,17 @@ def simulate(self, *, end_date):
finally:
self.output_file.release()

def step_through_events(self) -> None:
"""
Method for forward-propagating the simulation, by executing
the scheduled events in the queue. This is overwritten by
inheriting classes.
"""
raise NotImplementedError(
f"{self.__name__} is not intended to be simulated, "
"use either Simulation or ThreadedSimulation to run a simulation."
)

def schedule_event(self, event, date):
"""Schedule an event to happen on the given future date.
@@ -269,15 +273,6 @@ def schedule_event(self, event, date):

self.event_queue.schedule(event=event, date=date)

def fire_single_event(self, event, date):
"""Fires the event once for the given date
:param event: :py:class:`Event` to fire
:param date: the date of the event
"""
self.date = date
event.run()

def do_birth(self, mother_id):
"""Create a new child person.
@@ -308,6 +303,23 @@ def find_events_for_person(self, person_id: int):

return person_events

def update_progress_bar(self, new_date: Date):
"""
Updates the simulation's progress bar, if this is in use.
"""
if self.show_progress_bar:
simulation_day = (new_date - self.start_date).days
stats_dict = {
"date": str(new_date.date()),
"dataframe size": str(len(self.population.props)),
"queued events": str(len(self.event_queue)),
}
if "HealthSystem" in self.modules:
stats_dict["queued HSI events"] = str(
len(self.modules["HealthSystem"].HSI_EVENT_QUEUE)
)
self.progress_bar.update(simulation_day, stats_dict=stats_dict)


class EventQueue:
"""A simple priority queue for events.
@@ -329,7 +341,7 @@ def schedule(self, event, date):
entry = (date, event.priority, next(self.counter), event)
heapq.heappush(self.queue, entry)

def next_event(self):
def next_event(self) -> Tuple[Event, Date]:
"""Get the earliest event in the queue.
:returns: an (event, date) pair
@@ -340,3 +352,35 @@ def next_event(self):
def __len__(self):
""":return: the length of the queue"""
return len(self.queue)


class Simulation(_BaseSimulation):
"""
Default simulation type, which runs a serial simulation.
Events in the event_queue are executed in sequence, one
after the other, in the order they appear in the queue.
See `_BaseSimulation` for more details.
"""

def step_through_events(self) -> None:
"""Serial simulation: events are executed in the
order they occur in the queue."""
while self.event_queue:
event, date = self.event_queue.next_event()

self.update_progress_bar(date)

if date >= self.end_date:
self.date = self.end_date
break
self.fire_single_event(event, date)

def fire_single_event(self, event, date):
"""Fires the event once for the given date
:param event: :py:class:`Event` to fire
:param date: the date of the event
"""
self.date = date
event.run()
230 changes: 230 additions & 0 deletions src/tlo/threaded_simulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from queue import Queue
from threading import Thread
from time import sleep
from typing import Callable, List
from warnings import warn

from tlo.events import Event, IndividualScopeEventMixin
from tlo.simulation import _BaseSimulation

MAX_THREADS = 4 # make more elegant, probably examine the OS

class ThreadController:
"""
Thread controllers serve an organisational role, and allow us to
keep track of threads that we create for debugging purposes.
They also provide convenient wrapper functions to batch start the
threads they control all at once, and will manage the teardown of
their own threads when this is ready.
Threads spawned by the controller are intended to form a "pool" of
workers that will routinely check a Queue object for tasks to
perform, and otherwise will be idle. Worker targets should be
functions that allow the thread access to the job queue, whilst
they persist.
"""
_n_threads: int
_thread_list: List[Thread]

_worker_name: str

@property
def n_threads(self) -> int:
"""
Number of threads that this controller is operating.
"""
return self._n_threads

def __init__(self, n_threads: int = 1, name: str = "Worker") -> None:
"""
Create a new thread controller.
:param n_threads: Number of threads to be spawned, in addition to
the main thread.
:param name: Name to assign to worker threads that this controller
creates, for logging and internal ID purposes.
"""
# Determine how many threads to use given the machine maximum,
# and the user's request. Be sure to save one for the main thread!
self._n_threads = min(n_threads, MAX_THREADS - 1)
if self._n_threads < n_threads:
warn(
f"Requested {n_threads} but this exceeds the maximum possible number of worker threads ({MAX_THREADS - 1}). Restricting to {self._n_threads}."
)
assert (
self._n_threads > 0
), f"Instructed to use {self._n_threads} threads, which must be non-negative. Use a serial simulation if you do not want to delegate event execution to threads."

# Prepare the list of threads, but do not initialise threads yet
# since they need access to some of the Simulation properties
self._thread_list = []

self._worker_name = name

def create_all(self, target: Callable[[], None]) -> None:
"""
Creates the threads that will be managed by this controller,
and sets their targets.
Targets are not executed until the start_all method is called.
Targets are functions that take no arguments and return
no values. Workers will execute these functions - preserving
context and access of the functions that are passed in.
Passing in something like foo.bar will provide access to the
foo object and attempt to run the bar method on said object,
for example.
"""
for i in range(self._n_threads):
self._thread_list.append(
Thread(target=target, daemon=True, name=f"{self._worker_name}-{i}")
)

def start_all(self) -> None:
"""
Start all threads managed by this controller.
"""
for thread in self._thread_list:
thread.start()


class ThreadedSimulation(_BaseSimulation):
"""
Class for running threaded simulations. Events in the queue that can
be executed in parallel are delegated to a worker pool, to be executed
when resources become available.
Certain events cannot be executed in parallel threads safely (notably
population-level events, but also events that attempt to advance time).
When encountering such events, all workers complete the remaining
"thread-safe" events before the unsafe event is triggered.
Progress bar for threaded simulations only advances when time advances,
and statistics do not dynamically update as each event is fired.
TODO: Prints to actually using the logger
TODO: Prints to include the worker thread they were spit out from
"""
# Tracks the job queue that will be dispatched to worker threads
_worker_queue: Queue
# Workers must always work on different individuals due to
# their ability to edit the population DataFrame.
_worker_patient_targets: set
# Provides programmatic access to the threads created for the
# simulation
thread_controller: ThreadController

# Safety-catch variables to ensure safe execution of events.
_individuals_currently_being_examined: set

def __init__(self, n_threads: int = 1, **kwargs) -> None:
"""
In addition to the usual simulation instantiation arguments,
threaded simulations must also be passed the number of
worker threads to be used.
:param n_threads: Number of threads to use - in addition to
the main thread - when running simulation events.
"""
# Initialise as you would for any other simulation
super().__init__(**kwargs)

# Setup the thread controller
self.thread_controller = ThreadController(n_threads=n_threads, name = "EventWorker-")

# Set the target workflow of all workers
self.thread_controller.create_all(self._worker_target)

self._worker_queue = Queue()
# Initialise the set tracking which individuals the event workers
# are currently targeting.
self._worker_patient_targets = set()

def _worker_target(self) -> None:
"""
Workflow that threads will execute.
The workflow assumes that events added to the worker queue
are always safe to execute in any thread, alongside any
other events that might currently be in the queue.
"""
# While thread/worker is alive
# WOULD LIKE TO NOT HAVE THIS. We could spawn threads only when they're needed
# and then limit the number we have spawned at once, but creating a thread is also an expensive operation.
# Plus, the .get() method puts the thread to sleep until it gets something, so this should be fine.
while True:
# Check for the next job in the queue
event_to_run: Event = self._worker_queue.get()
target = event_to_run.target
# Wait for other events targeting the same individual to complete
while target in self._worker_patient_targets:
# Stall if another thread is currently executing an event
# which targets the same individual.
# Add some sleep time here to avoid near-misses.
sleep(0.01)
# Flag that this thread is running an event on this patient
self._worker_patient_targets.add(target)
event_to_run.run()
self._worker_patient_targets.remove(target)
# Report success and await next task
self._worker_queue.task_done()

@staticmethod
def event_must_run_in_main_thread(event: Event) -> bool:
"""
Return True if the event passed in must be run in the main thread, in serial.
Population-level events must always run in the main thread with no worker
events running in parallel, since they need to scan the state of the simulation
at that moment in time and workers have write access to simulation properties.
"""
if not isinstance(event, IndividualScopeEventMixin):
return True
return False

def step_through_events(self) -> None:
# Start the threads
self.thread_controller.start_all()

# Whilst the event queue is not empty
while self.event_queue:
event, date_of_next_event = self.event_queue.next_event()
self.update_progress_bar(self.date)

# If the simulation should end, escape
if date_of_next_event >= self.end_date:
break
# If we want to advance time, we need to ensure that
# the worker queue. Otherwise, a worker might be running an
# event from the previous date but may still call sim.date
# to get the "current" time, which would then be out-of-sync.
elif date_of_next_event > self.date:
# This event moves time forward, wait until all jobs
# from the current date have finished before advancing time
self.wait_for_workers()
# All jobs from the previous day have ended.
# Advance time and continue.
self.date = date_of_next_event

# Next, determine if the event to be run can be delegated to the
# worker pool.
if self.event_must_run_in_main_thread(event):
# Event needs all workers to finish, then to run in
# the main thread (this one)
self.wait_for_workers()
event.run()
else:
# This job can be delegated to the worker pool, and run safely
self._worker_queue.put(event)

# We may have exhausted all the events in the queue, but the workers will
# still need time to process them all!
self.wait_for_workers()
self.update_progress_bar(date_of_next_event)

def wait_for_workers(self) -> None:
"""
Pauses simulation progression until all worker threads
are ready and waiting to receive a new job.
"""
self._worker_queue.join()