2727#include " ttg/util/print.h"
2828#include " ttg/util/trace.h"
2929#include " ttg/util/typelist.h"
30+ #include " ttg/aggregator.h"
3031
3132#include " ttg/serialization/data_descriptor.h"
3233
@@ -744,7 +745,9 @@ namespace ttg_parsec {
744745 " The fourth template for ttg::TT must be a ttg::typelist containing the input types" );
745746 // create a virtual control input if the input list is empty, to be used in invoke()
746747 using actual_input_tuple_type = std::conditional_t <!ttg::meta::typelist_is_empty_v<input_valueTs>,
747- ttg::meta::typelist_to_tuple_t <input_valueTs>, std::tuple<void >>;
748+ ttg::meta::remove_wrapper_tuple_t <
749+ ttg::meta::typelist_to_tuple_t <input_valueTs>>,
750+ std::tuple<void >>;
748751 using input_tuple_type = ttg::meta::typelist_to_tuple_t <input_valueTs>;
749752 static_assert (ttg::meta::is_tuple_v<output_terminalsT>,
750753 " Second template argument for ttg::TT must be std::tuple containing the output terminal types" );
@@ -786,6 +789,8 @@ namespace ttg_parsec {
786789 using input_values_tuple_type = ttg::meta::drop_void_t <ttg::meta::decayed_typelist_t <input_tuple_type>>;
787790 using input_refs_tuple_type = ttg::meta::drop_void_t <ttg::meta::add_glvalue_reference_tuple_t <input_tuple_type>>;
788791
792+ using aggregator_factory_tuple_type = ttg::meta::aggregator_factory_tuple_type_t <input_edges_type>;
793+
789794 static constexpr int numinvals =
790795 std::tuple_size_v<input_refs_tuple_type>; // number of input arguments with values (i.e. omitting the control
791796 // input, if any)
@@ -810,6 +815,7 @@ namespace ttg_parsec {
810815
811816 input_terminals_type input_terminals;
812817 output_terminalsT output_terminals;
818+ aggregator_factory_tuple_type aggregator_factories;
813819
814820 protected:
815821 const auto &get_output_terminals () const { return output_terminals; }
@@ -1295,10 +1301,13 @@ namespace ttg_parsec {
12951301 task_t *task;
12961302 auto &world_impl = world.impl ();
12971303 auto &reducer = std::get<i>(input_reducers);
1304+ constexpr bool is_aggregator = ttg::detail::is_aggregator_v<
1305+ typename std::tuple_element_t <i, input_edges_type>::value_type>;
12981306 bool release = false ;
12991307 bool remove_from_hash = true ;
1308+ bool use_hash_table = is_aggregator || (numins > 1 ) || reducer;
13001309 /* If we have only one input and no reducer on that input we can skip the hash table */
1301- if (numins > 1 || reducer ) {
1310+ if (use_hash_table ) {
13021311 parsec_hash_table_lock_bucket (&tasks_table, hk);
13031312 if (nullptr == (task = (task_t *)parsec_hash_table_nolock_find (&tasks_table, hk))) {
13041313 task = create_new_task (key);
@@ -1310,17 +1319,52 @@ namespace ttg_parsec {
13101319 remove_from_hash = false ;
13111320 release = true ;
13121321 }
1313- parsec_hash_table_unlock_bucket (&tasks_table, hk);
1322+ /* we'll keep the lock for later */
13141323 } else {
13151324 task = create_new_task (key);
13161325 world_impl.increment_created ();
13171326 remove_from_hash = false ;
13181327 }
13191328
1320- if (reducer) { // is this a streaming input? reduce the received value
1329+ if constexpr (is_aggregator) {
1330+
1331+ /* we use the lock to ensure mutual exclusion when inserting into the aggregator */
1332+
1333+ using aggregator_t = typename std::tuple_element_t <i, input_edges_type>::value_type;
1334+ aggregator_t * agg;
1335+ ttg_data_copy_t *agg_copy;
1336+ if (nullptr == (agg_copy = reinterpret_cast <ttg_data_copy_t *>(task->parsec_task .data [i].data_in ))) {
1337+ /* create a new aggregator */
1338+ agg_copy = detail::create_new_datacopy (std::get<i>(aggregator_factories)());
1339+ task->parsec_task .data [i].data_in = agg_copy;
1340+ }
1341+ agg = reinterpret_cast <aggregator_t *>(agg_copy->device_private );
1342+
1343+ ttg_data_copy_t *copy;
1344+ if (nullptr != copy_in) {
1345+ /* register this copy with the task */
1346+ copy = detail::register_data_copy<std::decay_t <Value>>(copy_in, task, std::is_const_v<typename aggregator_t ::value_type>);
1347+ } else {
1348+ copy = detail::create_new_datacopy (std::forward<Value>(value));
1349+ }
1350+ /* put the value into the aggregator */
1351+ agg->add_value (*reinterpret_cast <std::decay_t <Value> *>(copy->device_private ));
1352+ if (agg->has_target ()) {
1353+ /* the target has a fixed target size set */
1354+ release = (agg->size () == agg->target ());
1355+ } else {
1356+ /* fall back to the stream size */
1357+ release = (agg->size () == task->stream [i].goal );
1358+ }
1359+
1360+ /* release the hash table bucket */
1361+ parsec_hash_table_unlock_bucket (&tasks_table, hk);
1362+
1363+ } else if (reducer) { // is this a streaming input? reduce the received value
13211364 // N.B. Right now reductions are done eagerly, without spawning tasks
13221365 // this means we must lock
1323- parsec_hash_table_lock_bucket (&tasks_table, hk);
1366+
1367+ /* we use the lock for mutual exclusion for the reducer */
13241368
13251369 if constexpr (!ttg::meta::is_void_v<valueT>) { // for data values
13261370 // have a value already? if not, set, otherwise reduce
@@ -1334,7 +1378,8 @@ namespace ttg_parsec {
13341378 }
13351379 task->parsec_task .data [i].data_in = copy;
13361380 } else {
1337- reducer (*reinterpret_cast <std::decay_t <valueT> *>(copy->device_private ), value);
1381+ using decay_valueT = std::decay_t <valueT>;
1382+ reducer (*reinterpret_cast <decay_valueT *>(copy->device_private ), value);
13381383 }
13391384 } else {
13401385 reducer (); // even if this was a control input, must execute the reducer for possible side effects
@@ -1347,6 +1392,8 @@ namespace ttg_parsec {
13471392 }
13481393 parsec_hash_table_unlock_bucket (&tasks_table, hk);
13491394 } else {
1395+ /* release the lock, not needed anymore */
1396+ parsec_hash_table_unlock_bucket (&tasks_table, hk);
13501397 /* whether the task needs to be deferred or not */
13511398 bool needs_deferring = false ;
13521399 if constexpr (!valueT_is_Void) {
@@ -2234,7 +2281,8 @@ namespace ttg_parsec {
22342281 public:
22352282 template <typename keymapT = ttg::detail::default_keymap<keyT>,
22362283 typename priomapT = ttg::detail::default_priomap<keyT>>
2237- TT (const std::string &name, const std::vector<std::string> &innames, const std::vector<std::string> &outnames,
2284+ TT (const input_edges_type &inedges,
2285+ const std::string &name, const std::vector<std::string> &innames, const std::vector<std::string> &outnames,
22382286 ttg::World world, keymapT &&keymap_ = keymapT(), priomapT &&priomap_ = priomapT())
22392287 : ttg::TTBase(name, numinedges, numouts)
22402288 , world(world)
@@ -2243,7 +2291,8 @@ namespace ttg_parsec {
22432291 ? decltype (keymap)(ttg::detail::default_keymap<keyT>(world))
22442292 : decltype (keymap)(std::forward<keymapT>(keymap_)))
22452293 , priomap(decltype (keymap)(std::forward<priomapT>(priomap_)))
2246- , static_stream_goal() {
2294+ , static_stream_goal()
2295+ , aggregator_factories(ttg::meta::make_aggregator_factory_tuple(inedges)) {
22472296 // Cannot call these in base constructor since terminals not yet constructed
22482297 if (innames.size () != numinedges) throw std::logic_error (" ttg_parsec::TT: #input names != #input terminals" );
22492298 if (outnames.size () != numouts) throw std::logic_error (" ttg_parsec::TT: #output names != #output terminals" );
@@ -2345,15 +2394,16 @@ namespace ttg_parsec {
23452394 typename priomapT = ttg::detail::default_priomap<keyT>>
23462395 TT (const std::string &name, const std::vector<std::string> &innames, const std::vector<std::string> &outnames,
23472396 keymapT &&keymap = keymapT(ttg::default_execution_context()), priomapT &&priomap = priomapT())
2348- : TT(name, innames, outnames, ttg::default_execution_context(), std::forward<keymapT>(keymap),
2397+ : TT(input_edges_type(), name, innames, outnames, ttg::default_execution_context(), std::forward<keymapT>(keymap),
23492398 std::forward<priomapT>(priomap)) {}
23502399
23512400 template <typename keymapT = ttg::detail::default_keymap<keyT>,
23522401 typename priomapT = ttg::detail::default_priomap<keyT>>
23532402 TT (const input_edges_type &inedges, const output_edges_type &outedges, const std::string &name,
23542403 const std::vector<std::string> &innames, const std::vector<std::string> &outnames, ttg::World world,
23552404 keymapT &&keymap_ = keymapT(), priomapT &&priomap = priomapT())
2356- : TT(name, innames, outnames, world, std::forward<keymapT>(keymap_), std::forward<priomapT>(priomap)) {
2405+ : TT(inedges, name, innames, outnames, world, std::forward<keymapT>(keymap_), std::forward<priomapT>(priomap))
2406+ {
23572407 connect_my_inputs_to_incoming_edge_outputs (std::make_index_sequence<numinedges>{}, inedges);
23582408 connect_my_outputs_to_outgoing_edge_inputs (std::make_index_sequence<numouts>{}, outedges);
23592409 }
0 commit comments