diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecExchange.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecExchange.java index 17e3966d52a..8bedbd89321 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecExchange.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecExchange.java @@ -20,6 +20,7 @@ import org.apache.gluten.streaming.runtime.partitioner.GlutenKeyGroupStreamPartitioner; import org.apache.gluten.table.runtime.keyselector.GlutenKeySelector; import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; +import org.apache.gluten.table.runtime.operators.WindowAggOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -112,6 +113,18 @@ public StreamExecExchange( checkArgument(inputProperties.size() == 1); } + private static boolean isWindowPropertyField(String fieldName) { + return "window_start".equals(fieldName) + || "window_end".equals(fieldName) + || "window_time".equals(fieldName); + } + + private boolean isWindowAggregateExchange(OneInputTransformation inputTransform) { + return inputTransform.getOperator() instanceof WindowAggOperator + || ((RowType) getOutputType()) + .getFieldNames().stream().anyMatch(StreamExecExchange::isWindowPropertyField); + } + @SuppressWarnings("unchecked") @Override protected Transformation translateToPlanInternal( @@ -139,7 +152,8 @@ protected Transformation translateToPlanInternal( planner.getFlinkContext().getClassLoader(), keys, inputType); // --- Begin Gluten-specific code changes --- OneInputTransformation oneInputTransform = (OneInputTransformation) inputTransform; - if (oneInputTransform.getOperator() instanceof GlutenOperator) { + if (oneInputTransform.getOperator() instanceof GlutenOperator + && !isWindowAggregateExchange(oneInputTransform)) { // TODO: velox's parallelism need to be set here, as some nodes need it. // should set it when operator init. parallelism = inputTransform.getParallelism(); @@ -181,7 +195,8 @@ protected Transformation translateToPlanInternal( parallelism, false); partitioner = - new GlutenKeyGroupStreamPartitioner(keySelector, DEFAULT_LOWER_BOUND_MAX_PARALLELISM); + new GlutenKeyGroupStreamPartitioner( + keySelector, DEFAULT_LOWER_BOUND_MAX_PARALLELISM, parallelism); } else { parallelism = ExecutionConfig.PARALLELISM_DEFAULT; partitioner = diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java index e5dca4ee8de..171cba1cff4 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java @@ -44,6 +44,7 @@ import org.apache.flink.streaming.api.transformations.OneInputTransformation; import org.apache.flink.table.data.RowData; import org.apache.flink.table.planner.delegation.PlannerBase; +import org.apache.flink.table.planner.plan.logical.SliceAttachedWindowingStrategy; import org.apache.flink.table.planner.plan.logical.WindowingStrategy; import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge; import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; @@ -63,21 +64,27 @@ import org.apache.flink.table.runtime.util.TimeWindowUtil; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.RowType.RowField; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; import org.apache.calcite.rel.core.AggregateCall; import org.apache.commons.math3.util.ArithmeticUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import javax.annotation.Nullable; import java.time.ZoneId; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.TimeZone; import java.util.stream.Collectors; @@ -95,6 +102,7 @@ minStateVersion = FlinkVersion.v1_15) public class StreamExecGlobalWindowAggregate extends StreamExecWindowAggregateBase { + private static final Logger LOG = LoggerFactory.getLogger(StreamExecGlobalWindowAggregate.class); public static final String GLOBAL_WINDOW_AGGREGATE_TRANSFORMATION = "global-window-aggregate"; public static final String FIELD_NAME_LOCAL_AGG_INPUT_ROW_TYPE = "localAggInputRowType"; @@ -167,6 +175,13 @@ public StreamExecGlobalWindowAggregate( this.needRetraction = Optional.ofNullable(needRetraction).orElse(false); } + private int getSliceEndIndex() { + if (windowing instanceof SliceAttachedWindowingStrategy) { + return ((SliceAttachedWindowingStrategy) windowing).getSliceEnd(); + } + return -1; + } + @SuppressWarnings("unchecked") @Override protected Transformation translateToPlanInternal( @@ -179,6 +194,18 @@ protected Transformation translateToPlanInternal( final ZoneId shiftTimeZone = TimeWindowUtil.getShiftTimeZone( windowing.getTimeAttributeType(), TableConfigUtils.getLocalTimeZone(config)); + Set nonAggFieldIndexes = new HashSet<>(); + Arrays.stream(grouping).forEach(nonAggFieldIndexes::add); + nonAggFieldIndexes.add(getSliceEndIndex()); + List intermediateAggInputRowFields = new ArrayList<>(); + for (int i = 0; i < inputRowType.getFieldNames().size(); i++) { + RowField rowField = + new RowField(inputRowType.getFieldNames().get(i), inputRowType.getChildren().get(i)); + if (!nonAggFieldIndexes.contains(i)) { + intermediateAggInputRowFields.add(rowField); + } + } + final RowType intermediateInputType = new RowType(intermediateAggInputRowFields); final AggregateInfoList globalAggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList( planner.getTypeFactory(), @@ -187,16 +214,21 @@ protected Transformation translateToPlanInternal( needRetraction, windowing.getWindow(), true); - // --- Begin Gluten-specific code changes --- // TODO: velox window not equal to flink window. + io.github.zhztheplayer.velox4j.type.RowType intermediateInputRowType = + (io.github.zhztheplayer.velox4j.type.RowType) + LogicalTypeConverter.toVLType(intermediateInputType); io.github.zhztheplayer.velox4j.type.RowType inputType = (io.github.zhztheplayer.velox4j.type.RowType) LogicalTypeConverter.toVLType(inputRowType); io.github.zhztheplayer.velox4j.type.RowType outputType = (io.github.zhztheplayer.velox4j.type.RowType) LogicalTypeConverter.toVLType(getOutputType()); List groupingKeys = Utils.generateFieldAccesses(inputType, grouping); - List aggregates = AggregateCallConverter.toAggregates(aggCalls, inputType); + List intermediateAggregates = + AggregateCallConverter.toIntermediateAggregates(aggCalls, intermediateInputRowType); + List finalAggregates = + AggregateCallConverter.toIntermediateAggregates(aggCalls, intermediateInputRowType); checkArgument(outputType.getNames().size() >= grouping.length + aggCalls.length); List aggNames = outputType.getNames().stream() @@ -220,26 +252,26 @@ protected Transformation translateToPlanInternal( PartitionFunctionSpec sliceAssignerSpec = new StreamWindowPartitionFunctionSpec( inputType, rowtimeIndex, size, slide, offset, windowType); - PlanNode aggregation = + PlanNode finalAgg = new AggregationNode( PlanNodeIdGenerator.newId(), - AggregateStep.SINGLE, + AggregateStep.FINAL, groupingKeys, groupingKeys, aggNames, - aggregates, + finalAggregates, false, List.of(new EmptyNode(inputType)), null, List.of()); - PlanNode localAgg = + PlanNode intermediateAgg = new AggregationNode( PlanNodeIdGenerator.newId(), - AggregateStep.SINGLE, + AggregateStep.INTERMEDIATE, groupingKeys, groupingKeys, aggNames, - aggregates, + intermediateAggregates, false, List.of(new EmptyNode(inputType)), null, @@ -247,8 +279,8 @@ protected Transformation translateToPlanInternal( PlanNode windowAgg = new StreamWindowAggregationNode( PlanNodeIdGenerator.newId(), - aggregation, - localAgg, + finalAgg, + intermediateAgg, keySelectorSpec, sliceAssignerSpec, ArithmeticUtils.gcd(size, slide), @@ -280,7 +312,8 @@ protected Transformation translateToPlanInternal( "StreamExecWindowAggregate", selector.getProducedType(), globalAggInfoList.getAggNames(), - accTypes); + accTypes, + windowing.isRowtime()); // --- End Gluten-specific code changes --- final OneInputTransformation transform = diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java index ccd11653160..8843982008f 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java @@ -16,14 +16,10 @@ */ package org.apache.flink.table.planner.plan.nodes.exec.stream; -import org.apache.gluten.rexnode.AggregateCallConverter; -import org.apache.gluten.rexnode.Utils; import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; -import io.github.zhztheplayer.velox4j.aggregate.Aggregate; -import io.github.zhztheplayer.velox4j.expression.FieldAccessTypedExpr; import io.github.zhztheplayer.velox4j.plan.GroupWindowAggregationNode; import io.github.zhztheplayer.velox4j.plan.GroupWindowAggsHandlerNode; import io.github.zhztheplayer.velox4j.plan.HashPartitionFunctionSpec; @@ -213,14 +209,7 @@ protected Transformation translateToPlanInternal( io.github.zhztheplayer.velox4j.type.RowType outputType = (io.github.zhztheplayer.velox4j.type.RowType) LogicalTypeConverter.toVLType(getOutputType()); - List groupingKeys = Utils.generateFieldAccesses(inputType, grouping); - List aggregates = AggregateCallConverter.toAggregates(aggCalls, inputType); checkArgument(outputType.getNames().size() >= grouping.length + aggCalls.length); - List aggNames = - outputType.getNames().stream() - .skip(grouping.length) - .limit(aggCalls.length) - .collect(Collectors.toList()); List keyIndexes = Arrays.stream(grouping).boxed().collect(Collectors.toList()); PartitionFunctionSpec keySelectorSpec = new HashPartitionFunctionSpec(inputType, keyIndexes); // TODO: support more window types. diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java index d67d61709e3..574362cf5ad 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java @@ -42,8 +42,6 @@ import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; import org.apache.flink.table.data.RowData; -import org.apache.flink.table.planner.codegen.CodeGeneratorContext; -import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator; import org.apache.flink.table.planner.delegation.PlannerBase; import org.apache.flink.table.planner.plan.logical.WindowingStrategy; import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge; @@ -53,21 +51,15 @@ import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeMetadata; import org.apache.flink.table.planner.plan.nodes.exec.InputProperty; import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil; -import org.apache.flink.table.planner.plan.utils.AggregateInfoList; -import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; import org.apache.flink.table.planner.utils.TableConfigUtils; -import org.apache.flink.table.runtime.generated.GeneratedNamespaceAggsHandleFunction; -import org.apache.flink.table.runtime.operators.window.tvf.slicing.SliceAssigner; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; import org.apache.flink.table.runtime.util.TimeWindowUtil; -import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.tools.RelBuilder; import org.apache.commons.math3.util.ArithmeticUtils; import javax.annotation.Nullable; @@ -181,7 +173,7 @@ protected Transformation translateToPlanInternal( .limit(aggCalls.length) .collect(Collectors.toList()); List keyIndexes = Arrays.stream(grouping).boxed().collect(Collectors.toList()); - PartitionFunctionSpec keySelectorSpec = new HashPartitionFunctionSpec(inputType, keyIndexes); + PartitionFunctionSpec keySelectorSpec = new HashPartitionFunctionSpec(outputType, keyIndexes); // TODO: support more window types. Tuple5 windowSpecParams = WindowUtils.extractWindowParameters(windowing); @@ -196,7 +188,7 @@ protected Transformation translateToPlanInternal( PlanNode aggregation = new AggregationNode( PlanNodeIdGenerator.newId(), - AggregateStep.SINGLE, + AggregateStep.PARTIAL, groupingKeys, groupingKeys, aggNames, @@ -245,35 +237,4 @@ protected Transformation translateToPlanInternal( WINDOW_AGG_MEMORY_RATIO / 2, false); } - - private GeneratedNamespaceAggsHandleFunction createAggsHandler( - SliceAssigner sliceAssigner, - AggregateInfoList aggInfoList, - ExecNodeConfig config, - ClassLoader classLoader, - RelBuilder relBuilder, - List fieldTypes, - ZoneId shiftTimeZone) { - final AggsHandlerCodeGenerator generator = - new AggsHandlerCodeGenerator( - new CodeGeneratorContext(config, classLoader), - relBuilder, - JavaScalaConversionUtil.toScala(fieldTypes), - true) // copyInputField - .needAccumulate() - .needMerge(0, true, null); - - if (needRetraction) { - generator.needRetract(); - } - - return generator.generateNamespaceAggsHandler( - "LocalWindowAggsHandler", - aggInfoList, - JavaScalaConversionUtil.toScala(Collections.emptyList()), - sliceAssigner, - // we use window end timestamp to indicate a slicing window, see SliceAssigner - Long.class, - shiftTimeZone); - } } diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java index 28b78687f66..a3691aced85 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java @@ -293,7 +293,8 @@ protected Transformation translateToPlanInternal( "StreamExecWindowAggregate", selector.getProducedType(), aggInfoList.getAggNames(), - accTypes); + accTypes, + windowing.isRowtime()); // --- End Gluten-specific code changes --- final OneInputTransformation transform = diff --git a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/AggregateCallConverter.java b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/AggregateCallConverter.java index ca79fb01c01..bc28074ef77 100644 --- a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/AggregateCallConverter.java +++ b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/AggregateCallConverter.java @@ -56,6 +56,33 @@ public static List toAggregates( return aggregates; } + public static List toIntermediateAggregates( + AggregateCall[] aggregateCalls, io.github.zhztheplayer.velox4j.type.RowType inputType) { + List aggregates = new ArrayList<>(); + List typeExprs = new ArrayList<>(); + for (int i = 0; i < inputType.getNames().size(); i++) { + typeExprs.add( + FieldAccessTypedExpr.create(inputType.getChildren().get(i), inputType.getNames().get(i))); + } + for (int i = 0; i < aggregateCalls.length; i++) { + AggregateCall aggregateCall = aggregateCalls[i]; + CallTypedExpr call = + convertAggregation( + aggregateCall.getAggregation().getName(), + typeExprs, + RexNodeConverter.toType(aggregateCall.getType())); + aggregates.add( + new Aggregate( + call, + inputType.getChildren(), + null, + List.of(), + List.of(), + aggregateCall.isDistinct())); + } + return aggregates; + } + public static WindowFunction toFunction( AggregateCall aggregateCall, io.github.zhztheplayer.velox4j.type.RowType inputType) { CallTypedExpr call = toCall(aggregateCall, inputType); diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/streaming/runtime/partitioner/GlutenKeyGroupStreamPartitioner.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/streaming/runtime/partitioner/GlutenKeyGroupStreamPartitioner.java index 1a30c525f82..90f34f514a3 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/streaming/runtime/partitioner/GlutenKeyGroupStreamPartitioner.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/streaming/runtime/partitioner/GlutenKeyGroupStreamPartitioner.java @@ -39,16 +39,17 @@ public class GlutenKeyGroupStreamPartitioner extends StreamPartitioner implements ConfigurableStreamPartitioner { private static final long serialVersionUID = 1L; - private final KeySelector keySelector; private int maxParallelism; + private int parallelism; public GlutenKeyGroupStreamPartitioner( - KeySelector keySelector, int maxParallelism) { + KeySelector keySelector, int maxParallelism, int parallelism) { Preconditions.checkArgument(maxParallelism > 0, "Number of key-groups must be > 0!"); this.keySelector = Preconditions.checkNotNull(keySelector); this.maxParallelism = maxParallelism; + this.parallelism = parallelism; } public int getMaxParallelism() { @@ -58,8 +59,10 @@ public int getMaxParallelism() { @Override public int selectChannel(SerializationDelegate> record) { try { - int channel = keySelector.getKey(record.getInstance().getValue()); - return channel; + int key = keySelector.getKey(record.getInstance().getValue()); + int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism); + return KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup( + maxParallelism, parallelism, keyGroup); } catch (Exception e) { throw new RuntimeException( "Could not extract key from " + record.getInstance().getValue(), e); diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/config/VeloxQueryConfig.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/config/VeloxQueryConfig.java index a1026cfe49a..96631eeab7b 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/config/VeloxQueryConfig.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/config/VeloxQueryConfig.java @@ -33,6 +33,9 @@ public class VeloxQueryConfig { private static final String keyVeloxSessionTimezone = "session_timezone"; private static final String kStreamingAggregationMinOutputBatchRows = "streaming_aggregation_min_output_batch_rows"; + private static final String kMaxOutputBatchRows = "max_output_batch_rows"; + private static final String kPreferredOutputBatchRows = "preferred_output_batch_rows"; + private static final String kStatefulTaskParallelism = "stateful_task_parallelism"; public static Config getConfig(RuntimeContext context) { if (!(context instanceof StreamingRuntimeContext)) { @@ -50,6 +53,11 @@ public static Config getConfig(RuntimeContext context) { configMap.put(keyVeloxSessionTimezone, localTimeZone); } configMap.put(kStreamingAggregationMinOutputBatchRows, String.valueOf(1)); + configMap.put(kMaxOutputBatchRows, String.valueOf(Integer.MAX_VALUE)); + configMap.put(kPreferredOutputBatchRows, String.valueOf(Integer.MAX_VALUE)); + configMap.put( + kStatefulTaskParallelism, + String.valueOf(context.getTaskInfo().getNumberOfParallelSubtasks())); return Config.create(configMap); } } diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/WindowAggOperator.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/WindowAggOperator.java index f15d4c8c39f..7eb71ee551c 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/WindowAggOperator.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/WindowAggOperator.java @@ -51,6 +51,7 @@ public class WindowAggOperator extends GlutenOneInputOperator keyType; private String[] accNames; private LogicalType[] accTypes; + private boolean isRowTime = false; public WindowAggOperator( StatefulPlanNode plan, @@ -62,11 +63,13 @@ public WindowAggOperator( String description, InternalTypeInfo keyType, String[] accNames, - LogicalType[] accTypes) { + LogicalType[] accTypes, + boolean isRowTime) { super(plan, id, inputType, outputTypes, inClass, outClass, description); this.keyType = keyType; this.accNames = accNames; this.accTypes = accTypes; + this.isRowTime = isRowTime; } public InternalTypeInfo getKeyTye() { @@ -147,7 +150,8 @@ public WindowAggOperator cloneWithInputOutputClasses( getDescription(), keyType, accNames, - accTypes); + accTypes, + isRowTime); } @Override