Skip to content

Commit 46948fc

Browse files
committed
[FLINK-38430][table] Support runtime config for VECTOR_SEARCH
1 parent b3405e9 commit 46948fc

File tree

19 files changed

+741
-113
lines changed

19 files changed

+741
-113
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.api.config;
20+
21+
import org.apache.flink.annotation.PublicEvolving;
22+
import org.apache.flink.configuration.ConfigOption;
23+
24+
import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableSet;
25+
26+
import java.time.Duration;
27+
import java.util.HashSet;
28+
import java.util.Set;
29+
30+
import static org.apache.flink.configuration.ConfigOptions.key;
31+
32+
/**
33+
* This class holds option name definitions for VECTOR_SEARCH runtime config based on {@link
34+
* ConfigOption}.
35+
*/
36+
@PublicEvolving
37+
public class VectorSearchRuntimeConfigOptions {
38+
39+
public static final ConfigOption<Boolean> ASYNC =
40+
key("async")
41+
.booleanType()
42+
.noDefaultValue()
43+
.withDescription(
44+
"Value can be 'true' or 'false' to suggest the planner choose the corresponding"
45+
+ " predict function. If the backend search function provider does not support the"
46+
+ " suggested mode, it will throw exception to notify users.");
47+
48+
public static final ConfigOption<ExecutionConfigOptions.AsyncOutputMode> ASYNC_OUTPUT_MODE =
49+
key("output-mode")
50+
.enumType(ExecutionConfigOptions.AsyncOutputMode.class)
51+
.noDefaultValue()
52+
.withDescription(
53+
"Output mode for asynchronous operations which will convert to {@see AsyncDataStream.OutputMode}, ORDERED by default. "
54+
+ "If set to ALLOW_UNORDERED, will attempt to use {@see AsyncDataStream.OutputMode.UNORDERED} when it does not "
55+
+ "affect the correctness of the result, otherwise ORDERED will be still used.");
56+
57+
public static final ConfigOption<Integer> ASYNC_MAX_CONCURRENT_OPERATIONS =
58+
key("max-concurrent-operations")
59+
.intType()
60+
.noDefaultValue()
61+
.withDescription(
62+
"The max number of async i/o operation that the async ml predict can trigger.");
63+
64+
public static final ConfigOption<Duration> ASYNC_TIMEOUT =
65+
key("timeout")
66+
.durationType()
67+
.noDefaultValue()
68+
.withDescription(
69+
"Timeout from first invoke to final completion of asynchronous operation, may include multiple"
70+
+ " retries, and will be reset in case of failover.");
71+
72+
private static final Set<ConfigOption<?>> supportedKeys = new HashSet<>();
73+
74+
static {
75+
supportedKeys.add(ASYNC);
76+
supportedKeys.add(ASYNC_OUTPUT_MODE);
77+
supportedKeys.add(ASYNC_MAX_CONCURRENT_OPERATIONS);
78+
supportedKeys.add(ASYNC_TIMEOUT);
79+
}
80+
81+
public static ImmutableSet<ConfigOption> getSupportedOptions() {
82+
return ImmutableSet.copyOf(supportedKeys);
83+
}
84+
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlMLTableFunction.java

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.flink.configuration.Configuration;
2222
import org.apache.flink.table.api.ValidationException;
2323
import org.apache.flink.table.api.config.MLPredictRuntimeConfigOptions;
24+
import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;
2425
import org.apache.flink.table.types.logical.LogicalType;
2526
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
2627
import org.apache.flink.types.Either;
@@ -56,6 +57,7 @@
5657
import static org.apache.flink.table.api.config.MLPredictRuntimeConfigOptions.ASYNC;
5758
import static org.apache.flink.table.api.config.MLPredictRuntimeConfigOptions.ASYNC_MAX_CONCURRENT_OPERATIONS;
5859
import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;
60+
import static org.apache.flink.table.planner.functions.utils.SqlValidatorUtils.reduceLiteralToString;
5961
import static org.apache.flink.table.types.logical.LogicalTypeFamily.CHARACTER_STRING;
6062

6163
/**
@@ -219,9 +221,9 @@ protected static Optional<RuntimeException> checkConfig(
219221
Map<String, String> runtimeConfig = new HashMap<>();
220222
for (int i = 0; i < operands.size(); i += 2) {
221223
Either<String, RuntimeException> key =
222-
reduceLiteral(operands.get(i), callBinding.getValidator());
224+
reduceLiteralToString(operands.get(i), callBinding.getValidator());
223225
Either<String, RuntimeException> value =
224-
reduceLiteral(operands.get(i + 1), callBinding.getValidator());
226+
reduceLiteralToString(operands.get(i + 1), callBinding.getValidator());
225227

226228
if (key.isRight()) {
227229
return Optional.of(key.right());
@@ -232,56 +234,6 @@ protected static Optional<RuntimeException> checkConfig(
232234
}
233235
}
234236

235-
return checkConfigValue(runtimeConfig);
236-
}
237-
238-
private static Optional<RuntimeException> checkConfigValue(Map<String, String> runtimeConfig) {
239-
Configuration config = Configuration.fromMap(runtimeConfig);
240-
try {
241-
MLPredictRuntimeConfigOptions.getSupportedOptions().forEach(config::get);
242-
} catch (Throwable t) {
243-
return Optional.of(new ValidationException("Failed to parse the config.", t));
244-
}
245-
246-
// option value check
247-
// async options are all optional
248-
Boolean async = config.get(ASYNC);
249-
if (Boolean.TRUE.equals(async)) {
250-
Integer maxConcurrentOperations = config.get(ASYNC_MAX_CONCURRENT_OPERATIONS);
251-
if (maxConcurrentOperations != null && maxConcurrentOperations <= 0) {
252-
return Optional.of(
253-
new ValidationException(
254-
String.format(
255-
"Invalid runtime config option '%s'. Its value should be positive integer but was %s.",
256-
ASYNC_MAX_CONCURRENT_OPERATIONS.key(),
257-
maxConcurrentOperations)));
258-
}
259-
}
260-
261-
return Optional.empty();
262-
}
263-
264-
private static Either<String, RuntimeException> reduceLiteral(
265-
SqlNode operand, SqlValidator validator) {
266-
if (operand instanceof SqlCharStringLiteral) {
267-
return Either.Left(
268-
((SqlCharStringLiteral) operand).getValueAs(NlsString.class).getValue());
269-
} else if (operand.getKind() == SqlKind.CAST) {
270-
// CAST(CAST('v' AS STRING) AS STRING)
271-
SqlCall call = (SqlCall) operand;
272-
SqlDataTypeSpec dataType = call.operand(1);
273-
if (!toLogicalType(dataType.deriveType(validator)).is(CHARACTER_STRING)) {
274-
return Either.Right(
275-
new ValidationException("Don't support to cast value to non-string type."));
276-
}
277-
return reduceLiteral((call.operand(0)), validator);
278-
} else {
279-
return Either.Right(
280-
new ValidationException(
281-
String.format(
282-
"Unsupported expression %s is in runtime config at position %s. Currently, "
283-
+ "runtime config should be be a MAP of string literals.",
284-
operand, operand.getParserPosition())));
285-
}
237+
return SqlValidatorUtils.checkConfigValue(runtimeConfig);
286238
}
287239
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,23 @@
1818

1919
package org.apache.flink.table.planner.functions.sql.ml;
2020

21+
import org.apache.flink.configuration.Configuration;
2122
import org.apache.flink.table.api.ValidationException;
23+
import org.apache.flink.table.api.config.VectorSearchRuntimeConfigOptions;
2224
import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;
2325
import org.apache.flink.table.types.logical.ArrayType;
2426
import org.apache.flink.table.types.logical.LogicalType;
2527
import org.apache.flink.table.types.logical.LogicalTypeRoot;
2628
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
29+
import org.apache.flink.types.Either;
2730

2831
import org.apache.calcite.rel.type.RelDataType;
2932
import org.apache.calcite.rel.type.RelDataTypeFactory;
3033
import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
3134
import org.apache.calcite.sql.SqlCall;
3235
import org.apache.calcite.sql.SqlCallBinding;
36+
import org.apache.calcite.sql.SqlCharStringLiteral;
37+
import org.apache.calcite.sql.SqlDataTypeSpec;
3338
import org.apache.calcite.sql.SqlFunction;
3439
import org.apache.calcite.sql.SqlFunctionCategory;
3540
import org.apache.calcite.sql.SqlIdentifier;
@@ -39,21 +44,29 @@
3944
import org.apache.calcite.sql.SqlOperator;
4045
import org.apache.calcite.sql.SqlOperatorBinding;
4146
import org.apache.calcite.sql.SqlTableFunction;
47+
import org.apache.calcite.sql.type.MapSqlType;
4248
import org.apache.calcite.sql.type.ReturnTypes;
4349
import org.apache.calcite.sql.type.SqlOperandCountRanges;
4450
import org.apache.calcite.sql.type.SqlOperandMetadata;
4551
import org.apache.calcite.sql.type.SqlReturnTypeInference;
4652
import org.apache.calcite.sql.type.SqlTypeName;
4753
import org.apache.calcite.sql.validate.SqlNameMatcher;
54+
import org.apache.calcite.sql.validate.SqlValidator;
55+
import org.apache.calcite.util.NlsString;
4856
import org.apache.calcite.util.Util;
4957
import org.checkerframework.checker.nullness.qual.Nullable;
5058

5159
import java.util.Arrays;
5260
import java.util.Collections;
61+
import java.util.HashMap;
5362
import java.util.List;
63+
import java.util.Map;
5464
import java.util.Optional;
65+
import java.util.function.Function;
5566

5667
import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;
68+
import static org.apache.flink.table.planner.functions.utils.SqlValidatorUtils.reduceLiteralToString;
69+
import static org.apache.flink.table.types.logical.LogicalTypeFamily.CHARACTER_STRING;
5770

5871
/**
5972
* {@link SqlVectorSearchTableFunction} implements an operator for search.
@@ -65,6 +78,7 @@
6578
* <li>a descriptor to provide a column name from the input table
6679
* <li>a query column from the left table
6780
* <li>a literal value for top k
81+
* <li>an optional config map
6882
* </ol>
6983
*/
7084
public class SqlVectorSearchTableFunction extends SqlFunction implements SqlTableFunction {
@@ -73,6 +87,7 @@ public class SqlVectorSearchTableFunction extends SqlFunction implements SqlTabl
7387
private static final String PARAM_COLUMN_TO_SEARCH = "COLUMN_TO_SEARCH";
7488
private static final String PARAM_COLUMN_TO_QUERY = "COLUMN_TO_QUERY";
7589
private static final String PARAM_TOP_K = "TOP_K";
90+
private static final String PARAM_CONFIG = "CONFIG";
7691

7792
private static final String OUTPUT_SCORE = "score";
7893

@@ -92,7 +107,10 @@ public SqlReturnTypeInference getRowTypeInference() {
92107
@Override
93108
public @Nullable RelDataType inferReturnType(SqlOperatorBinding opBinding) {
94109
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
95-
final RelDataType inputRowType = opBinding.getOperandType(0);
110+
SqlCallBinding callBinding = (SqlCallBinding) opBinding;
111+
List<SqlNode> operands = callBinding.operands();
112+
final RelDataType inputRowType =
113+
callBinding.getValidator().getValidatedNodeType(operands.get(0));
96114

97115
return typeFactory
98116
.builder()
@@ -125,7 +143,10 @@ private static class OperandMetadataImpl implements SqlOperandMetadata {
125143
PARAM_SEARCH_TABLE,
126144
PARAM_COLUMN_TO_SEARCH,
127145
PARAM_COLUMN_TO_QUERY,
128-
PARAM_TOP_K));
146+
PARAM_TOP_K,
147+
PARAM_CONFIG));
148+
149+
private static final int OPTIONAL_ARG_IDX = 4;
129150

130151
@Override
131152
public List<RelDataType> paramTypes(RelDataTypeFactory relDataTypeFactory) {
@@ -217,18 +238,23 @@ public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFail
217238
topK))),
218239
throwOnFailure);
219240
}
220-
return true;
241+
242+
// check config type
243+
return SqlValidatorUtils.throwExceptionOrReturnFalse(
244+
checkOptionalConfigOperands(
245+
callBinding, 4, SqlValidatorUtils::checkConfigValue),
246+
throwOnFailure);
221247
}
222248

