diff --git a/flink-python/pyflink/table/tests/test_environment_completeness.py b/flink-python/pyflink/table/tests/test_environment_completeness.py index d1305c7ef6a2b..6a1b00a05671c 100644 --- a/flink-python/pyflink/table/tests/test_environment_completeness.py +++ b/flink-python/pyflink/table/tests/test_environment_completeness.py @@ -16,8 +16,9 @@ # limitations under the License. ################################################################################ -from pyflink.testing.test_case_utils import PythonAPICompletenessTestCase, PyFlinkTestCase from pyflink.table import TableEnvironment +from pyflink.testing.test_case_utils import PythonAPICompletenessTestCase, \ + PyFlinkTestCase class EnvironmentAPICompletenessTests(PythonAPICompletenessTestCase, PyFlinkTestCase): @@ -40,6 +41,7 @@ def excluded_methods(cls): 'getCompletionHints', 'fromValues', 'fromCall', + 'fromModelPath', # See FLINK-25986 'loadPlan', 'compilePlanSql', diff --git a/flink-python/pyflink/table/tests/test_table_environment_completeness.py b/flink-python/pyflink/table/tests/test_table_environment_completeness.py index 48c6369bae6d4..306eadb4e1cee 100644 --- a/flink-python/pyflink/table/tests/test_table_environment_completeness.py +++ b/flink-python/pyflink/table/tests/test_table_environment_completeness.py @@ -44,6 +44,7 @@ def excluded_methods(cls): "from", "registerFunction", "fromCall", + "fromModelPath", } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Expressions.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Expressions.java index a4552b25f5e73..260be55a0cf17 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Expressions.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Expressions.java @@ -156,6 +156,19 @@ public static ApiExpression descriptor(String... columnNames) { return new ApiExpression(valueLiteral(ColumnList.of(Arrays.asList(columnNames)))); } + /** + * Creates a literal describing an arbitrary, unvalidated list of column names. + * + *

