1818
1919package org .apache .flink .table .planner .functions .sql .ml ;
2020
21+ import org .apache .flink .configuration .Configuration ;
2122import org .apache .flink .table .api .ValidationException ;
23+ import org .apache .flink .table .api .config .VectorSearchRuntimeConfigOptions ;
2224import org .apache .flink .table .planner .functions .utils .SqlValidatorUtils ;
2325import org .apache .flink .table .types .logical .ArrayType ;
2426import org .apache .flink .table .types .logical .LogicalType ;
2527import org .apache .flink .table .types .logical .LogicalTypeRoot ;
2628import org .apache .flink .table .types .logical .utils .LogicalTypeCasts ;
29+ import org .apache .flink .types .Either ;
2730
2831import org .apache .calcite .rel .type .RelDataType ;
2932import org .apache .calcite .rel .type .RelDataTypeFactory ;
3033import org .apache .calcite .rel .type .RelDataTypeFieldImpl ;
3134import org .apache .calcite .sql .SqlCall ;
3235import org .apache .calcite .sql .SqlCallBinding ;
36+ import org .apache .calcite .sql .SqlCharStringLiteral ;
37+ import org .apache .calcite .sql .SqlDataTypeSpec ;
3338import org .apache .calcite .sql .SqlFunction ;
3439import org .apache .calcite .sql .SqlFunctionCategory ;
3540import org .apache .calcite .sql .SqlIdentifier ;
3944import org .apache .calcite .sql .SqlOperator ;
4045import org .apache .calcite .sql .SqlOperatorBinding ;
4146import org .apache .calcite .sql .SqlTableFunction ;
47+ import org .apache .calcite .sql .type .MapSqlType ;
4248import org .apache .calcite .sql .type .ReturnTypes ;
4349import org .apache .calcite .sql .type .SqlOperandCountRanges ;
4450import org .apache .calcite .sql .type .SqlOperandMetadata ;
4551import org .apache .calcite .sql .type .SqlReturnTypeInference ;
4652import org .apache .calcite .sql .type .SqlTypeName ;
4753import org .apache .calcite .sql .validate .SqlNameMatcher ;
54+ import org .apache .calcite .sql .validate .SqlValidator ;
55+ import org .apache .calcite .util .NlsString ;
4856import org .apache .calcite .util .Util ;
4957import org .checkerframework .checker .nullness .qual .Nullable ;
5058
5159import java .util .Arrays ;
5260import java .util .Collections ;
61+ import java .util .HashMap ;
5362import java .util .List ;
63+ import java .util .Map ;
5464import java .util .Optional ;
65+ import java .util .function .Function ;
5566
5667import static org .apache .flink .table .planner .calcite .FlinkTypeFactory .toLogicalType ;
68+ import static org .apache .flink .table .planner .functions .utils .SqlValidatorUtils .reduceLiteralToString ;
69+ import static org .apache .flink .table .types .logical .LogicalTypeFamily .CHARACTER_STRING ;
5770
5871/**
5972 * {@link SqlVectorSearchTableFunction} implements an operator for search.
6578 * <li>a descriptor to provide a column name from the input table
6679 * <li>a query column from the left table
6780 * <li>a literal value for top k
81+ * <li>an optional config map
6882 * </ol>
6983 */
7084public class SqlVectorSearchTableFunction extends SqlFunction implements SqlTableFunction {
@@ -73,6 +87,7 @@ public class SqlVectorSearchTableFunction extends SqlFunction implements SqlTabl
7387 private static final String PARAM_COLUMN_TO_SEARCH = "COLUMN_TO_SEARCH" ;
7488 private static final String PARAM_COLUMN_TO_QUERY = "COLUMN_TO_QUERY" ;
7589 private static final String PARAM_TOP_K = "TOP_K" ;
90+ private static final String PARAM_CONFIG = "CONFIG" ;
7691
7792 private static final String OUTPUT_SCORE = "score" ;
7893
@@ -92,7 +107,10 @@ public SqlReturnTypeInference getRowTypeInference() {
92107 @ Override
93108 public @ Nullable RelDataType inferReturnType (SqlOperatorBinding opBinding ) {
94109 final RelDataTypeFactory typeFactory = opBinding .getTypeFactory ();
95- final RelDataType inputRowType = opBinding .getOperandType (0 );
110+ SqlCallBinding callBinding = (SqlCallBinding ) opBinding ;
111+ List <SqlNode > operands = callBinding .operands ();
112+ final RelDataType inputRowType =
113+ callBinding .getValidator ().getValidatedNodeType (operands .get (0 ));
96114
97115 return typeFactory
98116 .builder ()
@@ -125,7 +143,10 @@ private static class OperandMetadataImpl implements SqlOperandMetadata {
125143 PARAM_SEARCH_TABLE ,
126144 PARAM_COLUMN_TO_SEARCH ,
127145 PARAM_COLUMN_TO_QUERY ,
128- PARAM_TOP_K ));
146+ PARAM_TOP_K ,
147+ PARAM_CONFIG ));
148+
149+ private static final int OPTIONAL_ARG_IDX = 4 ;
129150
130151 @ Override
131152 public List <RelDataType > paramTypes (RelDataTypeFactory relDataTypeFactory ) {
@@ -217,18 +238,23 @@ public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFail
217238 topK ))),
218239 throwOnFailure );
219240 }
220- return true ;
241+
242+ // check config type
243+ return SqlValidatorUtils .throwExceptionOrReturnFalse (
244+ checkOptionalConfigOperands (
245+ callBinding , 4 , SqlValidatorUtils ::checkConfigValue ),
246+ throwOnFailure );
221247 }
222248
223249 @ Override
224250 public SqlOperandCountRange getOperandCountRange () {
225- return SqlOperandCountRanges .between (4 , 4 );
251+ return SqlOperandCountRanges .between (4 , 5 );
226252 }
227253
228254 @ Override
229255 public String getAllowedSignatures (SqlOperator op , String opName ) {
230256 return opName
231- + "(TABLE search_table, DESCRIPTOR(column_to_search), column_to_query, top_k)" ;
257+ + "(TABLE search_table, DESCRIPTOR(column_to_search), column_to_query, top_k, [MAP['key1', 'value1']...] )" ;
232258 }
233259
234260 @ Override
@@ -238,12 +264,64 @@ public Consistency getConsistency() {
238264
239265 @ Override
240266 public boolean isOptional (int i ) {
241- return false ;
267+ return i == OPTIONAL_ARG_IDX ;
242268 }
269+ }
243270
244- @ Override
245- public boolean isFixedParameters () {
246- return true ;
271+ /**
272+ * Check optional config parameter. Config parameter is a map that define some parameters and
273+ * values.
274+ *
275+ * @param callBinding The call binding
276+ * @param configLocation The location of the config parameter
277+ * @param checkConfigValue Check value in the config map.
278+ */
279+ public static Optional <RuntimeException > checkOptionalConfigOperands (
280+ SqlCallBinding callBinding ,
281+ int configLocation ,
282+ Function <Map <String , String >, Optional <RuntimeException >> checkConfigValue ) {
283+ if (callBinding .getOperandCount () <= configLocation ) {
284+ return Optional .empty ();
285+ }
286+
287+ SqlNode configNode = callBinding .operand (configLocation );
288+ if (!configNode .getKind ().equals (SqlKind .MAP_VALUE_CONSTRUCTOR )) {
289+ return Optional .of (new ValidationException ("Config param should be a MAP." ));
290+ }
291+
292+ RelDataType mapType =
293+ callBinding
294+ .getValidator ()
295+ .getValidatedNodeType (callBinding .operand (configLocation ));
296+
297+ assert mapType instanceof MapSqlType ;
298+
299+ LogicalType keyType = toLogicalType (mapType .getKeyType ());
300+ LogicalType valueType = toLogicalType (mapType .getValueType ());
301+ if (!keyType .is (CHARACTER_STRING ) || !valueType .is (CHARACTER_STRING )) {
302+ return Optional .of (
303+ new ValidationException (
304+ String .format (
305+ "Config param can only be a MAP of string literals but node's type is %s at position %s." ,
306+ mapType , callBinding .operand (3 ).getParserPosition ())));
307+ }
308+
309+ List <SqlNode > operands = ((SqlCall ) configNode ).getOperandList ();
310+ Map <String , String > runtimeConfig = new HashMap <>();
311+ for (int i = 0 ; i < operands .size (); i += 2 ) {
312+ Either <String , RuntimeException > key =
313+ reduceLiteralToString (operands .get (i ), callBinding .getValidator ());
314+ Either <String , RuntimeException > value =
315+ reduceLiteralToString (operands .get (i + 1 ), callBinding .getValidator ());
316+
317+ if (key .isRight ()) {
318+ return Optional .of (key .right ());
319+ } else if (value .isRight ()) {
320+ return Optional .of (value .right ());
321+ } else {
322+ runtimeConfig .put (key .left (), value .left ());
323+ }
247324 }
325+ return checkConfigValue .apply (runtimeConfig );
248326 }
249327}
0 commit comments