223249
@Override
224250
public SqlOperandCountRange getOperandCountRange() {
225-
return SqlOperandCountRanges.between(4, 4);
251+
return SqlOperandCountRanges.between(4, 5);
226252
}
227253

228254
@Override
229255
public String getAllowedSignatures(SqlOperator op, String opName) {
230256
return opName
231-
+ "(TABLE search_table, DESCRIPTOR(column_to_search), column_to_query, top_k)";
257+
+ "(TABLE search_table, DESCRIPTOR(column_to_search), column_to_query, top_k, [MAP['key1', 'value1']...])";
232258
}
233259

234260
@Override
@@ -238,12 +264,64 @@ public Consistency getConsistency() {
238264

239265
@Override
240266
public boolean isOptional(int i) {
241-
return false;
267+
return i == OPTIONAL_ARG_IDX;
242268
}
269+
}
243270

244-
@Override
245-
public boolean isFixedParameters() {
246-
return true;
271+
/**
272+
* Check optional config parameter. Config parameter is a map that define some parameters and
273+
* values.
274+
*
275+
* @param callBinding The call binding
276+
* @param configLocation The location of the config parameter
277+
* @param checkConfigValue Check value in the config map.
278+
*/
279+
public static Optional<RuntimeException> checkOptionalConfigOperands(
280+
SqlCallBinding callBinding,
281+
int configLocation,
282+
Function<Map<String, String>, Optional<RuntimeException>> checkConfigValue) {
283+
if (callBinding.getOperandCount() <= configLocation) {
284+
return Optional.empty();
285+
}
286+
287+
SqlNode configNode = callBinding.operand(configLocation);
288+
if (!configNode.getKind().equals(SqlKind.MAP_VALUE_CONSTRUCTOR)) {
289+
return Optional.of(new ValidationException("Config param should be a MAP."));
290+
}
291+
292+
RelDataType mapType =
293+
callBinding
294+
.getValidator()
295+
.getValidatedNodeType(callBinding.operand(configLocation));
296+
297+
assert mapType instanceof MapSqlType;
298+
299+
LogicalType keyType = toLogicalType(mapType.getKeyType());
300+
LogicalType valueType = toLogicalType(mapType.getValueType());
301+
if (!keyType.is(CHARACTER_STRING) || !valueType.is(CHARACTER_STRING)) {
302+
return Optional.of(
303+
new ValidationException(
304+
String.format(
305+
"Config param can only be a MAP of string literals but node's type is %s at position %s.",
306+
mapType, callBinding.operand(3).getParserPosition())));
307+
}
308+
309+
List<SqlNode> operands = ((SqlCall) configNode).getOperandList();
310+
Map<String, String> runtimeConfig = new HashMap<>();
311+
for (int i = 0; i < operands.size(); i += 2) {
312+
Either<String, RuntimeException> key =
313+
reduceLiteralToString(operands.get(i), callBinding.getValidator());
314+
Either<String, RuntimeException> value =
315+
reduceLiteralToString(operands.get(i + 1), callBinding.getValidator());
316+
317+
if (key.isRight()) {
318+
return Optional.of(key.right());
319+
} else if (value.isRight()) {
320+
return Optional.of(value.right());
321+
} else {
322+
runtimeConfig.put(key.left(), value.left());
323+
}
247324
}
325+
return checkConfigValue.apply(runtimeConfig);
248326
}
249327
}

0 commit comments

Comments
 (0)