Skip to content

Commit

Permalink
Remove introsort, parallelise merging across mech_ids.
Browse files Browse the repository at this point in the history
  • Loading branch information
thorstenhater committed Nov 21, 2024
1 parent c48fff5 commit 218d0b3
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 37 deletions.
94 changes: 69 additions & 25 deletions arbor/backends/event_stream_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "backends/event.hpp"
#include "backends/event_stream_state.hpp"
#include "event_lane.hpp"
#include "threading/threading.hpp"
#include "timestep_range.hpp"
#include "util/partition.hpp"

Expand All @@ -22,9 +23,11 @@ struct event_stream_base {
protected: // members
std::vector<event_data_type> ev_data_;
std::vector<std::size_t> ev_spans_ = {0};
std::vector<std::size_t> lane_spans_;
std::size_t index_ = 0;
event_data_type* base_ptr_ = nullptr;


public:
event_stream_base() = default;

Expand Down Expand Up @@ -63,31 +66,61 @@ struct event_stream_base {
virtual void init() = 0;
};

struct spike_event_stream_base : event_stream_base<deliverable_event> {
struct spike_event_stream_base: event_stream_base<deliverable_event> {
// Take in one event lane per cell `gid` and reorganise into one stream per
// synapse `mech_id`.
//
// - Due to the cell group coalescing multiple cells and their synapses into
// one object, one `mech_id` can touch multiple lanes / `gid`s.
// - Inversely, two `mech_id`s can cover different, but overlapping sets of `gid`s
// - Multiple `mech_id`s can receive events from the same source
//
// Pre:
// - Events in `lanes[ix]` forall ix
// * are sorted by time
// * `ix` maps to exactly one cell in the local cell group
// - `divs` partitions `handles` such that the target handles for cell `ix`
// are located in `handles[divs[ix]..divs[ix + 1]]`
// - `handles` records `(mech_id, index)` of a target s.t. `index` is the instance
// with the set identified by `mech_id`, e.g. a single synapse placed on a multi-
// location locset (plus the merging across cells by groups)
// Post:
// - streams[mech_id] contains a list of all events for synapse `mech_id` s.t.
// * the list is sorted by (time_step, lid, time)
// * the list is partitioned by `time_step` via `ev_spans`
template<typename EventStream>
friend void initialize(const event_lane_subrange& lanes,
const std::vector<target_handle>& handles,
const std::vector<std::size_t>& divs,
const timestep_range& steps,
std::unordered_map<unsigned, EventStream>& streams) {
std::unordered_map<unsigned, EventStream>& streams,
task_system_handle ts) {
arb_assert(lanes.size() < divs.size());

// reset streams and allocate sufficient space for temporaries
auto n_steps = steps.size();
for (auto& [k, v]: streams) {
v.clear();
v.spike_counter_.clear();
v.spike_counter_.resize(steps.size(), 0);
v.spikes_.clear();
for (auto& [id, stream]: streams) {
stream.clear();
stream.spike_counter_.clear();
stream.spike_counter_.resize(steps.size(), 0);
stream.spikes_.clear();
// ev_data_ has been cleared during v.clear(), so we use its capacity
v.spikes_.reserve(v.ev_data_.capacity());
stream.spikes_.reserve(stream.ev_data_.capacity());
// record sizes of streams for later merging
//
// The idea here is that this records the division points `pd` where
// `stream` was updated by the lane `lid`. As events within one lane are
// sorted, we known that events between two division points are sorted.
// Then, we can use `merge_inplace` over `sort` for a small but noticeable
// speed-up.
stream.lane_spans_.resize(lanes.size() + 1);
for (auto& ix: stream.lane_spans_) ix = stream.spikes_.size();
}

// loop over lanes: group events by mechanism and sort them by time
auto cell = 0;
for (const auto& lane: lanes) {
auto div = divs[cell];
++cell;
arb_size_type step = 0;
for (const auto& evt: lane) {
auto time = evt.time;
Expand All @@ -100,28 +133,39 @@ struct spike_event_stream_base : event_stream_base<deliverable_event> {
const auto& handle = handles[div + target];
auto& stream = streams[handle.mech_id];
stream.spikes_.push_back(spike_data{step, handle.mech_index, time, weight});
// insertion sort with last element as pivot
// ordering: first w.r.t. step, within a step: mech_index, within a mech_index: time
auto first = stream.spikes_.begin();
auto last = stream.spikes_.end();
auto pivot = std::prev(last, 1);
std::rotate(std::upper_bound(first, pivot, *pivot), pivot, last);
// increment count in current time interval
stream.spike_counter_[step]++;
}
// record current sizes here. putting this into the above loop is slower. significantly
for (auto& [id, stream]: streams) stream.lane_spans_[cell + 1] = stream.spikes_.size();
++cell;
}

// parallelise over streams
auto tg = threading::task_group(ts.get());
for (auto& [id, stream]: streams) {
// copy temporary deliverable_events into stream's ev_data_
stream.ev_data_.reserve(stream.spikes_.size());
std::transform(stream.spikes_.begin(), stream.spikes_.end(), std::back_inserter(stream.ev_data_),
[](auto const& e) noexcept -> arb_deliverable_event_data {
return {e.mech_index, e.weight}; });
// scan over spike_counter_ and written to ev_spans_
util::make_partition(stream.ev_spans_, stream.spike_counter_);
// delegate to derived class init: static cast necessary to access protected init()
static_cast<spike_event_stream_base&>(stream).init();
tg.run([&stream]() {
// scan over spike_counter_
util::make_partition(stream.ev_spans_, stream.spike_counter_);
// leverage our earlier partitioning to merge the partitions
// theoretically, this could be parallelised, too, practically it didn't pay off
auto& part = stream.lane_spans_;
for (size_t ix = 0; ix < part.size() - 1; ++ix) {
std::inplace_merge(stream.spikes_.begin(),
stream.spikes_.begin() + part[ix],
stream.spikes_.begin() + part[ix + 1]);
}
// Further optimisation: merge(!) merging, transforming, and appending into one
// call.
// copy temporary deliverable_events into stream's ev_data_
stream.ev_data_.reserve(stream.spikes_.size());
std::transform(stream.spikes_.begin(), stream.spikes_.end(),
std::back_inserter(stream.ev_data_),
[](auto const& e) noexcept -> arb_deliverable_event_data { return {e.mech_index, e.weight}; });
// delegate to derived class init: static cast necessary to access protected init()
static_cast<spike_event_stream_base&>(stream).init();
});
}
tg.wait();
}

protected: // members
Expand Down
5 changes: 3 additions & 2 deletions arbor/backends/shared_state_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ struct shared_state_base {
const std::vector<std::vector<sample_event>>& samples,
const timestep_range& dts,
const std::vector<target_handle>& handles,
const std::vector<size_t>& divs) {
const std::vector<size_t>& divs,
task_system_handle ts) {
auto d = static_cast<D*>(this);
// events
initialize(lanes, handles, divs, dts, d->streams);
initialize(lanes, handles, divs, dts, d->streams, ts);
// samples
auto n_samples = util::sum_by(samples, [] (const auto& s) {return s.size();});
if (d->sample_time.size() < n_samples) {
Expand Down
7 changes: 2 additions & 5 deletions arbor/fvm_lowered_cell.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,7 @@ struct fvm_initialization_data {
struct fvm_lowered_cell {
virtual void reset() = 0;

virtual fvm_initialization_data initialize(
const std::vector<cell_gid_type>& gids,
const recipe& rec) = 0;
virtual fvm_initialization_data initialize(const std::vector<cell_gid_type>& gids, const recipe& rec) = 0;

virtual fvm_integration_result integrate(const timestep_range& dts,
const event_lane_subrange& event_lanes,
Expand All @@ -249,8 +247,7 @@ struct fvm_lowered_cell {

using fvm_lowered_cell_ptr = std::unique_ptr<fvm_lowered_cell>;

ARB_ARBOR_API fvm_lowered_cell_ptr make_fvm_lowered_cell(backend_kind p, const execution_context& ctx,
std::uint64_t seed = 0);
ARB_ARBOR_API fvm_lowered_cell_ptr make_fvm_lowered_cell(backend_kind p, const execution_context& ctx, std::uint64_t seed = 0);

inline
void serialize(serializer& s, const std::string& k, const fvm_lowered_cell& v) { v.t_serialize(s, k); }
Expand Down
8 changes: 3 additions & 5 deletions arbor/fvm_lowered_cell_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ struct fvm_lowered_cell_impl: public fvm_lowered_cell {
fvm_lowered_cell_impl(execution_context ctx, arb_seed_type seed = 0):
context_(ctx),
seed_{seed}
{};
{}

void reset() override;

fvm_initialization_data initialize(
const std::vector<cell_gid_type>& gids,
const recipe& rec) override;
fvm_initialization_data initialize(const std::vector<cell_gid_type>& gids, const recipe& rec) override;

fvm_integration_result integrate(const timestep_range& dts,
const event_lane_subrange& event_lanes,
Expand Down Expand Up @@ -176,7 +174,7 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate(const timestep_
// Integration setup
PE(advance:integrate:setup);
// Push samples and events down to the state and reset the spike thresholds.
state_->begin_epoch(event_lanes, staged_samples, dts, target_handles_, target_handle_divisions_);
state_->begin_epoch(event_lanes, staged_samples, dts, target_handles_, target_handle_divisions_, context_.thread_pool);
PL();

// loop over timesteps
Expand Down

0 comments on commit 218d0b3

Please sign in to comment.