Passing a column list can be useful for parameterizing a function. In particular, it + * enables declaring the {@code on_time} argument for {@link ProcessTableFunction} or the {@code + * inputColumns} for {@link Model#predict}. + * + *

The data type will be {@link DataTypes#DESCRIPTOR()}. + */ + public static ApiExpression descriptor(ColumnList columnList) { + return new ApiExpression(valueLiteral(columnList)); + } + /** * Indicates a range from 'start' to 'end', which can be used in columns selection. * diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Model.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Model.java new file mode 100644 index 0000000000000..ca38aff4bef91 --- /dev/null +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Model.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.api.config.MLPredictRuntimeConfigOptions; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.types.ColumnList; + +import java.util.Map; + +/** + * The {@link Model} object is the core abstraction for ML model resources in the Table API. + * + *

A {@link Model} object describes a machine learning model resource that can be used for + * inference operations. It provides methods to perform prediction on data tables. + * + *

The {@link Model} interface offers main operations: + * + *

+ * + *

{@code ml_predict} operation supports runtime options for configuring execution parameters + * such as asynchronous execution mode. + * + *

Every {@link Model} object has input and output schemas that describe the expected data + * structure for model operations, available through {@link #getResolvedInputSchema()} and {@link + * #getResolvedOutputSchema()}. + * + *

Example usage: + * + *

{@code
+ * Model model = tableEnv.fromModelPath("my_model");
+ *
+ * // Simple prediction
+ * Table predictions = model.predict(inputTable, ColumnList.of("feature1", "feature2"));
+ *
+ * // Prediction with options
+ * Map options = Map.of("max-concurrent-operations", "100", "timeout", "30s", "async", "true");
+ * Table predictions = model.predict(inputTable, ColumnList.of("feature1", "feature2"), options);
+ * }
+ */ +@PublicEvolving +public interface Model { + + /** + * Returns the resolved input schema of this model. + * + *

The input schema describes the structure and data types of the input columns that the + * model expects for inference operations. + * + * @return the resolved input schema. + */ + ResolvedSchema getResolvedInputSchema(); + + /** + * Returns the resolved output schema of this model. + * + *

The output schema describes the structure and data types of the output columns that the + * model produces during inference operations. + * + * @return the resolved output schema. + */ + ResolvedSchema getResolvedOutputSchema(); + + /** + * Performs prediction on the given table using specified input columns. + * + *

This method applies the model to the input data to generate predictions. The input columns + * must match the model's expected input schema. + * + *

Example: + * + *

{@code
+     * Table predictions = model.predict(inputTable, ColumnList.of("feature1", "feature2"));
+     * }
+ * + * @param table the input table containing data for prediction + * @param inputColumns the columns from the input table to use as model input + * @return a table containing the input data along with prediction results + */ + Table predict(Table table, ColumnList inputColumns); + + /** + * Performs prediction on the given table using specified input columns with runtime options. + * + *

This method applies the model to the input data to generate predictions with additional + * runtime configuration options such as max-concurrent-operations, timeout, and execution mode + * settings. + * + *

For Common runtime options, see {@link MLPredictRuntimeConfigOptions}. + * + *

Example: + * + *

{@code
+     * Map options = Map.of("max-concurrent-operations", "100", "timeout", "30s", "async", "true");
+     * Table predictions = model.predict(inputTable,
+     *     ColumnList.of("feature1", "feature2"), options);
+     * }
+ * + * @param table the input table containing data for prediction + * @param inputColumns the columns from the input table to use as model input + * @param options runtime options for configuring the prediction operation + * @return a table containing the input data along with prediction results + */ + Table predict(Table table, ColumnList inputColumns, Map options); + + /** + * Converts this model object into a named argument. + * + *

This method is intended for use in function calls that accept model arguments, + * particularly in process table functions (PTFs) or other operations that work with models. + * + *

Example: + * + *

{@code
+     * env.fromCall(
+     *   "ML_PREDICT",
+     *   inputTable.asArgument("INPUT"),
+     *   model.asArgument("MODEL"),
+     *   Expressions.descriptor(ColumnList.of("feature1", "feature2")).asArgument("ARGS")
+     * )
+     * }
+ * + * @param name the name to assign to this model argument + * @return an expression that can be passed to functions expecting model arguments + */ + ApiExpression asArgument(String name); +} diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/TableEnvironment.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/TableEnvironment.java index dada96ae619f2..3f2a348b993e5 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/TableEnvironment.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/TableEnvironment.java @@ -1175,6 +1175,49 @@ void createTemporarySystemFunction( */ Table fromCall(Class function, Object... arguments); + /** + * Returns a {@link Model} object that is backed by the specified model path. + * + *

This method creates a {@link Model} object from a given model path in the catalog. The + * model path can be fully or partially qualified (e.g., "catalog.db.model" or just "model"), + * depending on the current catalog and database context. + * + *

The returned {@link Model} object can be used for further transformations or as input to + * other operations in the Table API. + * + *

Example: + * + *

{@code
+     * Model model = tableEnv.fromModelPath("my_model");
+     * }
+ * + * @param modelPath The path of the model in the catalog. + * @return The {@link Model} object describing the model resource. + */ + Model fromModelPath(String modelPath); + + /** + * Returns a {@link Model} object that is backed by the specified {@link ModelDescriptor}. + * + *

This method creates a {@link Model} object using the provided {@link ModelDescriptor}, + * which contains the necessary information to identify and configure the model resource in the + * catalog. + * + *

The returned {@link Model} object can be used for further transformations or as input to + * other operations in the Table API. + * + *

Example: + * + *

{@code
+     * ModelDescriptor descriptor = ...;
+     * Model model = tableEnv.from(descriptor);
+     * }
+ * + * @param descriptor The {@link ModelDescriptor} describing the model resource. + * @return The {@link Model} object representing the model resource. + */ + Model from(ModelDescriptor descriptor); + /** * Gets the names of all catalogs registered in this environment. * diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/ModelImpl.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/ModelImpl.java new file mode 100644 index 0000000000000..d6bf666a4f6ae --- /dev/null +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/ModelImpl.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.internal; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.ApiExpression; +import org.apache.flink.table.api.Expressions; +import org.apache.flink.table.api.Model; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.catalog.ContextResolvedModel; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.expressions.ApiExpressionUtils; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.types.ColumnList; + +import java.util.ArrayList; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.lit; + +/** Implementation of {@link Model} that works with the Table API. */ +@Internal +public class ModelImpl implements Model { + + private final TableEnvironmentInternal tableEnvironment; + private final ContextResolvedModel model; + + private ModelImpl(TableEnvironmentInternal tableEnvironment, ContextResolvedModel model) { + this.tableEnvironment = tableEnvironment; + this.model = model; + } + + public static ModelImpl createModel( + TableEnvironmentInternal tableEnvironment, ContextResolvedModel model) { + return new ModelImpl(tableEnvironment, model); + } + + public ContextResolvedModel getModel() { + return model; + } + + @Override + public ResolvedSchema getResolvedInputSchema() { + return model.getResolvedModel().getResolvedInputSchema(); + } + + @Override + public ResolvedSchema getResolvedOutputSchema() { + return model.getResolvedModel().getResolvedOutputSchema(); + } + + public TableEnvironment getTableEnv() { + return tableEnvironment; + } + + @Override + public Table predict(Table table, ColumnList inputColumns) { + return predict(table, inputColumns, Map.of()); + } + + @Override + public Table predict(Table table, ColumnList inputColumns, Map options) { + // Use Expressions.map() instead of Expressions.lit() to create a MAP literal since + // lit() is not serializable to sql. + if (options.isEmpty()) { + return tableEnvironment.fromCall( + BuiltInFunctionDefinitions.ML_PREDICT.getName(), + table.asArgument("INPUT"), + this.asArgument("MODEL"), + Expressions.descriptor(inputColumns).asArgument("ARGS")); + } + ArrayList configKVs = new ArrayList<>(); + options.forEach( + (k, v) -> { + configKVs.add(k); + configKVs.add(v); + }); + return tableEnvironment.fromCall( + BuiltInFunctionDefinitions.ML_PREDICT.getName(), + table.asArgument("INPUT"), + this.asArgument("MODEL"), + Expressions.descriptor(inputColumns).asArgument("ARGS"), + Expressions.map( + configKVs.get(0), + configKVs.get(1), + configKVs.subList(2, configKVs.size()).toArray()) + .asArgument("CONFIG")); + } + + @Override + public ApiExpression asArgument(String name) { + return new ApiExpression( + ApiExpressionUtils.unresolvedCall( + BuiltInFunctionDefinitions.ASSIGNMENT, + lit(name), + ApiExpressionUtils.modelRef(name, this))); + } + + public TableEnvironment getTableEnvironment() { + return tableEnvironment; + } +} diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableEnvironmentImpl.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableEnvironmentImpl.java index 058f7a92be967..de009f7075613 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableEnvironmentImpl.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableEnvironmentImpl.java @@ -33,6 +33,7 @@ import org.apache.flink.table.api.ExplainFormat; import org.apache.flink.table.api.Expressions; import org.apache.flink.table.api.FunctionDescriptor; +import org.apache.flink.table.api.Model; import org.apache.flink.table.api.ModelDescriptor; import org.apache.flink.table.api.PlanReference; import org.apache.flink.table.api.ResultKind; @@ -55,12 +56,14 @@ import org.apache.flink.table.catalog.CatalogStore; import org.apache.flink.table.catalog.CatalogStoreHolder; import org.apache.flink.table.catalog.Column; +import org.apache.flink.table.catalog.ContextResolvedModel; import org.apache.flink.table.catalog.ContextResolvedTable; import org.apache.flink.table.catalog.FunctionCatalog; import org.apache.flink.table.catalog.FunctionLanguage; import org.apache.flink.table.catalog.GenericInMemoryCatalog; import org.apache.flink.table.catalog.ObjectIdentifier; import org.apache.flink.table.catalog.QueryOperationCatalogView; +import org.apache.flink.table.catalog.ResolvedCatalogModel; import org.apache.flink.table.catalog.ResolvedCatalogTable; import org.apache.flink.table.catalog.ResolvedSchema; import org.apache.flink.table.catalog.StagedTable; @@ -77,6 +80,7 @@ import org.apache.flink.table.expressions.ApiExpressionUtils; import org.apache.flink.table.expressions.DefaultSqlFactory; import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.ModelReferenceExpression; import org.apache.flink.table.expressions.TableReferenceExpression; import org.apache.flink.table.expressions.utils.ApiExpressionDefaultVisitor; import org.apache.flink.table.factories.CatalogStoreFactory; @@ -682,6 +686,29 @@ public Table fromCall(Class function, Object... a operationTreeBuilder.tableFunction(Expressions.call(function, arguments))); } + @Override + public Model fromModelPath(String modelPath) { + UnresolvedIdentifier unresolvedIdentifier = getParser().parseIdentifier(modelPath); + ObjectIdentifier modelIdentifier = catalogManager.qualifyIdentifier(unresolvedIdentifier); + return catalogManager + .getModel(modelIdentifier) + .map(this::createModel) + .orElseThrow( + () -> + new ValidationException( + String.format( + "Model %s was not found.", unresolvedIdentifier))); + } + + @Override + public Model from(ModelDescriptor descriptor) { + Preconditions.checkNotNull(descriptor, "Model descriptor must not be null."); + + final ResolvedCatalogModel resolvedCatalogModel = + catalogManager.resolveCatalogModel(descriptor.toCatalogModel()); + return createModel(ContextResolvedModel.anonymous(resolvedCatalogModel)); + } + private Optional scanInternal(UnresolvedIdentifier identifier) { ObjectIdentifier tableIdentifier = catalogManager.qualifyIdentifier(identifier); @@ -1487,6 +1514,10 @@ public TableImpl createTable(QueryOperation tableOperation) { functionCatalog.asLookup(getParser()::parseIdentifier)); } + public ModelImpl createModel(ContextResolvedModel model) { + return ModelImpl.createModel(this, model); + } + @Override public String explainPlan(InternalPlan compiledPlan, ExplainDetail... extraDetails) { return planner.explainPlan(compiledPlan, extraDetails); @@ -1531,5 +1562,16 @@ public Void visit(TableReferenceExpression tableRef) { } return null; } + + @Override + public Void visit(ModelReferenceExpression modelRef) { + super.visit(modelRef); + if (modelRef.getTableEnvironment() != null + && modelRef.getTableEnvironment() != TableEnvironmentImpl.this) { + throw new ValidationException( + "All model references must use the same TableEnvironment."); + } + return null; + } } } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ContextResolvedModel.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ContextResolvedModel.java index dbfba0b7a584e..74854135de757 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ContextResolvedModel.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ContextResolvedModel.java @@ -19,12 +19,14 @@ package org.apache.flink.table.catalog; import org.apache.flink.annotation.Internal; +import org.apache.flink.table.factories.FactoryUtil; import org.apache.flink.util.Preconditions; import javax.annotation.Nullable; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; /** * This class contains information about a model and its relationship with a {@link Catalog}, if @@ -48,28 +50,46 @@ @Internal public final class ContextResolvedModel { + private static final AtomicInteger uniqueId = new AtomicInteger(0); private final ObjectIdentifier objectIdentifier; private final @Nullable Catalog catalog; private final ResolvedCatalogModel resolvedModel; + private final boolean anonymous; public static ContextResolvedModel permanent( ObjectIdentifier identifier, Catalog catalog, ResolvedCatalogModel resolvedModel) { return new ContextResolvedModel( - identifier, Preconditions.checkNotNull(catalog), resolvedModel); + identifier, Preconditions.checkNotNull(catalog), resolvedModel, false); } public static ContextResolvedModel temporary( ObjectIdentifier identifier, ResolvedCatalogModel resolvedModel) { - return new ContextResolvedModel(identifier, null, resolvedModel); + return new ContextResolvedModel(identifier, null, resolvedModel, false); + } + + public static ContextResolvedModel anonymous(ResolvedCatalogModel resolvedModel) { + return anonymous(null, resolvedModel); + } + + public static ContextResolvedModel anonymous( + @Nullable String hint, ResolvedCatalogModel resolvedModel) { + return new ContextResolvedModel( + ObjectIdentifier.ofAnonymous( + generateAnonymousStringIdentifier(hint, resolvedModel)), + null, + resolvedModel, + true); } private ContextResolvedModel( ObjectIdentifier objectIdentifier, @Nullable Catalog catalog, - ResolvedCatalogModel resolvedModel) { + ResolvedCatalogModel resolvedModel, + boolean anonymous) { this.objectIdentifier = Preconditions.checkNotNull(objectIdentifier); this.catalog = catalog; this.resolvedModel = Preconditions.checkNotNull(resolvedModel); + this.anonymous = anonymous; } /** @@ -83,6 +103,10 @@ public boolean isPermanent() { return !isTemporary(); } + public boolean isAnonymous() { + return anonymous; + } + public ObjectIdentifier getIdentifier() { return objectIdentifier; } @@ -116,13 +140,39 @@ public boolean equals(Object o) { return false; } ContextResolvedModel that = (ContextResolvedModel) o; - return Objects.equals(objectIdentifier, that.objectIdentifier) + return anonymous == that.anonymous + && Objects.equals(objectIdentifier, that.objectIdentifier) && Objects.equals(catalog, that.catalog) && Objects.equals(resolvedModel, that.resolvedModel); } @Override public int hashCode() { - return Objects.hash(objectIdentifier, catalog, resolvedModel); + return Objects.hash(objectIdentifier, catalog, resolvedModel, anonymous); + } + + /** + * This method tries to return the provider name of the model, trying to provide a bit more + * helpful toString for anonymous models. It's only to help users to debug, and its return value + * should not be relied on. + */ + private static String generateAnonymousStringIdentifier( + @Nullable String hint, ResolvedCatalogModel resolvedModel) { + // Planner can do some fancy optimizations' logic squashing two sources together in the same + // operator. Because this logic is string based, anonymous models still need some kind of + // unique string based identifier that can be used later by the planner. + if (hint == null) { + try { + hint = resolvedModel.getOptions().get(FactoryUtil.PROVIDER.key()); + } catch (Exception ignored) { + } + } + + int id = uniqueId.incrementAndGet(); + if (hint == null) { + return "*anonymous$" + id + "*"; + } + + return "*anonymous_" + hint + "$" + id + "*"; } } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java index 8748fafceb799..bcad8cccfce79 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java @@ -21,9 +21,11 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.table.api.ApiExpression; import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Model; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.TableEnvironment; import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.api.internal.ModelImpl; import org.apache.flink.table.api.internal.TableImpl; import org.apache.flink.table.catalog.ContextResolvedFunction; import org.apache.flink.table.functions.BuiltInFunctionDefinition; @@ -301,6 +303,11 @@ public static TableReferenceExpression tableRef( return new TableReferenceExpression(name, queryOperation, env); } + public static ModelReferenceExpression modelRef(String name, Model model) { + return new ModelReferenceExpression( + name, ((ModelImpl) model).getModel(), ((ModelImpl) model).getTableEnvironment()); + } + public static LookupCallExpression lookupCall(String name, Expression... args) { return new LookupCallExpression( name, diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionVisitor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionVisitor.java index f0b34bbddfaba..b64b77d24fba0 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionVisitor.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionVisitor.java @@ -29,6 +29,8 @@ public final R visit(Expression other) { return visit((UnresolvedReferenceExpression) other); } else if (other instanceof TableReferenceExpression) { return visit((TableReferenceExpression) other); + } else if (other instanceof ModelReferenceExpression) { + return visit((ModelReferenceExpression) other); } else if (other instanceof LocalReferenceExpression) { return visit((LocalReferenceExpression) other); } else if (other instanceof LookupCallExpression) { @@ -49,6 +51,8 @@ public final R visit(Expression other) { public abstract R visit(TableReferenceExpression tableReference); + public abstract R visit(ModelReferenceExpression modelReferenceExpression); + public abstract R visit(LocalReferenceExpression localReference); /** For resolved expressions created by the planner. */ diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ModelReferenceExpression.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ModelReferenceExpression.java new file mode 100644 index 0000000000000..c3c225528b204 --- /dev/null +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ModelReferenceExpression.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.Model; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.catalog.ContextResolvedModel; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.utils.DataTypeUtils; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * A reference to a {@link Model} in an expression context. + * + *

This expression is used when a model needs to be passed as an argument to functions or + * operations that accept model references. It wraps a model object and provides the necessary + * expression interface for use in the Table API expression system. + * + *

The expression carries a string representation of the model and uses a special data type to + * indicate that this is a model reference rather than a regular data value. + */ +@Internal +public final class ModelReferenceExpression implements ResolvedExpression { + + private final String name; + private final ContextResolvedModel model; + // The environment is optional but serves validation purposes + // to ensure that all referenced tables belong to the same + // environment. + private final TableEnvironment env; + + public ModelReferenceExpression(String name, ContextResolvedModel model, TableEnvironment env) { + this.name = Preconditions.checkNotNull(name); + this.model = Preconditions.checkNotNull(model); + this.env = Preconditions.checkNotNull(env); + } + + /** + * Returns the name of this model reference. + * + * @return the model reference name + */ + public String getName() { + return name; + } + + /** + * Returns the ContextResolvedModel associated with this model reference. + * + * @return the query context resolved model + */ + public ContextResolvedModel getModel() { + return model; + } + + public @Nullable TableEnvironment getTableEnvironment() { + return env; + } + + /** + * Returns the input data type expected by this model reference. + * + *

This method extracts the input data type from the model's input schema, which describes + * the structure and data types that the model expects for inference operations. + * + * @return the input data type expected by the model + */ + public DataType getInputDataType() { + return DataTypeUtils.fromResolvedSchemaPreservingTimeAttributes( + model.getResolvedModel().getResolvedInputSchema()); + } + + @Override + public DataType getOutputDataType() { + return DataTypeUtils.fromResolvedSchemaPreservingTimeAttributes( + model.getResolvedModel().getResolvedOutputSchema()); + } + + @Override + public List getResolvedChildren() { + return Collections.emptyList(); + } + + @Override + public String asSerializableString(SqlFactory sqlFactory) { + if (model.isAnonymous()) { + throw new ValidationException("Anonymous models cannot be serialized."); + } + + return "MODEL " + model.getIdentifier().asSerializableString(); + } + + @Override + public String asSummaryString() { + return name; + } + + @Override + public List getChildren() { + return Collections.emptyList(); + } + + @Override + public R accept(ExpressionVisitor visitor) { + return visitor.visit(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ModelReferenceExpression that = (ModelReferenceExpression) o; + return Objects.equals(name, that.name) + && Objects.equals(model, that.model) + // Effectively means reference equality + && Objects.equals(env, that.env); + } + + @Override + public int hashCode() { + return Objects.hash(name, model, env); + } + + @Override + public String toString() { + return asSummaryString(); + } +} diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ResolvedExpressionVisitor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ResolvedExpressionVisitor.java index 42ee52d98d51c..88a15e3929d8a 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ResolvedExpressionVisitor.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ResolvedExpressionVisitor.java @@ -32,6 +32,8 @@ public abstract class ResolvedExpressionVisitor implements ExpressionVisitor< public final R visit(Expression other) { if (other instanceof TableReferenceExpression) { return visit((TableReferenceExpression) other); + } else if (other instanceof ModelReferenceExpression) { + return visit((ModelReferenceExpression) other); } else if (other instanceof LocalReferenceExpression) { return visit((LocalReferenceExpression) other); } else if (other instanceof ResolvedExpression) { @@ -42,6 +44,8 @@ public final R visit(Expression other) { public abstract R visit(TableReferenceExpression tableReference); + public abstract R visit(ModelReferenceExpression modelReferenceExpression); + public abstract R visit(LocalReferenceExpression localReference); /** For resolved expressions created by the planner. */ diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java index f5c585087dd01..e298953f69e3e 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java @@ -28,6 +28,7 @@ import org.apache.flink.table.expressions.CallExpression; import org.apache.flink.table.expressions.Expression; import org.apache.flink.table.expressions.ExpressionUtils; +import org.apache.flink.table.expressions.ModelReferenceExpression; import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.expressions.TableReferenceExpression; import org.apache.flink.table.expressions.TypeLiteralExpression; @@ -38,6 +39,7 @@ import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionIdentifier; import org.apache.flink.table.functions.FunctionKind; +import org.apache.flink.table.functions.ModelSemantics; import org.apache.flink.table.functions.ScalarFunctionDefinition; import org.apache.flink.table.functions.TableAggregateFunctionDefinition; import org.apache.flink.table.functions.TableFunctionDefinition; @@ -651,6 +653,22 @@ public Optional getTableSemantics(int pos) { return Optional.of(semantics); } + @Override + public Optional getModelSemantics(int pos) { + final StaticArgument staticArg = + Optional.ofNullable(staticArguments).map(args -> args.get(pos)).orElse(null); + if (staticArg == null || !staticArg.is(StaticArgumentTrait.MODEL)) { + return Optional.empty(); + } + final ResolvedExpression arg = getArgument(pos); + if (!(arg instanceof ModelReferenceExpression)) { + return Optional.empty(); + } + final ModelReferenceExpression modelRef = (ModelReferenceExpression) arg; + final ModelSemantics semantics = new TableApiModelSemantics(modelRef); + return Optional.of(semantics); + } + @Override public String getName() { return functionName; @@ -732,4 +750,23 @@ public Optional changelogMode() { return Optional.empty(); } } + + private static class TableApiModelSemantics implements ModelSemantics { + + private final ModelReferenceExpression modelRef; + + private TableApiModelSemantics(ModelReferenceExpression modelRef) { + this.modelRef = modelRef; + } + + @Override + public DataType inputDataType() { + return modelRef.getInputDataType(); + } + + @Override + public DataType outputDataType() { + return modelRef.getOutputDataType(); + } + } } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java index 9797f53a331c2..6e0303454094d 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java @@ -25,6 +25,7 @@ import org.apache.flink.table.expressions.FieldReferenceExpression; import org.apache.flink.table.expressions.LocalReferenceExpression; import org.apache.flink.table.expressions.LookupCallExpression; +import org.apache.flink.table.expressions.ModelReferenceExpression; import org.apache.flink.table.expressions.NestedFieldReferenceExpression; import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.expressions.SqlCallExpression; @@ -76,6 +77,11 @@ public T visit(TableReferenceExpression tableReference) { return defaultMethod(tableReference); } + @Override + public T visit(ModelReferenceExpression modelReference) { + return defaultMethod(modelReference); + } + @Override public T visit(LocalReferenceExpression localReference) { return defaultMethod(localReference); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java index 3bf93880d7d56..841ff5a03393f 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java @@ -22,6 +22,7 @@ import org.apache.flink.table.expressions.CallExpression; import org.apache.flink.table.expressions.FieldReferenceExpression; import org.apache.flink.table.expressions.LocalReferenceExpression; +import org.apache.flink.table.expressions.ModelReferenceExpression; import org.apache.flink.table.expressions.NestedFieldReferenceExpression; import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.expressions.ResolvedExpressionVisitor; @@ -41,6 +42,10 @@ public T visit(TableReferenceExpression tableReference) { return defaultMethod(tableReference); } + public T visit(ModelReferenceExpression modelReferenceExpression) { + return defaultMethod(modelReferenceExpression); + } + @Override public T visit(LocalReferenceExpression localReference) { return defaultMethod(localReference); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationExpressionsUtils.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationExpressionsUtils.java index 4d7af03afa851..4f39f3ecb8410 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationExpressionsUtils.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationExpressionsUtils.java @@ -24,6 +24,7 @@ import org.apache.flink.table.expressions.FieldReferenceExpression; import org.apache.flink.table.expressions.LocalReferenceExpression; import org.apache.flink.table.expressions.LookupCallExpression; +import org.apache.flink.table.expressions.ModelReferenceExpression; import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.expressions.TableReferenceExpression; import org.apache.flink.table.expressions.UnresolvedCallExpression; @@ -278,6 +279,11 @@ public Optional visit(TableReferenceExpression tableReference) { return Optional.of(tableReference.getName()); } + @Override + public Optional visit(ModelReferenceExpression modelReference) { + return Optional.of(modelReference.getName()); + } + @Override public Optional visit(FieldReferenceExpression fieldReference) { return Optional.of(fieldReference.getName()); diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/FailingTableApiTestStep.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/FailingTableApiTestStep.java new file mode 100644 index 0000000000000..83b90f8946b0c --- /dev/null +++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/FailingTableApiTestStep.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.test.program; + +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.api.TableRuntimeException; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.test.program.TableApiTestStep.TableEnvAccessor; +import org.apache.flink.util.Preconditions; + +import java.util.function.Function; + +import static org.apache.flink.core.testutils.FlinkAssertions.anyCauseMatches; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test step for executing Table API query that will fail eventually with either {@link + * ValidationException} (during planning time) or {@link TableRuntimeException} (during execution + * time). + * + *

Similar to {@link FailingSqlTestStep} but uses Table API instead of SQL. + */ +public final class FailingTableApiTestStep implements TestStep { + + private final Function tableQuery; + private final String sinkName; + public final Class expectedException; + public final String expectedErrorMessage; + + FailingTableApiTestStep( + Function tableQuery, + String sinkName, + Class expectedException, + String expectedErrorMessage) { + Preconditions.checkArgument( + expectedException == ValidationException.class + || expectedException == TableRuntimeException.class, + "Usually a Table API query should fail with either validation or runtime exception. " + + "Otherwise this might require an update to the exception design."); + this.tableQuery = tableQuery; + this.sinkName = sinkName; + this.expectedException = expectedException; + this.expectedErrorMessage = expectedErrorMessage; + } + + @Override + public TestKind getKind() { + return TestKind.FAILING_TABLE_API; + } + + public Table toTable(TableEnvironment env) { + return tableQuery.apply( + new TableEnvAccessor() { + @Override + public Table from(String path) { + return env.from(path); + } + + @Override + public Table fromCall(String path, Object... arguments) { + return env.fromCall(path, arguments); + } + + @Override + public Table fromCall( + Class + function, + Object... arguments) { + return env.fromCall(function, arguments); + } + + @Override + public Table fromValues(Object... values) { + return env.fromValues(values); + } + + @Override + public Table fromValues( + org.apache.flink.table.types.AbstractDataType dataType, + Object... values) { + return env.fromValues(dataType, values); + } + + @Override + public Table sqlQuery(String query) { + return env.sqlQuery(query); + } + + @Override + public org.apache.flink.table.api.Model fromModel(String modelPath) { + return env.fromModelPath(modelPath); + } + + @Override + public org.apache.flink.table.api.Model from( + org.apache.flink.table.api.ModelDescriptor modelDescriptor) { + return env.from(modelDescriptor); + } + }); + } + + public void apply(TableEnvironment env) { + assertThatThrownBy( + () -> { + final Table table = toTable(env); + table.executeInsert(sinkName).await(); + }) + .satisfies(anyCauseMatches(expectedException, expectedErrorMessage)); + } + + public void applyAsSql(TableEnvironment env) { + assertThatThrownBy( + () -> { + final Table table = toTable(env); + final String query = + table.getQueryOperation() + .asSerializableString( + org.apache.flink.table.expressions + .DefaultSqlFactory.INSTANCE); + env.executeSql(String.format("INSERT INTO %s %s", sinkName, query)) + .await(); + }) + .satisfies(anyCauseMatches(expectedException, expectedErrorMessage)); + } +} diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableApiTestStep.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableApiTestStep.java index 4a375ce4f5932..4cfb19128e12e 100644 --- a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableApiTestStep.java +++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableApiTestStep.java @@ -18,6 +18,8 @@ package org.apache.flink.table.test.program; +import org.apache.flink.table.api.Model; +import org.apache.flink.table.api.ModelDescriptor; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.TableEnvironment; import org.apache.flink.table.api.TableResult; @@ -75,6 +77,16 @@ public Table fromValues(AbstractDataType dataType, Object... values) { public Table sqlQuery(String query) { return env.sqlQuery(query); } + + @Override + public Model fromModel(String modelPath) { + return env.fromModelPath(modelPath); + } + + @Override + public Model from(ModelDescriptor modelDescriptor) { + return env.from(modelDescriptor); + } }); } @@ -111,5 +123,11 @@ public interface TableEnvAccessor { /** See {@link TableEnvironment#sqlQuery(String)}. */ Table sqlQuery(String query); + + /** See {@link TableEnvironment#fromModelPath(String)}. */ + Model fromModel(String modelPath); + + /** See {@link TableEnvironment#from(ModelDescriptor)}. */ + Model from(ModelDescriptor modelDescriptor); } } diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableTestProgram.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableTestProgram.java index bb8ca296cae5a..d37c4c8dc957e 100644 --- a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableTestProgram.java +++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableTestProgram.java @@ -21,6 +21,7 @@ import org.apache.flink.configuration.ConfigOption; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.TableRuntimeException; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.expressions.Expression; import org.apache.flink.table.functions.UserDefinedFunction; import org.apache.flink.table.test.program.FunctionTestStep.FunctionBehavior; @@ -355,6 +356,22 @@ public Builder runFailingSql( return this; } + /** + * Run step for executing a Table API query that will fail eventually with either {@link + * ValidationException} (during planning time) or {@link TableRuntimeException} (during + * execution time). + */ + public Builder runFailingTableApi( + Function toTable, + String sinkName, + Class expectedException, + String expectedErrorMessage) { + this.runSteps.add( + new FailingTableApiTestStep( + toTable, sinkName, expectedException, expectedErrorMessage)); + return this; + } + public Builder runTableApi(Function toTable, String sinkName) { this.runSteps.add(new TableApiTestStep(toTable, sinkName)); return this; diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TestStep.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TestStep.java index fc4245df79ff9..db2fe754f360d 100644 --- a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TestStep.java +++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TestStep.java @@ -51,7 +51,8 @@ enum TestKind { SINK_WITHOUT_DATA, SINK_WITH_DATA, SINK_WITH_RESTORE_DATA, - FAILING_SQL + FAILING_SQL, + FAILING_TABLE_API } TestKind getKind(); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/types/ColumnList.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/types/ColumnList.java index d2d02bc498f70..bc70a660662f3 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/types/ColumnList.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/types/ColumnList.java @@ -61,6 +61,10 @@ public static ColumnList of(List names) { return of(names, List.of()); } + public static ColumnList of(String... names) { + return of(List.of(names)); + } + /** * Returns a list of column names. * diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java index b4c0cf396266a..7b44d97cbf901 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java @@ -19,26 +19,42 @@ package org.apache.flink.table.planner.expressions.converter; import org.apache.flink.table.api.TableException; +import org.apache.flink.table.catalog.Catalog; +import org.apache.flink.table.catalog.ContextResolvedModel; import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.data.DecimalData; import org.apache.flink.table.expressions.CallExpression; import org.apache.flink.table.expressions.Expression; -import org.apache.flink.table.expressions.ExpressionVisitor; import org.apache.flink.table.expressions.FieldReferenceExpression; import org.apache.flink.table.expressions.LocalReferenceExpression; +import org.apache.flink.table.expressions.ModelReferenceExpression; import org.apache.flink.table.expressions.NestedFieldReferenceExpression; +import org.apache.flink.table.expressions.ResolvedExpression; +import org.apache.flink.table.expressions.ResolvedExpressionVisitor; +import org.apache.flink.table.expressions.TableReferenceExpression; import org.apache.flink.table.expressions.TimeIntervalUnit; import org.apache.flink.table.expressions.TimePointUnit; import org.apache.flink.table.expressions.TypeLiteralExpression; import org.apache.flink.table.expressions.ValueLiteralExpression; +import org.apache.flink.table.factories.FactoryUtil; +import org.apache.flink.table.factories.ModelProviderFactory; +import org.apache.flink.table.ml.ModelProvider; +import org.apache.flink.table.module.Module; +import org.apache.flink.table.operations.PartitionQueryOperation; +import org.apache.flink.table.planner.calcite.FlinkContext; import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.calcite.RexFieldVariable; +import org.apache.flink.table.planner.calcite.RexModelCall; +import org.apache.flink.table.planner.calcite.RexTableArgCall; import org.apache.flink.table.planner.expressions.RexNodeExpression; import org.apache.flink.table.planner.expressions.converter.CallExpressionConvertRule.ConvertContext; import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable; +import org.apache.flink.table.planner.utils.ShortcutUtils; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.logical.TimeType; +import org.apache.flink.table.types.utils.DataTypeUtils; import org.apache.flink.types.ColumnList; import org.apache.calcite.avatica.util.ByteString; @@ -60,6 +76,7 @@ import java.time.Period; import java.time.ZoneOffset; import java.time.temporal.ChronoField; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Optional; @@ -69,19 +86,22 @@ import static org.apache.flink.table.planner.utils.ShortcutUtils.unwrapContext; import static org.apache.flink.table.planner.utils.TimestampStringUtils.fromLocalDateTime; import static org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType; +import static org.apache.flink.util.OptionalUtils.firstPresent; /** Visit expression to generator {@link RexNode}. */ -public class ExpressionConverter implements ExpressionVisitor { +public class ExpressionConverter extends ResolvedExpressionVisitor { private final RelBuilder relBuilder; private final FlinkTypeFactory typeFactory; private final DataTypeFactory dataTypeFactory; + private final List inputStack; public ExpressionConverter(RelBuilder relBuilder) { this.relBuilder = relBuilder; this.typeFactory = (FlinkTypeFactory) relBuilder.getRexBuilder().getTypeFactory(); this.dataTypeFactory = unwrapContext(relBuilder.getCluster()).getCatalogManager().getDataTypeFactory(); + this.inputStack = new ArrayList<>(); } private List getFunctionConvertChain(boolean isBatchMode) { @@ -235,27 +255,96 @@ public RexNode visit(TypeLiteralExpression typeLiteral) { } @Override - public RexNode visit(Expression other) { + public RexNode visit(TableReferenceExpression tableRef) { + final LogicalType tableArgType = tableRef.getOutputDataType().getLogicalType(); + final RelDataType rowType = typeFactory.buildRelNodeRowType((RowType) tableArgType); + final int[] partitionKeys; + if (tableRef.getQueryOperation() instanceof PartitionQueryOperation) { + final PartitionQueryOperation partitionOperation = + (PartitionQueryOperation) tableRef.getQueryOperation(); + partitionKeys = partitionOperation.getPartitionKeys(); + } else { + partitionKeys = new int[0]; + } + final RexTableArgCall tableArgCall = + new RexTableArgCall(rowType, inputStack.size(), partitionKeys, new int[0]); + inputStack.add(relBuilder.build()); + return tableArgCall; + } + + public List copyInputStack() { + return new ArrayList<>(inputStack); + } + + public void clearInputStack() { + inputStack.clear(); + } + + @Override + public RexNode visit(ModelReferenceExpression modelRef) { + final ContextResolvedModel contextResolvedModel = modelRef.getModel(); + final FlinkContext flinkContext = ShortcutUtils.unwrapContext(relBuilder); + + final Optional factoryFromCatalog = + contextResolvedModel + .getCatalog() + .flatMap(Catalog::getFactory) + .map( + f -> + f instanceof ModelProviderFactory + ? (ModelProviderFactory) f + : null); + + final Optional factoryFromModule = + flinkContext.getModuleManager().getFactory(Module::getModelProviderFactory); + + // Since the catalog is more specific, we give it + // precedence over a factory provided by any + // modules. + final ModelProviderFactory factory = + firstPresent(factoryFromCatalog, factoryFromModule).orElse(null); + + final ModelProvider modelProvider = + FactoryUtil.createModelProvider( + factory, + contextResolvedModel.getIdentifier(), + contextResolvedModel.getResolvedModel(), + flinkContext.getTableConfig(), + flinkContext.getClassLoader(), + contextResolvedModel.isTemporary()); + final LogicalType modelOutputType = + DataTypeUtils.fromResolvedSchemaPreservingTimeAttributes( + contextResolvedModel.getResolvedModel().getResolvedOutputSchema()) + .getLogicalType(); + final RelDataType modelOutputRelDataType = + typeFactory.buildRelNodeRowType((RowType) modelOutputType); + + return new RexModelCall(modelOutputRelDataType, contextResolvedModel, modelProvider); + } + + @Override + public RexNode visit(LocalReferenceExpression local) { + // check whether the local field reference can actually be resolved to an existing + // field otherwise preserve the locality attribute + RelNode inputNode; + try { + inputNode = relBuilder.peek(); + } catch (Throwable t) { + inputNode = null; + } + if (inputNode != null && inputNode.getRowType().getFieldNames().contains(local.getName())) { + return relBuilder.field(local.getName()); + } + return new RexFieldVariable( + local.getName(), + typeFactory.createFieldTypeFromLogicalType( + fromDataTypeToLogicalType(local.getOutputDataType()))); + } + + @Override + public RexNode visit(ResolvedExpression other) { if (other instanceof RexNodeExpression) { return ((RexNodeExpression) other).getRexNode(); - } else if (other instanceof LocalReferenceExpression) { - final LocalReferenceExpression local = (LocalReferenceExpression) other; - // check whether the local field reference can actually be resolved to an existing - // field otherwise preserve the locality attribute - RelNode inputNode; - try { - inputNode = relBuilder.peek(); - } catch (Throwable t) { - inputNode = null; - } - if (inputNode != null - && inputNode.getRowType().getFieldNames().contains(local.getName())) { - return relBuilder.field(local.getName()); - } - return new RexFieldVariable( - local.getName(), - typeFactory.createFieldTypeFromLogicalType( - fromDataTypeToLogicalType(local.getOutputDataType()))); } else { throw new UnsupportedOperationException( other.getClass().getSimpleName() + ":" + other.toString()); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java index 08ebe72f9e425..09c3cf159dd97 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java @@ -36,7 +36,6 @@ import org.apache.flink.table.expressions.ExpressionDefaultVisitor; import org.apache.flink.table.expressions.FieldReferenceExpression; import org.apache.flink.table.expressions.ResolvedExpression; -import org.apache.flink.table.expressions.TableReferenceExpression; import org.apache.flink.table.expressions.ValueLiteralExpression; import org.apache.flink.table.functions.BuiltInFunctionDefinitions; import org.apache.flink.table.functions.FunctionDefinition; @@ -68,7 +67,6 @@ import org.apache.flink.table.planner.calcite.FlinkContext; import org.apache.flink.table.planner.calcite.FlinkRelBuilder; import org.apache.flink.table.planner.calcite.FlinkTypeFactory; -import org.apache.flink.table.planner.calcite.RexTableArgCall; import org.apache.flink.table.planner.connectors.DynamicSourceUtils; import org.apache.flink.table.planner.expressions.RexNodeExpression; import org.apache.flink.table.planner.expressions.SqlAggFunctionVisitor; @@ -302,43 +300,14 @@ public RelNode visit(FunctionQueryOperation functionTable) { final RelDataType outputRelDataType = typeFactory.buildRelNodeRowType((RowType) outputType); - final List inputStack = new ArrayList<>(); final List rexNodeArgs = resolvedArgs.stream() - .map( - resolvedArg -> { - if (resolvedArg instanceof TableReferenceExpression) { - final TableReferenceExpression tableRef = - (TableReferenceExpression) resolvedArg; - final LogicalType tableArgType = - tableRef.getOutputDataType().getLogicalType(); - final RelDataType rowType = - typeFactory.buildRelNodeRowType( - (RowType) tableArgType); - final int[] partitionKeys; - if (tableRef.getQueryOperation() - instanceof PartitionQueryOperation) { - final PartitionQueryOperation partitionOperation = - (PartitionQueryOperation) - tableRef.getQueryOperation(); - partitionKeys = - partitionOperation.getPartitionKeys(); - } else { - partitionKeys = new int[0]; - } - final RexTableArgCall tableArgCall = - new RexTableArgCall( - rowType, - inputStack.size(), - partitionKeys, - new int[0]); - inputStack.add(relBuilder.build()); - return tableArgCall; - } - return convertExprToRexNode(resolvedArg); - }) + .map(QueryOperationConverter.this::convertExprToRexNode) .collect(Collectors.toList()); + final List inputStack = expressionConverter.copyInputStack(); + expressionConverter.clearInputStack(); + // relBuilder.build() works in LIFO fashion, this restores the original input order Collections.reverse(inputStack); @@ -547,7 +516,6 @@ else if (other instanceof DataStreamQueryOperation) { dataStreamQueryOperation.getResolvedSchema(), dataStreamQueryOperation.getIdentifier()); } - throw new TableException("Unknown table operation: " + other); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java index d0b294ed5ea70..0016ec9893d4a 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java @@ -114,6 +114,10 @@ public RelWriter explainTerms(RelWriter pw) { .item("rowType", getRowType()); } + public RexNode getMLPredictCall() { + return scan.getCall(); + } + private MLPredictSpec buildMLPredictSpec(Map runtimeConfig) { RexTableArgCall tableCall = extractOperand(operand -> operand instanceof RexTableArgCall); RexCall descriptorCall = diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java index 875a088ec591f..c34dda9ea6b79 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java @@ -43,6 +43,7 @@ import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalLegacyTableSourceScan; import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalLimit; import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalLookupJoin; +import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMLPredictTableFunction; import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMatch; import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMiniBatchAssigner; import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMultiJoin; @@ -198,6 +199,9 @@ public StreamPhysicalRel visit( (StreamPhysicalWindowTableFunction) rel, requireDeterminism); } else if (rel instanceof StreamPhysicalDeltaJoin) { return visitDeltaJoin((StreamPhysicalDeltaJoin) rel, requireDeterminism); + } else if (rel instanceof StreamPhysicalMLPredictTableFunction) { + return visitMLPredictTableFunction( + (StreamPhysicalMLPredictTableFunction) rel, requireDeterminism); } else if (rel instanceof StreamPhysicalChangelogNormalize || rel instanceof StreamPhysicalDropUpdateBefore || rel instanceof StreamPhysicalMiniBatchAssigner @@ -328,6 +332,16 @@ private StreamPhysicalRel visitCalc( } } + private StreamPhysicalRel visitMLPredictTableFunction( + final StreamPhysicalMLPredictTableFunction predictTableFunction, + final ImmutableBitSet requireDeterminism) { + if (!inputInsertOnly(predictTableFunction) && !requireDeterminism.isEmpty()) { + throwNonDeterministicConditionError( + "ML_PREDICT", predictTableFunction.getMLPredictCall(), predictTableFunction); + } + return transmitDeterminismRequirement(predictTableFunction, NO_REQUIRED_DETERMINISM); + } + private StreamPhysicalRel visitCorrelate( final StreamPhysicalCorrelateBase correlate, final ImmutableBitSet requireDeterminism) { if (inputInsertOnly(correlate) || requireDeterminism.isEmpty()) { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FunctionCallUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FunctionCallUtil.java index 18091c4ae1341..5c9f26032d8c3 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FunctionCallUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FunctionCallUtil.java @@ -35,9 +35,11 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeInfo; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeName; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; import java.util.HashMap; import java.util.List; @@ -45,10 +47,15 @@ import java.util.Objects; import static org.apache.calcite.sql.SqlKind.MAP_VALUE_CONSTRUCTOR; +import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType; +import static org.apache.flink.table.types.logical.LogicalTypeFamily.CHARACTER_STRING; /** Common utils for function call, e.g. ML_PREDICT and Lookup Join. */ public abstract class FunctionCallUtil { + private static final String CONFIG_ERROR_MESSAGE = + "Config parameter should be a MAP data type consisting String literals."; + /** A field used as an equal condition when querying content from a dimension table. */ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type") @JsonSubTypes({ @@ -225,18 +232,41 @@ public static Map convert(RexCall mapConstructor) { for (int i = 0; i < mapConstructor.getOperands().size(); i += 2) { RexNode keyNode = mapConstructor.getOperands().get(i); RexNode valueNode = mapConstructor.getOperands().get(i + 1); - // Both key and value should be string literals - if (!(keyNode instanceof RexLiteral) || !(valueNode instanceof RexLiteral)) { - throw new ValidationException( - "Config parameter should be a MAP data type consisting String literals."); - } - String key = RexLiteral.stringValue(keyNode); - String value = RexLiteral.stringValue(valueNode); + String key = getStringLiteral(keyNode); + String value = getStringLiteral(valueNode); reducedConfig.put(key, value); } return reducedConfig; } + private static String getStringLiteral(RexNode node) { + // Cast from string to string is used when Expressions.lit(Map(...)) is used as config map + // from table api + if (node instanceof RexCall && node.getKind() == SqlKind.CAST) { + final RexCall castCall = (RexCall) node; + // Unwrap CAST if present + final RexNode castOperand = castCall.getOperands().get(0); + if (!(castOperand instanceof RexLiteral)) { + throw new ValidationException(CONFIG_ERROR_MESSAGE); + } + final RelDataType operandType = castOperand.getType(); + if (!toLogicalType(operandType).is(CHARACTER_STRING)) { + throw new ValidationException(CONFIG_ERROR_MESSAGE); + } + final RelDataType castType = castCall.getType(); + if (!toLogicalType(castType).is(CHARACTER_STRING)) { + throw new ValidationException(CONFIG_ERROR_MESSAGE); + } + return RexLiteral.stringValue(castOperand); + } + // Both key and value should be string literals + if (!(node instanceof RexLiteral)) { + throw new ValidationException(CONFIG_ERROR_MESSAGE); + } + + return RexLiteral.stringValue(node); + } + public static String explainFunctionParam(FunctionParam param, List fieldNames) { if (param instanceof Constant) { return RelExplainUtil.literalToString(((Constant) param).literal); diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSemanticTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSemanticTest.java index 73ef260d56bd0..aeb8c7028a87c 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSemanticTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSemanticTest.java @@ -20,6 +20,7 @@ import org.apache.flink.table.operations.QueryOperation; import org.apache.flink.table.planner.plan.nodes.exec.testutils.SemanticTestBase; +import org.apache.flink.table.test.program.FailingTableApiTestStep; import org.apache.flink.table.test.program.TableApiTestStep; import org.apache.flink.table.test.program.TableTestProgram; import org.apache.flink.table.test.program.TestStep; @@ -57,7 +58,11 @@ public List programs() { QueryOperationTestPrograms.OVER_WINDOW_LAG, QueryOperationTestPrograms.ACCESSING_NESTED_COLUMN, QueryOperationTestPrograms.ROW_SEMANTIC_TABLE_PTF, - QueryOperationTestPrograms.SET_SEMANTIC_TABLE_PTF); + QueryOperationTestPrograms.SET_SEMANTIC_TABLE_PTF, + QueryOperationTestPrograms.ML_PREDICT_MODEL_API, + QueryOperationTestPrograms.ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG, + QueryOperationTestPrograms.ASYNC_ML_PREDICT_MODEL_API, + QueryOperationTestPrograms.ML_PREDICT_ANON_MODEL_API); } @Override @@ -65,6 +70,9 @@ protected void runStep(TestStep testStep, TableEnvironment env) throws Exception if (testStep instanceof TableApiTestStep) { final TableApiTestStep tableApiStep = (TableApiTestStep) testStep; tableApiStep.applyAsSql(env).await(); + } else if (testStep instanceof FailingTableApiTestStep) { + final FailingTableApiTestStep failingTableApiStep = (FailingTableApiTestStep) testStep; + failingTableApiStep.applyAsSql(env); } else { super.runStep(testStep, env); } @@ -72,6 +80,6 @@ protected void runStep(TestStep testStep, TableEnvironment env) throws Exception @Override public EnumSet supportedRunSteps() { - return EnumSet.of(TestKind.TABLE_API, TestKind.SQL); + return EnumSet.of(TestKind.TABLE_API, TestKind.SQL, TestKind.FAILING_TABLE_API); } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java index 53315269c2f08..7f8bb251f931b 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java @@ -19,9 +19,11 @@ package org.apache.flink.table.api; import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.config.ExecutionConfigOptions; import org.apache.flink.table.api.config.OptimizerConfigOptions; import org.apache.flink.table.functions.ScalarFunction; import org.apache.flink.table.operations.QueryOperation; +import org.apache.flink.table.planner.factories.TestValuesModelFactory; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.ChainedReceivingFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.ChainedSendingFunction; import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.RowSemanticTableFunction; @@ -30,6 +32,7 @@ import org.apache.flink.table.test.program.SinkTestStep; import org.apache.flink.table.test.program.SourceTestStep; import org.apache.flink.table.test.program.TableTestProgram; +import org.apache.flink.types.ColumnList; import org.apache.flink.types.Row; import org.apache.flink.types.RowKind; @@ -39,6 +42,7 @@ import java.time.LocalTime; import java.time.ZoneId; import java.util.Collections; +import java.util.Map; import static org.apache.flink.table.api.Expressions.$; import static org.apache.flink.table.api.Expressions.UNBOUNDED_ROW; @@ -49,6 +53,10 @@ import static org.apache.flink.table.api.Expressions.lit; import static org.apache.flink.table.api.Expressions.nullOf; import static org.apache.flink.table.api.Expressions.row; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_MODEL; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SIMPLE_FEATURES_SOURCE; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SIMPLE_SINK; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SYNC_MODEL; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.BASE_SINK_SCHEMA; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.BASIC_VALUES; import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.KEYED_TIMED_BASE_SINK_SCHEMA; @@ -1117,6 +1125,97 @@ private static Instant dayOfSeconds(int second) { "sink") .build(); + public static final TableTestProgram ML_PREDICT_MODEL_API = + TableTestProgram.of("ml-predict-model-api", "ml-predict using model API") + .setupTableSource(SIMPLE_FEATURES_SOURCE) + .setupModel(SYNC_MODEL) + .setupTableSink(SIMPLE_SINK) + .runTableApi( + env -> + env.fromModel("chatgpt") + .predict( + env.from("features"), ColumnList.of("feature")), + "sink") + .build(); + + public static final TableTestProgram ASYNC_ML_PREDICT_MODEL_API = + TableTestProgram.of("async-ml-predict-model-api", "async ml-predict using model API") + .setupTableSource(SIMPLE_FEATURES_SOURCE) + .setupModel(ASYNC_MODEL) + .setupTableSink(SIMPLE_SINK) + .setupConfig( + ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE, + ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED) + .runTableApi( + env -> + env.fromModel("chatgpt") + .predict( + env.from("features"), + ColumnList.of("feature"), + Map.of( + "async", + "true", + "max-concurrent-operations", + "10")), + "sink") + .build(); + + public static final TableTestProgram ML_PREDICT_ANON_MODEL_API = + TableTestProgram.of( + "ml-predict-anonymous-model-api", + "ml-predict using anonymous model API") + .setupTableSource(SIMPLE_FEATURES_SOURCE) + .runFailingTableApi( + env -> + env.from( + ModelDescriptor.forProvider("values") + .inputSchema( + Schema.newBuilder() + .column( + "feature", + "STRING") + .build()) + .outputSchema( + Schema.newBuilder() + .column( + "category", + "STRING") + .build()) + .option( + "data-id", + TestValuesModelFactory + .registerData( + SYNC_MODEL + .data)) + .build()) + .predict( + env.from("features"), ColumnList.of("feature")), + "sink", + ValidationException.class, + "Anonymous models cannot be serialized.") + .build(); + + public static final TableTestProgram ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG = + TableTestProgram.of( + "async-ml-predict-table-api-map-expression-config", + "ml-predict in async mode using Table API and map expression.") + .setupTableSource(SIMPLE_FEATURES_SOURCE) + .setupModel(ASYNC_MODEL) + .setupTableSink(SIMPLE_SINK) + .setupConfig( + ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE, + ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED) + .runTableApi( + env -> + env.fromCall( + "ML_PREDICT", + env.from("features").asArgument("INPUT"), + env.fromModel("chatgpt").asArgument("MODEL"), + descriptor("feature").asArgument("ARGS"), + Expressions.map("async", "true").asArgument("CONFIG")), + "sink") + .build(); + /** * A function that will be used as an inline function in {@link #INLINE_FUNCTION_SERIALIZATION}. */ diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictSemanticTests.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictSemanticTests.java new file mode 100644 index 0000000000000..cb15ff4533d8e --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictSemanticTests.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.exec.stream; + +import org.apache.flink.table.planner.plan.nodes.exec.testutils.SemanticTestBase; +import org.apache.flink.table.test.program.TableTestProgram; + +import java.util.List; + +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_ML_PREDICT_MODEL_API; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_ML_PREDICT_TABLE_API; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ML_PREDICT_ANON_MODEL_API; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ML_PREDICT_MODEL_API; +import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SYNC_ML_PREDICT_TABLE_API; + +/** Semantic tests for {@link StreamExecMLPredictTableFunction} using Table API. */ +public class MLPredictSemanticTests extends SemanticTestBase { + + @Override + public List programs() { + return List.of( + SYNC_ML_PREDICT_TABLE_API, + ASYNC_ML_PREDICT_TABLE_API, + ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG, + ML_PREDICT_MODEL_API, + ASYNC_ML_PREDICT_MODEL_API, + ML_PREDICT_ANON_MODEL_API); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictTestPrograms.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictTestPrograms.java index 7d430e143b21e..26d4903a3e41d 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictTestPrograms.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictTestPrograms.java @@ -18,11 +18,17 @@ package org.apache.flink.table.planner.plan.nodes.exec.stream; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Expressions; +import org.apache.flink.table.api.ModelDescriptor; +import org.apache.flink.table.api.Schema; import org.apache.flink.table.api.config.ExecutionConfigOptions; +import org.apache.flink.table.planner.factories.TestValuesModelFactory; import org.apache.flink.table.test.program.ModelTestStep; import org.apache.flink.table.test.program.SinkTestStep; import org.apache.flink.table.test.program.SourceTestStep; import org.apache.flink.table.test.program.TableTestProgram; +import org.apache.flink.types.ColumnList; import org.apache.flink.types.Row; import org.apache.flink.types.RowKind; @@ -30,6 +36,8 @@ import java.util.List; import java.util.Map; +import static org.apache.flink.table.api.Expressions.descriptor; + /** Programs for verifying {@link StreamExecMLPredictTableFunction}. */ public class MLPredictTestPrograms { @@ -48,7 +56,13 @@ public class MLPredictTestPrograms { Row.ofKind(RowKind.INSERT, 4, "Mysql"), Row.ofKind(RowKind.INSERT, 5, "Postgres") }; - static final SourceTestStep FEATURES_TABLE = + public static final SourceTestStep SIMPLE_FEATURES_SOURCE = + SourceTestStep.newBuilder("features") + .addSchema(FEATURES_SCHEMA) + .producedValues(FEATURES_BEFORE_DATA) + .build(); + + static final SourceTestStep RESTORE_FEATURES_TABLE = SourceTestStep.newBuilder("features") .addSchema(FEATURES_SCHEMA) .producedBeforeRestore(FEATURES_BEFORE_DATA) @@ -61,7 +75,7 @@ public class MLPredictTestPrograms { static final String[] MODEL_OUTPUT_SCHEMA = new String[] {"category STRING"}; static final Map> MODEL_DATA = - new HashMap>() { + new HashMap<>() { { put( Row.ofKind(RowKind.INSERT, "Flink"), @@ -82,14 +96,14 @@ public class MLPredictTestPrograms { } }; - static final ModelTestStep SYNC_MODEL = + public static final ModelTestStep SYNC_MODEL = ModelTestStep.newBuilder("chatgpt") .addInputSchema(MODEL_INPUT_SCHEMA) .addOutputSchema(MODEL_OUTPUT_SCHEMA) .data(MODEL_DATA) .build(); - static final ModelTestStep ASYNC_MODEL = + public static final ModelTestStep ASYNC_MODEL = ModelTestStep.newBuilder("chatgpt") .addInputSchema(MODEL_INPUT_SCHEMA) .addOutputSchema(MODEL_OUTPUT_SCHEMA) @@ -102,7 +116,7 @@ public class MLPredictTestPrograms { static final String[] SINK_SCHEMA = new String[] {"id INT PRIMARY KEY NOT ENFORCED", "feature STRING", "category STRING"}; - static final SinkTestStep SINK_TABLE = + static final SinkTestStep RESTORE_SINK_TABLE = SinkTestStep.newBuilder("sink_t") .addSchema(SINK_SCHEMA) .consumedBeforeRestore( @@ -112,22 +126,31 @@ public class MLPredictTestPrograms { .consumedAfterRestore("+I[4, Mysql, Database]", "+I[5, Postgres, Database]") .build(); + public static final SinkTestStep SIMPLE_SINK = + SinkTestStep.newBuilder("sink") + .addSchema(SINK_SCHEMA) + .consumedValues( + "+I[1, Flink, Big Data]", + "+I[2, Spark, Big Data]", + "+I[3, Hive, Big Data]") + .build(); + // ------------------------------------------------------------------------------------------- public static final TableTestProgram SYNC_ML_PREDICT = TableTestProgram.of("sync-ml-predict", "ml-predict in sync mode.") - .setupTableSource(FEATURES_TABLE) + .setupTableSource(RESTORE_FEATURES_TABLE) .setupModel(SYNC_MODEL) - .setupTableSink(SINK_TABLE) + .setupTableSink(RESTORE_SINK_TABLE) .runSql( "INSERT INTO sink_t SELECT * FROM ML_PREDICT(TABLE features, MODEL chatgpt, DESCRIPTOR(feature))") .build(); public static final TableTestProgram ASYNC_UNORDERED_ML_PREDICT = TableTestProgram.of("async-unordered-ml-predict", "ml-predict in async unordered mode.") - .setupTableSource(FEATURES_TABLE) + .setupTableSource(RESTORE_FEATURES_TABLE) .setupModel(ASYNC_MODEL) - .setupTableSink(SINK_TABLE) + .setupTableSink(RESTORE_SINK_TABLE) .setupConfig( ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE, ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED) @@ -139,11 +162,148 @@ public class MLPredictTestPrograms { TableTestProgram.of( "sync-ml-predict-with-runtime-options", "ml-predict in sync mode with runtime config.") - .setupTableSource(FEATURES_TABLE) + .setupTableSource(RESTORE_FEATURES_TABLE) .setupModel(ASYNC_MODEL) - .setupTableSink(SINK_TABLE) + .setupTableSink(RESTORE_SINK_TABLE) .runSql( "INSERT INTO sink_t SELECT * FROM ML_PREDICT(TABLE features, MODEL chatgpt, DESCRIPTOR(feature), MAP['async', 'false'])") .build(); - ; + + public static final TableTestProgram SYNC_ML_PREDICT_TABLE_API = + TableTestProgram.of( + "sync-ml-predict-table-api", "ml-predict in sync mode using Table API.") + .setupTableSource(SIMPLE_FEATURES_SOURCE) + .setupModel(SYNC_MODEL) + .setupTableSink(SIMPLE_SINK) + .runTableApi( + env -> + env.fromCall( + "ML_PREDICT", + env.from("features").asArgument("INPUT"), + env.fromModel("chatgpt").asArgument("MODEL"), + descriptor("feature").asArgument("ARGS")), + "sink") + .build(); + + public static final TableTestProgram ASYNC_ML_PREDICT_TABLE_API = + TableTestProgram.of( + "async-ml-predict-table-api", + "ml-predict in async mode using Table API.") + .setupTableSource(SIMPLE_FEATURES_SOURCE) + .setupModel(ASYNC_MODEL) + .setupTableSink(SIMPLE_SINK) + .setupConfig( + ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE, + ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED) + .runTableApi( + env -> + env.fromCall( + "ML_PREDICT", + env.from("features").asArgument("INPUT"), + env.fromModel("chatgpt").asArgument("MODEL"), + descriptor("feature").asArgument("ARGS"), + Expressions.lit( + Map.of("async", "true"), + DataTypes.MAP( + DataTypes.STRING(), + DataTypes.STRING()) + .notNull()) + .asArgument("CONFIG")), + "sink") + .build(); + + public static final TableTestProgram ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG = + TableTestProgram.of( + "async-ml-predict-table-api-map-expression-config", + "ml-predict in async mode using Table API and map expression.") + .setupTableSource(SIMPLE_FEATURES_SOURCE) + .setupModel(ASYNC_MODEL) + .setupTableSink(SIMPLE_SINK) + .setupConfig( + ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE, + ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED) + .runTableApi( + env -> + env.fromCall( + "ML_PREDICT", + env.from("features").asArgument("INPUT"), + env.fromModel("chatgpt").asArgument("MODEL"), + descriptor("feature").asArgument("ARGS"), + Expressions.map( + "async", + "true", + "max-concurrent-operations", + "10") + .asArgument("CONFIG")), + "sink") + .build(); + + public static final TableTestProgram ML_PREDICT_MODEL_API = + TableTestProgram.of("ml-predict-model-api", "ml-predict using model API") + .setupTableSource(SIMPLE_FEATURES_SOURCE) + .setupModel(SYNC_MODEL) + .setupTableSink(SIMPLE_SINK) + .runTableApi( + env -> + env.fromModel("chatgpt") + .predict( + env.from("features"), ColumnList.of("feature")), + "sink") + .build(); + + public static final TableTestProgram ASYNC_ML_PREDICT_MODEL_API = + TableTestProgram.of("async-ml-predict-model-api", "async ml-predict using model API") + .setupTableSource(SIMPLE_FEATURES_SOURCE) + .setupModel(ASYNC_MODEL) + .setupTableSink(SIMPLE_SINK) + .setupConfig( + ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE, + ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED) + .runTableApi( + env -> + env.fromModel("chatgpt") + .predict( + env.from("features"), + ColumnList.of("feature"), + Map.of( + "async", + "true", + "max-concurrent-operations", + "10")), + "sink") + .build(); + + public static final TableTestProgram ML_PREDICT_ANON_MODEL_API = + TableTestProgram.of( + "ml-predict-anonymous-model-api", + "ml-predict using anonymous model API") + .setupTableSource(SIMPLE_FEATURES_SOURCE) + .setupTableSink(SIMPLE_SINK) + .runTableApi( + env -> + env.from( + ModelDescriptor.forProvider("values") + .inputSchema( + Schema.newBuilder() + .column( + "feature", + "STRING") + .build()) + .outputSchema( + Schema.newBuilder() + .column( + "category", + "STRING") + .build()) + .option( + "data-id", + TestValuesModelFactory + .registerData( + SYNC_MODEL + .data)) + .build()) + .predict( + env.from("features"), ColumnList.of("feature")), + "sink") + .build(); } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/SemanticTestBase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/SemanticTestBase.java index 9c65cec88c8d4..4ba4c3cc8651e 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/SemanticTestBase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/SemanticTestBase.java @@ -22,10 +22,13 @@ import org.apache.flink.table.api.TableConfig; import org.apache.flink.table.api.TableEnvironment; import org.apache.flink.table.api.config.OptimizerConfigOptions; +import org.apache.flink.table.planner.factories.TestValuesModelFactory; import org.apache.flink.table.planner.factories.TestValuesTableFactory; import org.apache.flink.table.test.program.ConfigOptionTestStep; import org.apache.flink.table.test.program.FailingSqlTestStep; +import org.apache.flink.table.test.program.FailingTableApiTestStep; import org.apache.flink.table.test.program.FunctionTestStep; +import org.apache.flink.table.test.program.ModelTestStep; import org.apache.flink.table.test.program.SinkTestStep; import org.apache.flink.table.test.program.SourceTestStep; import org.apache.flink.table.test.program.SqlTestStep; @@ -62,6 +65,7 @@ public abstract class SemanticTestBase implements TableTestProgramRunner { public EnumSet supportedSetupSteps() { return EnumSet.of( TestKind.CONFIG, + TestKind.MODEL, TestKind.SOURCE_WITH_DATA, TestKind.SINK_WITH_DATA, TestKind.FUNCTION, @@ -70,7 +74,8 @@ public EnumSet supportedSetupSteps() { @Override public EnumSet supportedRunSteps() { - return EnumSet.of(TestKind.SQL, TestKind.FAILING_SQL, TestKind.TABLE_API); + return EnumSet.of( + TestKind.SQL, TestKind.FAILING_SQL, TestKind.TABLE_API, TestKind.FAILING_TABLE_API); } @AfterEach @@ -145,6 +150,22 @@ protected void runStep(TestStep testStep, TableEnvironment env) throws Exception sqlTestStep.apply(env); } break; + case FAILING_TABLE_API: + { + final FailingTableApiTestStep tableApiTestStep = + (FailingTableApiTestStep) testStep; + tableApiTestStep.apply(env); + } + break; + case MODEL: + { + final ModelTestStep modelTestStep = (ModelTestStep) testStep; + final Map options = new HashMap<>(); + options.put("provider", "values"); + options.put("data-id", TestValuesModelFactory.registerData(modelTestStep.data)); + modelTestStep.apply(env, options); + } + break; case TABLE_API: { final TableApiTestStep apiTestStep = (TableApiTestStep) testStep; diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableEnvironmentTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableEnvironmentTest.scala index 6e9b9285bec6b..8bd7ca7869cab 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableEnvironmentTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableEnvironmentTest.scala @@ -24,7 +24,7 @@ import org.apache.flink.core.testutils.FlinkAssertions.anyCauseMatches import org.apache.flink.sql.parser.error.SqlValidateException import org.apache.flink.streaming.api.environment.{LocalStreamEnvironment, StreamExecutionEnvironment} import org.apache.flink.table.api.bridge.scala._ -import org.apache.flink.table.api.internal.TableEnvironmentInternal +import org.apache.flink.table.api.internal.{ModelImpl, TableEnvironmentInternal} import org.apache.flink.table.catalog._ import org.apache.flink.table.factories.{TableFactoryUtil, TableSourceFactoryContextImpl} import org.apache.flink.table.functions.TestGenericUDF @@ -3250,6 +3250,28 @@ class TableEnvironmentTest { checkData(util.Arrays.asList(Row.of("your_model")).iterator(), tableResult3.collect()) } + @Test + def testGetNonExistModel(): Unit = { + assertThatThrownBy(() => tableEnv.fromModelPath("MyModel")) + .hasMessageContaining("Model `MyModel` was not found") + .isInstanceOf[ValidationException] + } + + @Test + def testGetModel(): Unit = { + val inputSchema = Schema.newBuilder().column("feature", DataTypes.STRING()).build() + + val outputSchema = Schema.newBuilder().column("response", DataTypes.DOUBLE()).build() + tableEnv.createModel( + "MyModel", + ModelDescriptor + .forProvider("openai") + .inputSchema(inputSchema) + .outputSchema(outputSchema) + .build()) + assertThat(tableEnv.fromModelPath("MyModel")).isInstanceOf(classOf[ModelImpl]) + } + @Test def testTemporaryOperationListener(): Unit = { val listener = new ListenerCatalog("listener_cat") diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala index f3d7b1023ed41..e4fca40e7f54e 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala @@ -37,7 +37,7 @@ import org.apache.flink.table.api.bridge.java.{StreamTableEnvironment => JavaStr import org.apache.flink.table.api.bridge.scala.{StreamTableEnvironment => ScalaStreamTableEnv} import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} import org.apache.flink.table.api.config.OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE -import org.apache.flink.table.api.internal.{StatementSetImpl, TableEnvironmentImpl, TableEnvironmentInternal, TableImpl} +import org.apache.flink.table.api.internal._ import org.apache.flink.table.api.typeutils.CaseClassTypeInfo import org.apache.flink.table.catalog._ import org.apache.flink.table.connector.ChangelogMode