Skip to content

Commit 61a37aa

Browse files
committed
Draft of aggregators
Aggregators are meant to replace reduction terminals where the sole purpose is to collect a number of inputs before passing them into the task. An aggregator is created by wrapping the input type in a ttg::aggregator, which will cause values to be aggregated until either the target count is reached (if provided) or the stream size has been reached. The aggregator is then passed directly into the task and can be iterated over. In combination with make_tt, an aggregator is created using `ttg::make_aggregator(inedge)` or `ttg::make_aggregator(inedge, target)`. These calls will return a specialized Edge that allows to check at compile-time and create aggregators in the backend. Example: ``` ttg::Edge<int, int> edge; auto tt = make_tt( [](const int& key, ttg::aggregator<int>& agg){ for (auto&& v : agg) { ... } }, ttg::edges(ttg::make_aggregator(edge)), ...); ``` Signed-off-by: Joseph Schuchart <schuchart@icl.utk.edu>
1 parent d8a1204 commit 61a37aa

5 files changed

Lines changed: 202 additions & 18 deletions

File tree

tests/unit/tt.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,27 @@ namespace tt_i_iv {
145145
};
146146
} // namespace tt_i_iv
147147

148+
// {task_id,data} = {int, aggregator}
149+
namespace tt_i_i_a {
150+
151+
class tt : public ttg::TT<int, std::tuple<>, tt, ttg::typelist<ttg::Aggregator<int>>> {
152+
using baseT = typename TT::ttT;
153+
154+
public:
155+
tt(const typename baseT::input_edges_type &inedges, const typename baseT::output_edges_type &outedges,
156+
const std::string &name)
157+
: baseT(inedges, outedges, name, {"aggregator<int>"}, {}) {}
158+
159+
static constexpr const bool have_cuda_op = false;
160+
161+
void op(const int &key, const baseT::input_refs_tuple_type &data, baseT::output_terminals_type &outs) {
162+
static_assert(ttg::detail::is_aggregator_v<std::decay_t<std::tuple_element_t<0, baseT::input_refs_tuple_type>>>);
163+
}
164+
165+
~tt() {}
166+
};
167+
} // namespace tt_i_i_a
168+
148169
TEST_CASE("TemplateTask", "[core]") {
149170
SECTION("constructors") {
150171
{ // void task id, void data
@@ -249,5 +270,28 @@ TEST_CASE("TemplateTask", "[core]") {
249270
},
250271
ttg::edges(in), ttg::edges()));
251272
}
273+
{ // nonvoid task id, aggregator input
274+
ttg::Edge<int, int> in;
275+
size_t count = 16;
276+
CHECK_NOTHROW(std::make_unique<tt_i_i_a::tt>(ttg::edges(ttg::make_aggregator(in)), ttg::edges(), ""));
277+
CHECK_NOTHROW(
278+
ttg::make_tt(
279+
[](const int &key, const ttg::Aggregator<int> &datum, std::tuple<> &outs) {
280+
for (auto&& v : datum)
281+
{ }
282+
283+
for (const auto& v : datum)
284+
{ }
285+
}, ttg::edges(ttg::make_aggregator(in)), ttg::edges()));
286+
CHECK_NOTHROW(
287+
ttg::make_tt(
288+
[](const int &key, ttg::Aggregator<int> &&datum, std::tuple<> &outs) {
289+
for (auto&& v : datum)
290+
{ }
291+
292+
for (const auto& v : datum)
293+
{ }
294+
}, ttg::edges(ttg::make_aggregator(in)), ttg::edges()));
295+
}
252296
}
253297
}

ttg/ttg/make_tt.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#ifndef TTG_MAKE_TT_H
44
#define TTG_MAKE_TT_H
55

