2020#pragma once
2121
2222#include < forward_list>
23+ #include < memory>
2324#include < mutex>
2425#include < sstream>
2526#include < thread>
2627#include < unordered_map>
2728#include < unordered_set>
29+ #include < utility>
2830
31+ #include " arrow/acero/accumulation_queue.h"
2932#include " arrow/acero/aggregate_node.h"
33+ #include " arrow/acero/backpressure_handler.h"
3034#include " arrow/acero/exec_plan.h"
3135#include " arrow/acero/options.h"
3236#include " arrow/acero/query_context.h"
3337#include " arrow/acero/util.h"
3438#include " arrow/compute/exec.h"
3539#include " arrow/compute/exec_internal.h"
40+ #include " arrow/compute/ordering.h"
3641#include " arrow/compute/registry.h"
3742#include " arrow/compute/row/grouper.h"
3843#include " arrow/datum.h"
3944#include " arrow/result.h"
45+ #include " arrow/type_fwd.h"
4046#include " arrow/util/checked_cast.h"
4147#include " arrow/util/logging.h"
4248#include " arrow/util/thread_pool.h"
4349#include " arrow/util/tracing_internal.h"
4450
4551// This file implements both regular and segmented group-by aggregation, which is a
46- // generalization of ordered aggregation in which the key columns are not required to be
47- // ordered.
52+ // generalization of ordered aggregation in which the key columns are not required to
53+ // be ordered.
4854//
49- // In (regular) group-by aggregation, the input rows are partitioned into groups using a
50- // set of columns called keys, where in a given group each row has the same values for
51- // these columns. In segmented group-by aggregation, a second set of columns called
52- // segment-keys is used to refine the partitioning. However, segment-keys are different in
53- // that they partition only consecutive rows into a single group. Such a partition of
54- // consecutive rows is called a segment group. For example, consider a column X with
55- // values [A, A, B, A] at row-indices [0, 1, 2, 3]. A regular group-by aggregation with
56- // keys [X] yields a row-index partitioning [[0, 1, 3], [2]] whereas a segmented-group-by
57- // aggregation with segment-keys [X] yields [[0, 1], [2], [3]].
55+ // In (regular) group-by aggregation, the input rows are partitioned into groups using
56+ // a set of columns called keys, where in a given group each row has the same values
57+ // for these columns. In segmented group-by aggregation, a second set of columns
58+ // called segment-keys is used to refine the partitioning. However, segment-keys are
59+ // different in that they partition only consecutive rows into a single group. Such a
60+ // partition of consecutive rows is called a segment group. For example, consider a
61+ // column X with values [A, A, B, A] at row-indices [0, 1, 2, 3]. A regular group-by
62+ // aggregation with keys [X] yields a row-index partitioning [[0, 1, 3], [2]] whereas
63+ // a segmented-group-by aggregation with segment-keys [X] yields [[0, 1], [2], [3]].
5864//
59- // The implementation first segments the input using the segment-keys, then groups by the
60- // keys. When a segment group end is reached while scanning the input, output is pushed
61- // and the accumulating state is cleared. If no segment-keys are given, then the entire
62- // input is taken as one segment group. One batch per segment group is sent to output.
65+ // The implementation first segments the input using the segment-keys, then groups by
66+ // the keys. When a segment group end is reached while scanning the input, output is
67+ // pushed and the accumulating state is cleared. If no segment-keys are given, then
68+ // the entire input is taken as one segment group. One batch per segment group is sent
69+ // to output.
6370
6471namespace arrow {
65-
6672using internal::checked_cast;
6773
6874using compute::ExecSpan;
@@ -82,6 +88,20 @@ using compute::Segment;
8288namespace acero {
8389namespace aggregate {
8490
91+ class BackpressureController : public BackpressureControl {
92+ public:
93+ BackpressureController (ExecNode* node, ExecNode* output)
94+ : node_(node), output_(output) {}
95+
96+ void Pause () override { node_->PauseProducing (output_, ++backpressure_counter_); }
97+ void Resume () override { node_->ResumeProducing (output_, ++backpressure_counter_); }
98+
99+ private:
100+ ExecNode* node_;
101+ ExecNode* output_;
102+ std::atomic<int32_t > backpressure_counter_;
103+ };
104+
85105template <typename KernelType>
86106struct AggregateNodeArgs {
87107 std::shared_ptr<Schema> output_schema;
@@ -93,6 +113,7 @@ struct AggregateNodeArgs {
93113 std::vector<const KernelType*> kernels;
94114 std::vector<std::vector<TypeHolder>> kernel_intypes;
95115 std::vector<std::vector<std::unique_ptr<KernelState>>> states;
116+ bool requires_ordering;
96117};
97118
98119std::vector<TypeHolder> ExtendWithGroupIdType (const std::vector<TypeHolder>& in_types);
@@ -155,17 +176,17 @@ Result<std::vector<Datum>> ExtractValues(const ExecBatch& input_batch,
155176
156177void PlaceFields (ExecBatch& batch, size_t base, std::vector<Datum>& values);
157178
158- class ScalarAggregateNode : public ExecNode , public TracedNode {
179+ class ScalarAggregateNode : public ExecNode ,
180+ public TracedNode,
181+ public util::SerialSequencingQueue::Processor {
159182 public:
160- ScalarAggregateNode (ExecPlan* plan, std::vector<ExecNode*> inputs,
161- std::shared_ptr<Schema> output_schema,
162- std::unique_ptr<RowSegmenter> segmenter,
163- std::vector<int > segment_field_ids,
164- std::vector<std::vector<int >> target_fieldsets,
165- std::vector<Aggregate> aggs,
166- std::vector<const ScalarAggregateKernel*> kernels,
167- std::vector<std::vector<TypeHolder>> kernel_intypes,
168- std::vector<std::vector<std::unique_ptr<KernelState>>> states)
183+ ScalarAggregateNode (
184+ ExecPlan* plan, std::vector<ExecNode*> inputs,
185+ std::shared_ptr<Schema> output_schema, std::unique_ptr<RowSegmenter> segmenter,
186+ std::vector<int > segment_field_ids, std::vector<std::vector<int >> target_fieldsets,
187+ std::vector<Aggregate> aggs, std::vector<const ScalarAggregateKernel*> kernels,
188+ std::vector<std::vector<TypeHolder>> kernel_intypes,
189+ std::vector<std::vector<std::unique_ptr<KernelState>>> states, Ordering ordering)
169190 : ExecNode(plan, std::move(inputs), {" target" },
170191 /* output_schema=*/ std::move(output_schema)),
171192 TracedNode (this ),
@@ -176,22 +197,31 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
176197 aggs_(std::move(aggs)),
177198 kernels_(std::move(kernels)),
178199 kernel_intypes_(std::move(kernel_intypes)),
179- states_(std::move(states)) {}
200+ states_(std::move(states)),
201+ total_output_batches_(0 ),
202+ sequencer_(nullptr ),
203+ ordering_(std::move(ordering)) {}
180204
181205 static Result<AggregateNodeArgs<ScalarAggregateKernel>> MakeAggregateNodeArgs (
182206 const std::shared_ptr<Schema>& input_schema, const std::vector<FieldRef>& keys,
183207 const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>& aggs,
184- ExecContext* exec_ctx, size_t concurrency, bool is_cpu_parallel );
208+ ExecContext* exec_ctx, size_t concurrency);
185209
186210 static Result<ExecNode*> Make (ExecPlan* plan, std::vector<ExecNode*> inputs,
187211 const ExecNodeOptions& options);
188212
189213 const char * kind_name () const override { return " ScalarAggregateNode" ; }
190214
215+ Status Init () override ;
216+
191217 Status DoConsume (const ExecSpan& batch, size_t thread_index);
192218
193219 Status InputReceived (ExecNode* input, ExecBatch batch) override ;
194220
221+ const Ordering& ordering () const override { return ordering_; }
222+
223+ Status Process (ExecBatch batch) override ;
224+
195225 Status InputFinished (ExecNode* input, int total_batches) override ;
196226
197227 Status StartProducing () override {
@@ -235,18 +265,23 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
235265
236266 AtomicCounter input_counter_;
237267 // / \brief Total number of output batches produced
238- int total_output_batches_ = 0 ;
268+ int64_t total_output_batches_;
269+ std::unique_ptr<acero::util::SerialSequencingQueue> sequencer_;
270+ std::unique_ptr<Processor> processor_;
271+ Ordering ordering_;
239272};
240273
241- class GroupByNode : public ExecNode , public TracedNode {
274+ class GroupByNode : public ExecNode ,
275+ public TracedNode,
276+ public util::SerialSequencingQueue::Processor {
242277 public:
243278 GroupByNode (ExecNode* input, std::shared_ptr<Schema> output_schema,
244279 std::vector<int > key_field_ids, std::vector<int > segment_key_field_ids,
245280 std::unique_ptr<RowSegmenter> segmenter,
246281 std::vector<std::vector<TypeHolder>> agg_src_types,
247282 std::vector<std::vector<int >> agg_src_fieldsets,
248283 std::vector<Aggregate> aggs,
249- std::vector<const HashAggregateKernel*> agg_kernels)
284+ std::vector<const HashAggregateKernel*> agg_kernels, Ordering ordering )
250285 : ExecNode(input->plan (), {input}, {" groupby" }, std::move(output_schema)),
251286 TracedNode (this ),
252287 segmenter_(std::move(segmenter)),
@@ -256,14 +291,17 @@ class GroupByNode : public ExecNode, public TracedNode {
256291 agg_src_types_(std::move(agg_src_types)),
257292 agg_src_fieldsets_(std::move(agg_src_fieldsets)),
258293 aggs_(std::move(aggs)),
259- agg_kernels_(std::move(agg_kernels)) {}
294+ agg_kernels_(std::move(agg_kernels)),
295+ total_output_batches_(0 ),
296+ sequencer_(nullptr ),
297+ ordering_(std::move(ordering)) {}
260298
261299 Status Init () override ;
262300
263301 static Result<AggregateNodeArgs<HashAggregateKernel>> MakeAggregateNodeArgs (
264302 const std::shared_ptr<Schema>& input_schema, const std::vector<FieldRef>& keys,
265303 const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>& aggs,
266- ExecContext* ctx, const bool is_cpu_parallel );
304+ ExecContext* ctx);
267305
268306 static Result<ExecNode*> Make (ExecPlan* plan, std::vector<ExecNode*> inputs,
269307 const ExecNodeOptions& options);
@@ -284,6 +322,8 @@ class GroupByNode : public ExecNode, public TracedNode {
284322
285323 Status InputReceived (ExecNode* input, ExecBatch batch) override ;
286324
325+ Status Process (ExecBatch batch) override ;
326+
287327 Status InputFinished (ExecNode* input, int total_batches) override ;
288328
289329 Status StartProducing () override {
@@ -292,6 +332,8 @@ class GroupByNode : public ExecNode, public TracedNode {
292332 return Status::OK ();
293333 }
294334
335+ const Ordering& ordering () const override { return ordering_; }
336+
295337 void PauseProducing (ExecNode* output, int32_t counter) override {
296338 // TODO(ARROW-16260)
297339 // Without spillover there is no way to handle backpressure in this node
@@ -346,12 +388,15 @@ class GroupByNode : public ExecNode, public TracedNode {
346388
347389 AtomicCounter input_counter_;
348390 // / \brief Total number of output batches produced
349- int total_output_batches_ = 0 ;
391+ int64_t total_output_batches_;
350392
351393 std::vector<ThreadLocalState> local_states_;
352394 ExecBatch out_data_;
395+ std::unique_ptr<acero::util::SerialSequencingQueue> sequencer_;
396+ std::unique_ptr<Processor> processor_;
397+ Ordering ordering_;
353398};
354399
355400} // namespace aggregate
356401} // namespace acero
357- } // namespace arrow
402+ } // namespace arrow
0 commit comments