Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.gluten.streaming.runtime.partitioner.GlutenKeyGroupStreamPartitioner;
import org.apache.gluten.table.runtime.keyselector.GlutenKeySelector;
import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator;
import org.apache.gluten.table.runtime.operators.WindowAggOperator;
import org.apache.gluten.util.LogicalTypeConverter;
import org.apache.gluten.util.PlanNodeIdGenerator;

Expand Down Expand Up @@ -112,6 +113,18 @@ public StreamExecExchange(
checkArgument(inputProperties.size() == 1);
}

private static boolean isWindowPropertyField(String fieldName) {
return "window_start".equals(fieldName)
|| "window_end".equals(fieldName)
|| "window_time".equals(fieldName);
}

private boolean isWindowAggregateExchange(OneInputTransformation inputTransform) {
return inputTransform.getOperator() instanceof WindowAggOperator
|| ((RowType) getOutputType())
.getFieldNames().stream().anyMatch(StreamExecExchange::isWindowPropertyField);
}

@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(
Expand Down Expand Up @@ -139,7 +152,8 @@ protected Transformation<RowData> translateToPlanInternal(
planner.getFlinkContext().getClassLoader(), keys, inputType);
// --- Begin Gluten-specific code changes ---
OneInputTransformation oneInputTransform = (OneInputTransformation) inputTransform;
if (oneInputTransform.getOperator() instanceof GlutenOperator) {
if (oneInputTransform.getOperator() instanceof GlutenOperator
&& !isWindowAggregateExchange(oneInputTransform)) {
// TODO: velox's parallelism need to be set here, as some nodes need it.
// should set it when operator init.
parallelism = inputTransform.getParallelism();
Expand Down Expand Up @@ -181,7 +195,8 @@ protected Transformation<RowData> translateToPlanInternal(
parallelism,
false);
partitioner =
new GlutenKeyGroupStreamPartitioner(keySelector, DEFAULT_LOWER_BOUND_MAX_PARALLELISM);
new GlutenKeyGroupStreamPartitioner(
keySelector, DEFAULT_LOWER_BOUND_MAX_PARALLELISM, parallelism);
} else {
parallelism = ExecutionConfig.PARALLELISM_DEFAULT;
partitioner =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.logical.SliceAttachedWindowingStrategy;
import org.apache.flink.table.planner.plan.logical.WindowingStrategy;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
Expand All @@ -63,21 +64,27 @@
import org.apache.flink.table.runtime.util.TimeWindowUtil;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.RowType.RowField;

import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;

import org.apache.calcite.rel.core.AggregateCall;
import org.apache.commons.math3.util.ArithmeticUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;

import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TimeZone;
import java.util.stream.Collectors;

Expand All @@ -95,6 +102,7 @@
minStateVersion = FlinkVersion.v1_15)
public class StreamExecGlobalWindowAggregate extends StreamExecWindowAggregateBase {

private static final Logger LOG = LoggerFactory.getLogger(StreamExecGlobalWindowAggregate.class);
public static final String GLOBAL_WINDOW_AGGREGATE_TRANSFORMATION = "global-window-aggregate";

public static final String FIELD_NAME_LOCAL_AGG_INPUT_ROW_TYPE = "localAggInputRowType";
Expand Down Expand Up @@ -167,6 +175,13 @@ public StreamExecGlobalWindowAggregate(
this.needRetraction = Optional.ofNullable(needRetraction).orElse(false);
}

private int getSliceEndIndex() {
if (windowing instanceof SliceAttachedWindowingStrategy) {
return ((SliceAttachedWindowingStrategy) windowing).getSliceEnd();
}
return -1;
}

@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(
Expand All @@ -179,6 +194,18 @@ protected Transformation<RowData> translateToPlanInternal(
final ZoneId shiftTimeZone =
TimeWindowUtil.getShiftTimeZone(
windowing.getTimeAttributeType(), TableConfigUtils.getLocalTimeZone(config));
Set<Integer> nonAggFieldIndexes = new HashSet<>();
Arrays.stream(grouping).forEach(nonAggFieldIndexes::add);
nonAggFieldIndexes.add(getSliceEndIndex());
List<RowField> intermediateAggInputRowFields = new ArrayList<>();
for (int i = 0; i < inputRowType.getFieldNames().size(); i++) {
RowField rowField =
new RowField(inputRowType.getFieldNames().get(i), inputRowType.getChildren().get(i));
if (!nonAggFieldIndexes.contains(i)) {
intermediateAggInputRowFields.add(rowField);
}
}
final RowType intermediateInputType = new RowType(intermediateAggInputRowFields);
final AggregateInfoList globalAggInfoList =
AggregateUtil.deriveStreamWindowAggregateInfoList(
planner.getTypeFactory(),
Expand All @@ -187,16 +214,21 @@ protected Transformation<RowData> translateToPlanInternal(
needRetraction,
windowing.getWindow(),
true);

// --- Begin Gluten-specific code changes ---
// TODO: velox window not equal to flink window.
io.github.zhztheplayer.velox4j.type.RowType intermediateInputRowType =
(io.github.zhztheplayer.velox4j.type.RowType)
LogicalTypeConverter.toVLType(intermediateInputType);
io.github.zhztheplayer.velox4j.type.RowType inputType =
(io.github.zhztheplayer.velox4j.type.RowType) LogicalTypeConverter.toVLType(inputRowType);
io.github.zhztheplayer.velox4j.type.RowType outputType =
(io.github.zhztheplayer.velox4j.type.RowType)
LogicalTypeConverter.toVLType(getOutputType());
List<FieldAccessTypedExpr> groupingKeys = Utils.generateFieldAccesses(inputType, grouping);
List<Aggregate> aggregates = AggregateCallConverter.toAggregates(aggCalls, inputType);
List<Aggregate> intermediateAggregates =
AggregateCallConverter.toIntermediateAggregates(aggCalls, intermediateInputRowType);
List<Aggregate> finalAggregates =
AggregateCallConverter.toIntermediateAggregates(aggCalls, intermediateInputRowType);
checkArgument(outputType.getNames().size() >= grouping.length + aggCalls.length);
List<String> aggNames =
outputType.getNames().stream()
Expand All @@ -220,35 +252,35 @@ protected Transformation<RowData> translateToPlanInternal(
PartitionFunctionSpec sliceAssignerSpec =
new StreamWindowPartitionFunctionSpec(
inputType, rowtimeIndex, size, slide, offset, windowType);
PlanNode aggregation =
PlanNode finalAgg =
new AggregationNode(
PlanNodeIdGenerator.newId(),
AggregateStep.SINGLE,
AggregateStep.FINAL,
groupingKeys,
groupingKeys,
aggNames,
aggregates,
finalAggregates,
false,
List.of(new EmptyNode(inputType)),
null,
List.of());
PlanNode localAgg =
PlanNode intermediateAgg =
new AggregationNode(
PlanNodeIdGenerator.newId(),
AggregateStep.SINGLE,
AggregateStep.INTERMEDIATE,
groupingKeys,
groupingKeys,
aggNames,
aggregates,
intermediateAggregates,
false,
List.of(new EmptyNode(inputType)),
null,
List.of());
PlanNode windowAgg =
new StreamWindowAggregationNode(
PlanNodeIdGenerator.newId(),
aggregation,
localAgg,
finalAgg,
intermediateAgg,
keySelectorSpec,
sliceAssignerSpec,
ArithmeticUtils.gcd(size, slide),
Expand Down Expand Up @@ -280,7 +312,8 @@ protected Transformation<RowData> translateToPlanInternal(
"StreamExecWindowAggregate",
selector.getProducedType(),
globalAggInfoList.getAggNames(),
accTypes);
accTypes,
windowing.isRowtime());
// --- End Gluten-specific code changes ---

final OneInputTransformation<RowData, RowData> transform =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,10 @@
*/
package org.apache.flink.table.planner.plan.nodes.exec.stream;

import org.apache.gluten.rexnode.AggregateCallConverter;
import org.apache.gluten.rexnode.Utils;
import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator;
import org.apache.gluten.util.LogicalTypeConverter;
import org.apache.gluten.util.PlanNodeIdGenerator;

import io.github.zhztheplayer.velox4j.aggregate.Aggregate;
import io.github.zhztheplayer.velox4j.expression.FieldAccessTypedExpr;
import io.github.zhztheplayer.velox4j.plan.GroupWindowAggregationNode;
import io.github.zhztheplayer.velox4j.plan.GroupWindowAggsHandlerNode;
import io.github.zhztheplayer.velox4j.plan.HashPartitionFunctionSpec;
Expand Down Expand Up @@ -213,14 +209,7 @@ protected Transformation<RowData> translateToPlanInternal(
io.github.zhztheplayer.velox4j.type.RowType outputType =
(io.github.zhztheplayer.velox4j.type.RowType)
LogicalTypeConverter.toVLType(getOutputType());
List<FieldAccessTypedExpr> groupingKeys = Utils.generateFieldAccesses(inputType, grouping);
List<Aggregate> aggregates = AggregateCallConverter.toAggregates(aggCalls, inputType);
checkArgument(outputType.getNames().size() >= grouping.length + aggCalls.length);
List<String> aggNames =
outputType.getNames().stream()
.skip(grouping.length)
.limit(aggCalls.length)
.collect(Collectors.toList());
List<Integer> keyIndexes = Arrays.stream(grouping).boxed().collect(Collectors.toList());
PartitionFunctionSpec keySelectorSpec = new HashPartitionFunctionSpec(inputType, keyIndexes);
// TODO: support more window types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.logical.WindowingStrategy;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
Expand All @@ -53,21 +51,15 @@
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeMetadata;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.planner.utils.TableConfigUtils;
import org.apache.flink.table.runtime.generated.GeneratedNamespaceAggsHandleFunction;
import org.apache.flink.table.runtime.operators.window.tvf.slicing.SliceAssigner;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.runtime.util.TimeWindowUtil;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;

import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;

import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.tools.RelBuilder;
import org.apache.commons.math3.util.ArithmeticUtils;

import javax.annotation.Nullable;
Expand Down Expand Up @@ -181,7 +173,7 @@ protected Transformation<RowData> translateToPlanInternal(
.limit(aggCalls.length)
.collect(Collectors.toList());
List<Integer> keyIndexes = Arrays.stream(grouping).boxed().collect(Collectors.toList());
PartitionFunctionSpec keySelectorSpec = new HashPartitionFunctionSpec(inputType, keyIndexes);
PartitionFunctionSpec keySelectorSpec = new HashPartitionFunctionSpec(outputType, keyIndexes);
// TODO: support more window types.
Tuple5<Long, Long, Long, Integer, Integer> windowSpecParams =
WindowUtils.extractWindowParameters(windowing);
Expand All @@ -196,7 +188,7 @@ protected Transformation<RowData> translateToPlanInternal(
PlanNode aggregation =
new AggregationNode(
PlanNodeIdGenerator.newId(),
AggregateStep.SINGLE,
AggregateStep.PARTIAL,
groupingKeys,
groupingKeys,
aggNames,
Expand Down Expand Up @@ -245,35 +237,4 @@ protected Transformation<RowData> translateToPlanInternal(
WINDOW_AGG_MEMORY_RATIO / 2,
false);
}

private GeneratedNamespaceAggsHandleFunction<Long> createAggsHandler(
SliceAssigner sliceAssigner,
AggregateInfoList aggInfoList,
ExecNodeConfig config,
ClassLoader classLoader,
RelBuilder relBuilder,
List<LogicalType> fieldTypes,
ZoneId shiftTimeZone) {
final AggsHandlerCodeGenerator generator =
new AggsHandlerCodeGenerator(
new CodeGeneratorContext(config, classLoader),
relBuilder,
JavaScalaConversionUtil.toScala(fieldTypes),
true) // copyInputField
.needAccumulate()
.needMerge(0, true, null);

if (needRetraction) {
generator.needRetract();
}

return generator.generateNamespaceAggsHandler(
"LocalWindowAggsHandler",
aggInfoList,
JavaScalaConversionUtil.toScala(Collections.emptyList()),
sliceAssigner,
// we use window end timestamp to indicate a slicing window, see SliceAssigner
Long.class,
shiftTimeZone);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ protected Transformation<RowData> translateToPlanInternal(
"StreamExecWindowAggregate",
selector.getProducedType(),
aggInfoList.getAggNames(),
accTypes);
accTypes,
windowing.isRowtime());
// --- End Gluten-specific code changes ---

final OneInputTransformation<RowData, RowData> transform =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,33 @@ public static List<Aggregate> toAggregates(
return aggregates;
}

public static List<Aggregate> toIntermediateAggregates(
AggregateCall[] aggregateCalls, io.github.zhztheplayer.velox4j.type.RowType inputType) {
List<Aggregate> aggregates = new ArrayList<>();
List<TypedExpr> typeExprs = new ArrayList<>();
for (int i = 0; i < inputType.getNames().size(); i++) {
typeExprs.add(
FieldAccessTypedExpr.create(inputType.getChildren().get(i), inputType.getNames().get(i)));
}
for (int i = 0; i < aggregateCalls.length; i++) {
AggregateCall aggregateCall = aggregateCalls[i];
CallTypedExpr call =
convertAggregation(
aggregateCall.getAggregation().getName(),
typeExprs,
RexNodeConverter.toType(aggregateCall.getType()));
aggregates.add(
new Aggregate(
call,
inputType.getChildren(),
null,
List.of(),
List.of(),
aggregateCall.isDistinct()));
}
return aggregates;
}

public static WindowFunction toFunction(
AggregateCall aggregateCall, io.github.zhztheplayer.velox4j.type.RowType inputType) {
CallTypedExpr call = toCall(aggregateCall, inputType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@
public class GlutenKeyGroupStreamPartitioner extends StreamPartitioner<StatefulRecord>
implements ConfigurableStreamPartitioner {
private static final long serialVersionUID = 1L;

private final KeySelector<StatefulRecord, Integer> keySelector;

private int maxParallelism;
private int parallelism;

public GlutenKeyGroupStreamPartitioner(
KeySelector<StatefulRecord, Integer> keySelector, int maxParallelism) {
KeySelector<StatefulRecord, Integer> keySelector, int maxParallelism, int parallelism) {
Preconditions.checkArgument(maxParallelism > 0, "Number of key-groups must be > 0!");
this.keySelector = Preconditions.checkNotNull(keySelector);
this.maxParallelism = maxParallelism;
this.parallelism = parallelism;
}

public int getMaxParallelism() {
Expand All @@ -58,8 +59,10 @@ public int getMaxParallelism() {
@Override
public int selectChannel(SerializationDelegate<StreamRecord<StatefulRecord>> record) {
try {
int channel = keySelector.getKey(record.getInstance().getValue());
return channel;
int key = keySelector.getKey(record.getInstance().getValue());
int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
return KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
maxParallelism, parallelism, keyGroup);
} catch (Exception e) {
throw new RuntimeException(
"Could not extract key from " + record.getInstance().getValue(), e);
Expand Down
Loading
Loading