6+
#if 0
67
namespace detail {
78

89
template <typename... FromEdgeTypesT, std::size_t... I>
@@ -22,6 +23,7 @@ namespace detail {
2223

2324
inline auto edge_base_tuple(const std::tuple<> &empty) { return empty; }
2425
} // namespace detail
26+
#endif // 0
2527

2628
// Class to wrap a callable with signature
2729
//
@@ -346,9 +348,9 @@ auto make_tt_tpl(funcT &&func, const std::tuple<ttg::Edge<keyT, input_edge_value
346348
static_assert(std::is_same_v<decayed_input_args_t, std::tuple<input_edge_valuesT...>>,
347349
"ttg::make_tt_tpl(func, inedges, outedges): inedges value types do not match argument types of func");
348350

349-
auto input_edges = detail::edge_base_tuple(inedges);
351+
//auto input_edges = detail::edge_base_tuple(inedges);
350352

351-
return std::make_unique<wrapT>(std::forward<funcT>(func), input_edges, outedges, name, innames, outnames);
353+
return std::make_unique<wrapT>(std::forward<funcT>(func), inedges, outedges, name, innames, outnames);
352354
}
353355

354356
/// @brief Factory function to assist in wrapping a callable with signature
@@ -423,9 +425,9 @@ auto make_tt(funcT &&func, const std::tuple<ttg::Edge<keyT, input_edge_valuesT>.
423425
using wrapT = typename CallableWrapTTArgsAsTypelist<funcT, have_outterm_tuple, keyT, output_terminals_type,
424426
full_input_args_t>::type;
425427

426-
auto input_edges = detail::edge_base_tuple(inedges);
428+
//auto input_edges = detail::edge_base_tuple(inedges);
427429

428-
return std::make_unique<wrapT>(std::forward<funcT>(func), input_edges, outedges, name, innames, outnames);
430+
return std::make_unique<wrapT>(std::forward<funcT>(func), inedges, outedges, name, innames, outnames);
429431
}
430432

431433
template <typename keyT, typename funcT, typename... input_valuesT, typename... output_edgesT>

ttg/ttg/parsec/ttg.h

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "ttg/util/print.h"
2828
#include "ttg/util/trace.h"
2929
#include "ttg/util/typelist.h"
30+
#include "ttg/aggregator.h"
3031

3132
#include "ttg/serialization/data_descriptor.h"
3233

@@ -744,7 +745,9 @@ namespace ttg_parsec {
744745
"The fourth template for ttg::TT must be a ttg::typelist containing the input types");
745746
// create a virtual control input if the input list is empty, to be used in invoke()
746747
using actual_input_tuple_type = std::conditional_t<!ttg::meta::typelist_is_empty_v<input_valueTs>,
747-
ttg::meta::typelist_to_tuple_t<input_valueTs>, std::tuple<void>>;
748+
ttg::meta::remove_wrapper_tuple_t<
749+
ttg::meta::typelist_to_tuple_t<input_valueTs>>,
750+
std::tuple<void>>;
748751
using input_tuple_type = ttg::meta::typelist_to_tuple_t<input_valueTs>;
749752
static_assert(ttg::meta::is_tuple_v<output_terminalsT>,
750753
"Second template argument for ttg::TT must be std::tuple containing the output terminal types");
@@ -786,6 +789,8 @@ namespace ttg_parsec {
786789
using input_values_tuple_type = ttg::meta::drop_void_t<ttg::meta::decayed_typelist_t<input_tuple_type>>;
787790
using input_refs_tuple_type = ttg::meta::drop_void_t<ttg::meta::add_glvalue_reference_tuple_t<input_tuple_type>>;
788791

792+
using aggregator_factory_tuple_type = ttg::meta::aggregator_factory_tuple_type_t<input_edges_type>;
793+
789794
static constexpr int numinvals =
790795
std::tuple_size_v<input_refs_tuple_type>; // number of input arguments with values (i.e. omitting the control
791796
// input, if any)
@@ -810,6 +815,7 @@ namespace ttg_parsec {
810815

811816
input_terminals_type input_terminals;
812817
output_terminalsT output_terminals;
818+
aggregator_factory_tuple_type aggregator_factories;
813819

814820
protected:
815821
const auto &get_output_terminals() const { return output_terminals; }
@@ -1295,10 +1301,13 @@ namespace ttg_parsec {
12951301
task_t *task;
12961302
auto &world_impl = world.impl();
12971303
auto &reducer = std::get<i>(input_reducers);
1304+
constexpr bool is_aggregator = ttg::detail::is_aggregator_v<
1305+
typename std::tuple_element_t<i, input_edges_type>::value_type>;
12981306
bool release = false;
12991307
bool remove_from_hash = true;
1308+
bool use_hash_table = is_aggregator || (numins > 1) || reducer;
13001309
/* If we have only one input and no reducer on that input we can skip the hash table */
1301-
if (numins > 1 || reducer) {
1310+
if (use_hash_table) {
13021311
parsec_hash_table_lock_bucket(&tasks_table, hk);
13031312
if (nullptr == (task = (task_t *)parsec_hash_table_nolock_find(&tasks_table, hk))) {
13041313
task = create_new_task(key);
@@ -1310,17 +1319,52 @@ namespace ttg_parsec {
13101319
remove_from_hash = false;
13111320
release = true;
13121321
}
1313-
parsec_hash_table_unlock_bucket(&tasks_table, hk);
1322+
/* we'll keep the lock for later */
13141323
} else {
13151324
task = create_new_task(key);
13161325
world_impl.increment_created();
13171326
remove_from_hash = false;
13181327
}
13191328

1320-
if (reducer) { // is this a streaming input? reduce the received value
1329+
if constexpr (is_aggregator) {
1330+
1331+
/* we use the lock to ensure mutual exclusion when inserting into the aggregator */
1332+
1333+
using aggregator_t = typename std::tuple_element_t<i, input_edges_type>::value_type;
1334+
aggregator_t* agg;
1335+
ttg_data_copy_t *agg_copy;
1336+
if (nullptr == (agg_copy = reinterpret_cast<ttg_data_copy_t *>(task->parsec_task.data[i].data_in))) {
1337+
/* create a new aggregator */
1338+
agg_copy = detail::create_new_datacopy(std::get<i>(aggregator_factories)());
1339+
task->parsec_task.data[i].data_in = agg_copy;
1340+
}
1341+
agg = reinterpret_cast<aggregator_t *>(agg_copy->device_private);
1342+
1343+
ttg_data_copy_t *copy;
1344+
if (nullptr != copy_in) {
1345+
/* register this copy with the task */
1346+
copy = detail::register_data_copy<std::decay_t<Value>>(copy_in, task, std::is_const_v<typename aggregator_t::value_type>);
1347+
} else {
1348+
copy = detail::create_new_datacopy(std::forward<Value>(value));
1349+
}
1350+
/* put the value into the aggregator */
1351+
agg->add_value(*reinterpret_cast<std::decay_t<Value> *>(copy->device_private));
1352+
if (agg->has_target()) {
1353+
/* the target has a fixed target size set */
1354+
release = (agg->size() == agg->target());
1355+
} else {
1356+
/* fall back to the stream size */
1357+
release = (agg->size() == task->stream[i].goal);
1358+
}
1359+
1360+
/* release the hash table bucket */
1361+
parsec_hash_table_unlock_bucket(&tasks_table, hk);
1362+
1363+
} else if (reducer) { // is this a streaming input? reduce the received value
13211364
// N.B. Right now reductions are done eagerly, without spawning tasks
13221365
// this means we must lock
1323-
parsec_hash_table_lock_bucket(&tasks_table, hk);
1366+
1367+
/* we use the lock for mutual exclusion for the reducer */
13241368

13251369
if constexpr (!ttg::meta::is_void_v<valueT>) { // for data values
13261370
// have a value already? if not, set, otherwise reduce
@@ -1334,7 +1378,8 @@ namespace ttg_parsec {
13341378
}
13351379
task->parsec_task.data[i].data_in = copy;
13361380
} else {
1337-
reducer(*reinterpret_cast<std::decay_t<valueT> *>(copy->device_private), value);
1381+
using decay_valueT = std::decay_t<valueT>;
1382+
reducer(*reinterpret_cast<decay_valueT *>(copy->device_private), value);
13381383
}
13391384
} else {
13401385
reducer(); // even if this was a control input, must execute the reducer for possible side effects
@@ -1347,6 +1392,8 @@ namespace ttg_parsec {
13471392
}
13481393
parsec_hash_table_unlock_bucket(&tasks_table, hk);
13491394
} else {
1395+
/* release the lock, not needed anymore */
1396+
parsec_hash_table_unlock_bucket(&tasks_table, hk);
13501397
/* whether the task needs to be deferred or not */
13511398
bool needs_deferring = false;
13521399
if constexpr (!valueT_is_Void) {
@@ -2234,7 +2281,8 @@ namespace ttg_parsec {
22342281
public:
22352282
template <typename keymapT = ttg::detail::default_keymap<keyT>,
22362283
typename priomapT = ttg::detail::default_priomap<keyT>>
2237-
TT(const std::string &name, const std::vector<std::string> &innames, const std::vector<std::string> &outnames,
2284+
TT(const input_edges_type &inedges,
2285+
const std::string &name, const std::vector<std::string> &innames, const std::vector<std::string> &outnames,
22382286
ttg::World world, keymapT &&keymap_ = keymapT(), priomapT &&priomap_ = priomapT())
22392287
: ttg::TTBase(name, numinedges, numouts)
22402288
, world(world)
@@ -2243,7 +2291,8 @@ namespace ttg_parsec {
22432291
? decltype(keymap)(ttg::detail::default_keymap<keyT>(world))
22442292
: decltype(keymap)(std::forward<keymapT>(keymap_)))
22452293
, priomap(decltype(keymap)(std::forward<priomapT>(priomap_)))
2246-
, static_stream_goal() {
2294+
, static_stream_goal()
2295+
, aggregator_factories(ttg::meta::make_aggregator_factory_tuple(inedges)) {
22472296
// Cannot call these in base constructor since terminals not yet constructed
22482297
if (innames.size() != numinedges) throw std::logic_error("ttg_parsec::TT: #input names != #input terminals");
22492298
if (outnames.size() != numouts) throw std::logic_error("ttg_parsec::TT: #output names != #output terminals");
@@ -2345,15 +2394,16 @@ namespace ttg_parsec {
23452394
typename priomapT = ttg::detail::default_priomap<keyT>>
23462395
TT(const std::string &name, const std::vector<std::string> &innames, const std::vector<std::string> &outnames,
23472396
keymapT &&keymap = keymapT(ttg::default_execution_context()), priomapT &&priomap = priomapT())
2348-
: TT(name, innames, outnames, ttg::default_execution_context(), std::forward<keymapT>(keymap),
2397+
: TT(input_edges_type(), name, innames, outnames, ttg::default_execution_context(), std::forward<keymapT>(keymap),
23492398
std::forward<priomapT>(priomap)) {}
23502399

23512400
template <typename keymapT = ttg::detail::default_keymap<keyT>,
23522401
typename priomapT = ttg::detail::default_priomap<keyT>>
23532402
TT(const input_edges_type &inedges, const output_edges_type &outedges, const std::string &name,
23542403
const std::vector<std::string> &innames, const std::vector<std::string> &outnames, ttg::World world,
23552404
keymapT &&keymap_ = keymapT(), priomapT &&priomap = priomapT())
2356-
: TT(name, innames, outnames, world, std::forward<keymapT>(keymap_), std::forward<priomapT>(priomap)) {
2405+
: TT(inedges, name, innames, outnames, world, std::forward<keymapT>(keymap_), std::forward<priomapT>(priomap))
2406+
{
23572407
connect_my_inputs_to_incoming_edge_outputs(std::make_index_sequence<numinedges>{}, inedges);
23582408
connect_my_outputs_to_outgoing_edge_inputs(std::make_index_sequence<numouts>{}, outedges);
23592409
}

ttg/ttg/terminal.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,13 @@ namespace ttg {
230230
namespace detail {
231231
template <typename keyT, typename... valuesT>
232232
struct input_terminals_tuple {
233-
using type = std::tuple<ttg::In<keyT, valuesT>...>;
233+
using type = std::tuple<ttg::In<keyT, ttg::meta::remove_wrapper_t<valuesT>>...>;
234234
};
235235

236236
template <typename keyT, typename... valuesT>
237-
struct input_terminals_tuple<keyT, std::tuple<valuesT...>> {
238-
using type = std::tuple<ttg::In<keyT, valuesT>...>;
239-
};
237+
struct input_terminals_tuple<keyT, std::tuple<valuesT...>>
238+
: input_terminals_tuple<keyT, valuesT...>
239+
{ };
240240

241241
template <typename keyT, typename... valuesT>
242242
using input_terminals_tuple_t = typename input_terminals_tuple<keyT, valuesT...>::type;

ttg/ttg/util/meta.h

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,94 @@ namespace ttg {
755755
constexpr bool is_invocable_typelist_r_v<ReturnType, Callable, ttg::typelist<Args...>> =
756756
std::is_invocable_r_v<ReturnType, Callable, Args...>;
757757

758+
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
759+
// remove any wrapper from a type, specializations provided where the wrapper is implemented (e.g., aggregator, ...)
760+
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
761+
template<typename T>
762+
struct remove_wrapper {
763+
using type = T;
764+
};
765+
766+
template<typename T>
767+
using remove_wrapper_t = typename remove_wrapper<T>::type;
768+
769+
template<typename T>
770+
struct remove_wrapper_tuple;
771+
772+
template<typename...Ts>
773+
struct remove_wrapper_tuple<std::tuple<Ts...>> {
774+
using type = std::tuple<remove_wrapper_t<Ts>...>;
775+
};
776+
777+
template<typename T>
778+
using remove_wrapper_tuple_t = typename remove_wrapper_tuple<T>::type;
779+
780+
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
781+
// type of a aggregator factory (returned by Edge::aggregator_factory()), with empty default
782+
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
783+
784+
template<typename EdgeT, typename Enabler = void>
785+
struct edge_has_aggregator_factory : std::false_type
786+
{ };
787+
788+
template<typename EdgeT>
789+
struct edge_has_aggregator_factory<EdgeT, std::enable_if_t<std::is_invocable_v<decltype(std::declval<EdgeT>().aggregator_factory())>>>
790+
: std::true_type
791+
{ };
792+
793+
template<typename T>
794+
constexpr bool edge_has_aggregator_factory_v = edge_has_aggregator_factory<T>::value;
795+
796+
template<typename EdgeT, bool HasAggregatorFactory>
797+
struct aggregator_factory {
798+
using type = std::byte;
799+
};
800+
801+
template<typename EdgeT>
802+
struct aggregator_factory<EdgeT, true> {
803+
using type = decltype(std::declval<EdgeT>().aggregator_factory());
804+
};
805+
806+
template<typename EdgeT>
807+
using aggregator_factory_t = typename aggregator_factory<EdgeT, edge_has_aggregator_factory_v<EdgeT>>::type;
808+
809+
template<typename T>
810+
struct aggregator_factory_tuple_type;
811+
812+
template<typename... ValueTs>
813+
struct aggregator_factory_tuple_type<ttg::typelist<ValueTs...>> {
814+
using type = std::tuple<aggregator_factory_t<ValueTs>...>;
815+
};
816+
817+
template<typename... ValueTs>
818+
struct aggregator_factory_tuple_type<std::tuple<ValueTs...>> {
819+
using type = std::tuple<aggregator_factory_t<ValueTs>...>;
820+
};
821+
822+
template<typename T>
823+
using aggregator_factory_tuple_type_t = typename aggregator_factory_tuple_type<T>::type;
824+
825+
namespace detail {
826+
827+
template<typename EdgeT>
828+
auto make_aggregator(const EdgeT& edge) {
829+
if constexpr (edge_has_aggregator_factory_v<EdgeT>) {
830+
return edge.aggregator_factory();
831+
} else {
832+
return std::byte();
833+
}
834+
}
835+
836+
template<typename... EdgesT, std::size_t... Is>
837+
auto make_aggregator_factory_tuple(const std::tuple<EdgesT...>& edges, std::index_sequence<Is...>) {
838+
return std::make_tuple(make_aggregator(std::get<Is>(edges))...);
839+
}
840+
} // namespace detail
841+
template<typename... EdgesT>
842+
auto make_aggregator_factory_tuple(const std::tuple<EdgesT...>& edges) {
843+
return detail::make_aggregator_factory_tuple(edges, std::make_index_sequence<sizeof...(EdgesT)>());
844+
}
845+
758846
} // namespace meta
759847
} // namespace ttg
760848

0 commit comments

Comments
 (0)