Skip to content

Commit f14fcd3

Browse files
authored
[FLINK-38104][table] Add table api support for ML_PREDICT (#27108)
1 parent af86700 commit f14fcd3

File tree

35 files changed

+1340
-34
lines changed

35 files changed

+1340
-34
lines changed

flink-python/pyflink/table/tests/test_environment_completeness.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
# limitations under the License.
1717
################################################################################
1818

19-
from pyflink.testing.test_case_utils import PythonAPICompletenessTestCase, PyFlinkTestCase
2019
from pyflink.table import TableEnvironment
20+
from pyflink.testing.test_case_utils import PythonAPICompletenessTestCase, \
21+
PyFlinkTestCase
2122

2223

2324
class EnvironmentAPICompletenessTests(PythonAPICompletenessTestCase, PyFlinkTestCase):
@@ -40,6 +41,7 @@ def excluded_methods(cls):
4041
'getCompletionHints',
4142
'fromValues',
4243
'fromCall',
44+
'fromModel',
4345
# See FLINK-25986
4446
'loadPlan',
4547
'compilePlanSql',

flink-python/pyflink/table/tests/test_table_environment_completeness.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def excluded_methods(cls):
4444
"from",
4545
"registerFunction",
4646
"fromCall",
47+
"fromModel",
4748
}
4849

4950

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.api;
20+
21+
import org.apache.flink.annotation.PublicEvolving;
22+
import org.apache.flink.table.api.config.MLPredictRuntimeConfigOptions;
23+
import org.apache.flink.table.catalog.ResolvedSchema;
24+
import org.apache.flink.types.ColumnList;
25+
26+
import java.util.Map;
27+
28+
/**
29+
* The {@link Model} object is the core abstraction for ML model resources in the Table API.
30+
*
31+
* <p>A {@link Model} object describes a machine learning model resource that can be used for
32+
* inference operations. It provides methods to perform prediction on data tables.
33+
*
34+
* <p>The {@link Model} interface offers main operations:
35+
*
36+
* <ul>
37+
* <li>{@link #predict(Table, ColumnList)} - Applies the model to make predictions on input data
38+
* </ul>
39+
*
40+
* <p>{@code ml_predict} operation supports runtime options for configuring execution parameters
41+
* such as asynchronous execution mode.
42+
*
43+
* <p>Every {@link Model} object has input and output schemas that describe the expected data
44+
* structure for model operations, available through {@link #getResolvedInputSchema()} and {@link
45+
* #getResolvedOutputSchema()}.
46+
*
47+
* <p>Example usage:
48+
*
49+
* <pre>{@code
50+
* Model model = tableEnv.fromModel("my_model");
51+
*
52+
* // Simple prediction
53+
* Table predictions = model.predict(inputTable, ColumnList.of("feature1", "feature2"));
54+
*
55+
* // Prediction with options
56+
* Map<String, String> options = Map.of("max-concurrent-operations", "100", "timeout", "30s", "async", "true");
57+
* Table predictions = model.predict(inputTable, ColumnList.of("feature1", "feature2"), options);
58+
* }</pre>
59+
*/
60+
@PublicEvolving
61+
public interface Model {
62+
63+
/**
64+
* Returns the resolved input schema of this model.
65+
*
66+
* <p>The input schema describes the structure and data types of the input columns that the
67+
* model expects for inference operations.
68+
*
69+
* @return the resolved input schema.
70+
*/
71+
ResolvedSchema getResolvedInputSchema();
72+
73+
/**
74+
* Returns the resolved output schema of this model.
75+
*
76+
* <p>The output schema describes the structure and data types of the output columns that the
77+
* model produces during inference operations.
78+
*
79+
* @return the resolved output schema.
80+
*/
81+
ResolvedSchema getResolvedOutputSchema();
82+
83+
/**
84+
* Performs prediction on the given table using specified input columns.
85+
*
86+
* <p>This method applies the model to the input data to generate predictions. The input columns
87+
* must match the model's expected input schema.
88+
*
89+
* <p>Example:
90+
*
91+
* <pre>{@code
92+
* Table predictions = model.predict(inputTable, ColumnList.of("feature1", "feature2"));
93+
* }</pre>
94+
*
95+
* @param table the input table containing data for prediction
96+
* @param inputColumns the columns from the input table to use as model input
97+
* @return a table containing the input data along with prediction results
98+
*/
99+
Table predict(Table table, ColumnList inputColumns);
100+
101+
/**
102+
* Performs prediction on the given table using specified input columns with runtime options.
103+
*
104+
* <p>This method applies the model to the input data to generate predictions with additional
105+
* runtime configuration options such as max-concurrent-operations, timeout, and execution mode
106+
* settings.
107+
*
108+
* <p>For Common runtime options, see {@link MLPredictRuntimeConfigOptions}.
109+
*
110+
* <p>Example:
111+
*
112+
* <pre>{@code
113+
* Map<String, String> options = Map.of("max-concurrent-operations", "100", "timeout", "30s", "async", "true");
114+
* Table predictions = model.predict(inputTable,
115+
* ColumnList.of("feature1", "feature2"), options);
116+
* }</pre>
117+
*
118+
* @param table the input table containing data for prediction
119+
* @param inputColumns the columns from the input table to use as model input
120+
* @param options runtime options for configuring the prediction operation
121+
* @return a table containing the input data along with prediction results
122+
*/
123+
Table predict(Table table, ColumnList inputColumns, Map<String, String> options);
124+
125+
/**
126+
* Converts this model object into a named argument.
127+
*
128+
* <p>This method is intended for use in function calls that accept model arguments,
129+
* particularly in process table functions (PTFs) or other operations that work with models.
130+
*
131+
* <p>Example:
132+
*
133+
* <pre>{@code
134+
* env.fromCall(
135+
* "ML_PREDICT",
136+
* inputTable.asArgument("INPUT"),
137+
* model.asArgument("MODEL"),
138+
* Expressions.descriptor(ColumnList.of("feature1", "feature2")).asArgument("ARGS")
139+
* )
140+
* }</pre>
141+
*
142+
* @param name the name to assign to this model argument
143+
* @return an expression that can be passed to functions expecting model arguments
144+
*/
145+
ApiExpression asArgument(String name);
146+
}

flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/TableEnvironment.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,49 @@ void createTemporarySystemFunction(
11751175
*/
11761176
Table fromCall(Class<? extends UserDefinedFunction> function, Object... arguments);
11771177

1178+
/**
1179+
* Returns a {@link Model} object that is backed by the specified model path.
1180+
*
1181+
* <p>This method creates a {@link Model} object from a given model path in the catalog. The
1182+
* model path can be fully or partially qualified (e.g., "catalog.db.model" or just "model"),
1183+
* depending on the current catalog and database context.
1184+
*
1185+
* <p>The returned {@link Model} object can be used for further transformations or as input to
1186+
* other operations in the Table API.
1187+
*
1188+
* <p>Example:
1189+
*
1190+
* <pre>{@code
1191+
* Model model = tableEnv.fromModel("my_model");
1192+
* }</pre>
1193+
*
1194+
* @param modelPath The path of the model in the catalog.
1195+
* @return The {@link Model} object describing the model resource.
1196+
*/
1197+
Model fromModel(String modelPath);
1198+
1199+
/**
1200+
* Returns a {@link Model} object that is backed by the specified {@link ModelDescriptor}.
1201+
*
1202+
* <p>This method creates a {@link Model} object using the provided {@link ModelDescriptor},
1203+
* which contains the necessary information to identify and configure the model resource in the
1204+
* catalog.
1205+
*
1206+
* <p>The returned {@link Model} object can be used for further transformations or as input to
1207+
* other operations in the Table API.
1208+
*
1209+
* <p>Example:
1210+
*
1211+
* <pre>{@code
1212+
* ModelDescriptor descriptor = ...;
1213+
* Model model = tableEnv.from(descriptor);
1214+
* }</pre>
1215+
*
1216+
* @param descriptor The {@link ModelDescriptor} describing the model resource.
1217+
* @return The {@link Model} object representing the model resource.
1218+
*/
1219+
Model fromModel(ModelDescriptor descriptor);
1220+
11781221
/**
11791222
* Gets the names of all catalogs registered in this environment.
11801223
*
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.api.internal;
20+
21+
import org.apache.flink.annotation.Internal;
22+
import org.apache.flink.table.api.ApiExpression;
23+
import org.apache.flink.table.api.Expressions;
24+
import org.apache.flink.table.api.Model;
25+
import org.apache.flink.table.api.Table;
26+
import org.apache.flink.table.api.TableEnvironment;
27+
import org.apache.flink.table.catalog.ContextResolvedModel;
28+
import org.apache.flink.table.catalog.ResolvedSchema;
29+
import org.apache.flink.table.expressions.ApiExpressionUtils;
30+
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
31+
import org.apache.flink.types.ColumnList;
32+
33+
import java.util.ArrayList;
34+
import java.util.Map;
35+
36+
import static org.apache.flink.table.api.Expressions.lit;
37+
import static org.apache.flink.table.expressions.ApiExpressionUtils.valueLiteral;
38+
39+
/** Implementation of {@link Model} that works with the Table API. */
40+
@Internal
41+
public class ModelImpl implements Model {
42+
43+
private final TableEnvironmentInternal tableEnvironment;
44+
private final ContextResolvedModel model;
45+
46+
private ModelImpl(TableEnvironmentInternal tableEnvironment, ContextResolvedModel model) {
47+
this.tableEnvironment = tableEnvironment;
48+
this.model = model;
49+
}
50+
51+
public static ModelImpl createModel(
52+
TableEnvironmentInternal tableEnvironment, ContextResolvedModel model) {
53+
return new ModelImpl(tableEnvironment, model);
54+
}
55+
56+
public ContextResolvedModel getModel() {
57+
return model;
58+
}
59+
60+
@Override
61+
public ResolvedSchema getResolvedInputSchema() {
62+
return model.getResolvedModel().getResolvedInputSchema();
63+
}
64+
65+
@Override
66+
public ResolvedSchema getResolvedOutputSchema() {
67+
return model.getResolvedModel().getResolvedOutputSchema();
68+
}
69+
70+
public TableEnvironment getTableEnv() {
71+
return tableEnvironment;
72+
}
73+
74+
@Override
75+
public Table predict(Table table, ColumnList inputColumns) {
76+
return predict(table, inputColumns, Map.of());
77+
}
78+
79+
@Override
80+
public Table predict(Table table, ColumnList inputColumns, Map<String, String> options) {
81+
// Use Expressions.map() instead of Expressions.lit() to create a MAP literal since
82+
// lit() is not serializable to sql.
83+
if (options.isEmpty()) {
84+
return tableEnvironment.fromCall(
85+
BuiltInFunctionDefinitions.ML_PREDICT.getName(),
86+
table.asArgument("INPUT"),
87+
this.asArgument("MODEL"),
88+
new ApiExpression(valueLiteral(inputColumns)).asArgument("ARGS"));
89+
}
90+
ArrayList<String> configKVs = new ArrayList<>();
91+
options.forEach(
92+
(k, v) -> {
93+
configKVs.add(k);
94+
configKVs.add(v);
95+
});
96+
return tableEnvironment.fromCall(
97+
BuiltInFunctionDefinitions.ML_PREDICT.getName(),
98+
table.asArgument("INPUT"),
99+
this.asArgument("MODEL"),
100+
new ApiExpression(valueLiteral(inputColumns)).asArgument("ARGS"),
101+
Expressions.map(
102+
configKVs.get(0),
103+
configKVs.get(1),
104+
configKVs.subList(2, configKVs.size()).toArray())
105+
.asArgument("CONFIG"));
106+
}
107+
108+
@Override
109+
public ApiExpression asArgument(String name) {
110+
return new ApiExpression(
111+
ApiExpressionUtils.unresolvedCall(
112+
BuiltInFunctionDefinitions.ASSIGNMENT,
113+
lit(name),
114+
ApiExpressionUtils.modelRef(name, this)));
115+
}
116+
117+
public TableEnvironment getTableEnvironment() {
118+
return tableEnvironment;
119+
}
120+
}

0 commit comments

Comments
 (0)