scanInternal(UnresolvedIdentifier identifier) {
ObjectIdentifier tableIdentifier = catalogManager.qualifyIdentifier(identifier);
@@ -1487,6 +1514,10 @@ public TableImpl createTable(QueryOperation tableOperation) {
functionCatalog.asLookup(getParser()::parseIdentifier));
}
+ public ModelImpl createModel(ContextResolvedModel model) {
+ return ModelImpl.createModel(this, model);
+ }
+
@Override
public String explainPlan(InternalPlan compiledPlan, ExplainDetail... extraDetails) {
return planner.explainPlan(compiledPlan, extraDetails);
@@ -1531,5 +1562,16 @@ public Void visit(TableReferenceExpression tableRef) {
}
return null;
}
+
+ @Override
+ public Void visit(ModelReferenceExpression modelRef) {
+ super.visit(modelRef);
+ if (modelRef.getTableEnvironment() != null
+ && modelRef.getTableEnvironment() != TableEnvironmentImpl.this) {
+ throw new ValidationException(
+ "All model references must use the same TableEnvironment.");
+ }
+ return null;
+ }
}
}
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ContextResolvedModel.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ContextResolvedModel.java
index dbfba0b7a584e..74854135de757 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ContextResolvedModel.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ContextResolvedModel.java
@@ -19,12 +19,14 @@
package org.apache.flink.table.catalog;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.factories.FactoryUtil;
import org.apache.flink.util.Preconditions;
import javax.annotation.Nullable;
import java.util.Objects;
import java.util.Optional;
+import java.util.concurrent.atomic.AtomicInteger;
/**
* This class contains information about a model and its relationship with a {@link Catalog}, if
@@ -48,28 +50,46 @@
@Internal
public final class ContextResolvedModel {
+ private static final AtomicInteger uniqueId = new AtomicInteger(0);
private final ObjectIdentifier objectIdentifier;
private final @Nullable Catalog catalog;
private final ResolvedCatalogModel resolvedModel;
+ private final boolean anonymous;
public static ContextResolvedModel permanent(
ObjectIdentifier identifier, Catalog catalog, ResolvedCatalogModel resolvedModel) {
return new ContextResolvedModel(
- identifier, Preconditions.checkNotNull(catalog), resolvedModel);
+ identifier, Preconditions.checkNotNull(catalog), resolvedModel, false);
}
public static ContextResolvedModel temporary(
ObjectIdentifier identifier, ResolvedCatalogModel resolvedModel) {
- return new ContextResolvedModel(identifier, null, resolvedModel);
+ return new ContextResolvedModel(identifier, null, resolvedModel, false);
+ }
+
+ public static ContextResolvedModel anonymous(ResolvedCatalogModel resolvedModel) {
+ return anonymous(null, resolvedModel);
+ }
+
+ public static ContextResolvedModel anonymous(
+ @Nullable String hint, ResolvedCatalogModel resolvedModel) {
+ return new ContextResolvedModel(
+ ObjectIdentifier.ofAnonymous(
+ generateAnonymousStringIdentifier(hint, resolvedModel)),
+ null,
+ resolvedModel,
+ true);
}
private ContextResolvedModel(
ObjectIdentifier objectIdentifier,
@Nullable Catalog catalog,
- ResolvedCatalogModel resolvedModel) {
+ ResolvedCatalogModel resolvedModel,
+ boolean anonymous) {
this.objectIdentifier = Preconditions.checkNotNull(objectIdentifier);
this.catalog = catalog;
this.resolvedModel = Preconditions.checkNotNull(resolvedModel);
+ this.anonymous = anonymous;
}
/**
@@ -83,6 +103,10 @@ public boolean isPermanent() {
return !isTemporary();
}
+ public boolean isAnonymous() {
+ return anonymous;
+ }
+
public ObjectIdentifier getIdentifier() {
return objectIdentifier;
}
@@ -116,13 +140,39 @@ public boolean equals(Object o) {
return false;
}
ContextResolvedModel that = (ContextResolvedModel) o;
- return Objects.equals(objectIdentifier, that.objectIdentifier)
+ return anonymous == that.anonymous
+ && Objects.equals(objectIdentifier, that.objectIdentifier)
&& Objects.equals(catalog, that.catalog)
&& Objects.equals(resolvedModel, that.resolvedModel);
}
@Override
public int hashCode() {
- return Objects.hash(objectIdentifier, catalog, resolvedModel);
+ return Objects.hash(objectIdentifier, catalog, resolvedModel, anonymous);
+ }
+
+ /**
+ * This method tries to return the provider name of the model, trying to provide a bit more
+ * helpful toString for anonymous models. It's only to help users to debug, and its return value
+ * should not be relied on.
+ */
+ private static String generateAnonymousStringIdentifier(
+ @Nullable String hint, ResolvedCatalogModel resolvedModel) {
+ // Planner can do some fancy optimizations' logic squashing two sources together in the same
+ // operator. Because this logic is string based, anonymous models still need some kind of
+ // unique string based identifier that can be used later by the planner.
+ if (hint == null) {
+ try {
+ hint = resolvedModel.getOptions().get(FactoryUtil.PROVIDER.key());
+ } catch (Exception ignored) {
+ }
+ }
+
+ int id = uniqueId.incrementAndGet();
+ if (hint == null) {
+ return "*anonymous$" + id + "*";
+ }
+
+ return "*anonymous_" + hint + "$" + id + "*";
}
}
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java
index 8748fafceb799..bcad8cccfce79 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java
@@ -21,9 +21,11 @@
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.ApiExpression;
import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Model;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.api.internal.ModelImpl;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.catalog.ContextResolvedFunction;
import org.apache.flink.table.functions.BuiltInFunctionDefinition;
@@ -301,6 +303,11 @@ public static TableReferenceExpression tableRef(
return new TableReferenceExpression(name, queryOperation, env);
}
+ public static ModelReferenceExpression modelRef(String name, Model model) {
+ return new ModelReferenceExpression(
+ name, ((ModelImpl) model).getModel(), ((ModelImpl) model).getTableEnvironment());
+ }
+
public static LookupCallExpression lookupCall(String name, Expression... args) {
return new LookupCallExpression(
name,
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionVisitor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionVisitor.java
index f0b34bbddfaba..b64b77d24fba0 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionVisitor.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionVisitor.java
@@ -29,6 +29,8 @@ public final R visit(Expression other) {
return visit((UnresolvedReferenceExpression) other);
} else if (other instanceof TableReferenceExpression) {
return visit((TableReferenceExpression) other);
+ } else if (other instanceof ModelReferenceExpression) {
+ return visit((ModelReferenceExpression) other);
} else if (other instanceof LocalReferenceExpression) {
return visit((LocalReferenceExpression) other);
} else if (other instanceof LookupCallExpression) {
@@ -49,6 +51,8 @@ public final R visit(Expression other) {
public abstract R visit(TableReferenceExpression tableReference);
+ public abstract R visit(ModelReferenceExpression modelReferenceExpression);
+
public abstract R visit(LocalReferenceExpression localReference);
/** For resolved expressions created by the planner. */
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ModelReferenceExpression.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ModelReferenceExpression.java
new file mode 100644
index 0000000000000..c3c225528b204
--- /dev/null
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ModelReferenceExpression.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.expressions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.Model;
+import org.apache.flink.table.api.TableEnvironment;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.catalog.ContextResolvedModel;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.utils.DataTypeUtils;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * A reference to a {@link Model} in an expression context.
+ *
+ * This expression is used when a model needs to be passed as an argument to functions or
+ * operations that accept model references. It wraps a model object and provides the necessary
+ * expression interface for use in the Table API expression system.
+ *
+ *
The expression carries a string representation of the model and uses a special data type to
+ * indicate that this is a model reference rather than a regular data value.
+ */
+@Internal
+public final class ModelReferenceExpression implements ResolvedExpression {
+
+ private final String name;
+ private final ContextResolvedModel model;
+ // The environment is optional but serves validation purposes
+ // to ensure that all referenced tables belong to the same
+ // environment.
+ private final TableEnvironment env;
+
+ public ModelReferenceExpression(String name, ContextResolvedModel model, TableEnvironment env) {
+ this.name = Preconditions.checkNotNull(name);
+ this.model = Preconditions.checkNotNull(model);
+ this.env = Preconditions.checkNotNull(env);
+ }
+
+ /**
+ * Returns the name of this model reference.
+ *
+ * @return the model reference name
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * Returns the ContextResolvedModel associated with this model reference.
+ *
+ * @return the query context resolved model
+ */
+ public ContextResolvedModel getModel() {
+ return model;
+ }
+
+ public @Nullable TableEnvironment getTableEnvironment() {
+ return env;
+ }
+
+ /**
+ * Returns the input data type expected by this model reference.
+ *
+ *
This method extracts the input data type from the model's input schema, which describes
+ * the structure and data types that the model expects for inference operations.
+ *
+ * @return the input data type expected by the model
+ */
+ public DataType getInputDataType() {
+ return DataTypeUtils.fromResolvedSchemaPreservingTimeAttributes(
+ model.getResolvedModel().getResolvedInputSchema());
+ }
+
+ @Override
+ public DataType getOutputDataType() {
+ return DataTypeUtils.fromResolvedSchemaPreservingTimeAttributes(
+ model.getResolvedModel().getResolvedOutputSchema());
+ }
+
+ @Override
+ public List getResolvedChildren() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public String asSerializableString(SqlFactory sqlFactory) {
+ if (model.isAnonymous()) {
+ throw new ValidationException("Anonymous models cannot be serialized.");
+ }
+
+ return "MODEL " + model.getIdentifier().asSerializableString();
+ }
+
+ @Override
+ public String asSummaryString() {
+ return name;
+ }
+
+ @Override
+ public List getChildren() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public R accept(ExpressionVisitor visitor) {
+ return visitor.visit(this);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ModelReferenceExpression that = (ModelReferenceExpression) o;
+ return Objects.equals(name, that.name)
+ && Objects.equals(model, that.model)
+ // Effectively means reference equality
+ && Objects.equals(env, that.env);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(name, model, env);
+ }
+
+ @Override
+ public String toString() {
+ return asSummaryString();
+ }
+}
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ResolvedExpressionVisitor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ResolvedExpressionVisitor.java
index 42ee52d98d51c..88a15e3929d8a 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ResolvedExpressionVisitor.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ResolvedExpressionVisitor.java
@@ -32,6 +32,8 @@ public abstract class ResolvedExpressionVisitor implements ExpressionVisitor<
public final R visit(Expression other) {
if (other instanceof TableReferenceExpression) {
return visit((TableReferenceExpression) other);
+ } else if (other instanceof ModelReferenceExpression) {
+ return visit((ModelReferenceExpression) other);
} else if (other instanceof LocalReferenceExpression) {
return visit((LocalReferenceExpression) other);
} else if (other instanceof ResolvedExpression) {
@@ -42,6 +44,8 @@ public final R visit(Expression other) {
public abstract R visit(TableReferenceExpression tableReference);
+ public abstract R visit(ModelReferenceExpression modelReferenceExpression);
+
public abstract R visit(LocalReferenceExpression localReference);
/** For resolved expressions created by the planner. */
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java
index f5c585087dd01..e298953f69e3e 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java
@@ -28,6 +28,7 @@
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.ExpressionUtils;
+import org.apache.flink.table.expressions.ModelReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.TableReferenceExpression;
import org.apache.flink.table.expressions.TypeLiteralExpression;
@@ -38,6 +39,7 @@
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.FunctionIdentifier;
import org.apache.flink.table.functions.FunctionKind;
+import org.apache.flink.table.functions.ModelSemantics;
import org.apache.flink.table.functions.ScalarFunctionDefinition;
import org.apache.flink.table.functions.TableAggregateFunctionDefinition;
import org.apache.flink.table.functions.TableFunctionDefinition;
@@ -651,6 +653,22 @@ public Optional getTableSemantics(int pos) {
return Optional.of(semantics);
}
+ @Override
+ public Optional getModelSemantics(int pos) {
+ final StaticArgument staticArg =
+ Optional.ofNullable(staticArguments).map(args -> args.get(pos)).orElse(null);
+ if (staticArg == null || !staticArg.is(StaticArgumentTrait.MODEL)) {
+ return Optional.empty();
+ }
+ final ResolvedExpression arg = getArgument(pos);
+ if (!(arg instanceof ModelReferenceExpression)) {
+ return Optional.empty();
+ }
+ final ModelReferenceExpression modelRef = (ModelReferenceExpression) arg;
+ final ModelSemantics semantics = new TableApiModelSemantics(modelRef);
+ return Optional.of(semantics);
+ }
+
@Override
public String getName() {
return functionName;
@@ -732,4 +750,23 @@ public Optional changelogMode() {
return Optional.empty();
}
}
+
+ private static class TableApiModelSemantics implements ModelSemantics {
+
+ private final ModelReferenceExpression modelRef;
+
+ private TableApiModelSemantics(ModelReferenceExpression modelRef) {
+ this.modelRef = modelRef;
+ }
+
+ @Override
+ public DataType inputDataType() {
+ return modelRef.getInputDataType();
+ }
+
+ @Override
+ public DataType outputDataType() {
+ return modelRef.getOutputDataType();
+ }
+ }
}
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java
index 9797f53a331c2..6e0303454094d 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java
@@ -25,6 +25,7 @@
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.LocalReferenceExpression;
import org.apache.flink.table.expressions.LookupCallExpression;
+import org.apache.flink.table.expressions.ModelReferenceExpression;
import org.apache.flink.table.expressions.NestedFieldReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.SqlCallExpression;
@@ -76,6 +77,11 @@ public T visit(TableReferenceExpression tableReference) {
return defaultMethod(tableReference);
}
+ @Override
+ public T visit(ModelReferenceExpression modelReference) {
+ return defaultMethod(modelReference);
+ }
+
@Override
public T visit(LocalReferenceExpression localReference) {
return defaultMethod(localReference);
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java
index 3bf93880d7d56..841ff5a03393f 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java
@@ -22,6 +22,7 @@
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.LocalReferenceExpression;
+import org.apache.flink.table.expressions.ModelReferenceExpression;
import org.apache.flink.table.expressions.NestedFieldReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.ResolvedExpressionVisitor;
@@ -41,6 +42,10 @@ public T visit(TableReferenceExpression tableReference) {
return defaultMethod(tableReference);
}
+ public T visit(ModelReferenceExpression modelReferenceExpression) {
+ return defaultMethod(modelReferenceExpression);
+ }
+
@Override
public T visit(LocalReferenceExpression localReference) {
return defaultMethod(localReference);
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationExpressionsUtils.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationExpressionsUtils.java
index 4d7af03afa851..4f39f3ecb8410 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationExpressionsUtils.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationExpressionsUtils.java
@@ -24,6 +24,7 @@
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.LocalReferenceExpression;
import org.apache.flink.table.expressions.LookupCallExpression;
+import org.apache.flink.table.expressions.ModelReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.TableReferenceExpression;
import org.apache.flink.table.expressions.UnresolvedCallExpression;
@@ -278,6 +279,11 @@ public Optional visit(TableReferenceExpression tableReference) {
return Optional.of(tableReference.getName());
}
+ @Override
+ public Optional visit(ModelReferenceExpression modelReference) {
+ return Optional.of(modelReference.getName());
+ }
+
@Override
public Optional visit(FieldReferenceExpression fieldReference) {
return Optional.of(fieldReference.getName());
diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/FailingTableApiTestStep.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/FailingTableApiTestStep.java
new file mode 100644
index 0000000000000..83b90f8946b0c
--- /dev/null
+++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/FailingTableApiTestStep.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.test.program;
+
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.TableEnvironment;
+import org.apache.flink.table.api.TableRuntimeException;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.test.program.TableApiTestStep.TableEnvAccessor;
+import org.apache.flink.util.Preconditions;
+
+import java.util.function.Function;
+
+import static org.apache.flink.core.testutils.FlinkAssertions.anyCauseMatches;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Test step for executing Table API query that will fail eventually with either {@link
+ * ValidationException} (during planning time) or {@link TableRuntimeException} (during execution
+ * time).
+ *
+ * Similar to {@link FailingSqlTestStep} but uses Table API instead of SQL.
+ */
+public final class FailingTableApiTestStep implements TestStep {
+
+ private final Function tableQuery;
+ private final String sinkName;
+ public final Class extends Exception> expectedException;
+ public final String expectedErrorMessage;
+
+ FailingTableApiTestStep(
+ Function tableQuery,
+ String sinkName,
+ Class extends Exception> expectedException,
+ String expectedErrorMessage) {
+ Preconditions.checkArgument(
+ expectedException == ValidationException.class
+ || expectedException == TableRuntimeException.class,
+ "Usually a Table API query should fail with either validation or runtime exception. "
+ + "Otherwise this might require an update to the exception design.");
+ this.tableQuery = tableQuery;
+ this.sinkName = sinkName;
+ this.expectedException = expectedException;
+ this.expectedErrorMessage = expectedErrorMessage;
+ }
+
+ @Override
+ public TestKind getKind() {
+ return TestKind.FAILING_TABLE_API;
+ }
+
+ public Table toTable(TableEnvironment env) {
+ return tableQuery.apply(
+ new TableEnvAccessor() {
+ @Override
+ public Table from(String path) {
+ return env.from(path);
+ }
+
+ @Override
+ public Table fromCall(String path, Object... arguments) {
+ return env.fromCall(path, arguments);
+ }
+
+ @Override
+ public Table fromCall(
+ Class extends org.apache.flink.table.functions.UserDefinedFunction>
+ function,
+ Object... arguments) {
+ return env.fromCall(function, arguments);
+ }
+
+ @Override
+ public Table fromValues(Object... values) {
+ return env.fromValues(values);
+ }
+
+ @Override
+ public Table fromValues(
+ org.apache.flink.table.types.AbstractDataType> dataType,
+ Object... values) {
+ return env.fromValues(dataType, values);
+ }
+
+ @Override
+ public Table sqlQuery(String query) {
+ return env.sqlQuery(query);
+ }
+
+ @Override
+ public org.apache.flink.table.api.Model fromModel(String modelPath) {
+ return env.fromModelPath(modelPath);
+ }
+
+ @Override
+ public org.apache.flink.table.api.Model from(
+ org.apache.flink.table.api.ModelDescriptor modelDescriptor) {
+ return env.from(modelDescriptor);
+ }
+ });
+ }
+
+ public void apply(TableEnvironment env) {
+ assertThatThrownBy(
+ () -> {
+ final Table table = toTable(env);
+ table.executeInsert(sinkName).await();
+ })
+ .satisfies(anyCauseMatches(expectedException, expectedErrorMessage));
+ }
+
+ public void applyAsSql(TableEnvironment env) {
+ assertThatThrownBy(
+ () -> {
+ final Table table = toTable(env);
+ final String query =
+ table.getQueryOperation()
+ .asSerializableString(
+ org.apache.flink.table.expressions
+ .DefaultSqlFactory.INSTANCE);
+ env.executeSql(String.format("INSERT INTO %s %s", sinkName, query))
+ .await();
+ })
+ .satisfies(anyCauseMatches(expectedException, expectedErrorMessage));
+ }
+}
diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableApiTestStep.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableApiTestStep.java
index 4a375ce4f5932..4cfb19128e12e 100644
--- a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableApiTestStep.java
+++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableApiTestStep.java
@@ -18,6 +18,8 @@
package org.apache.flink.table.test.program;
+import org.apache.flink.table.api.Model;
+import org.apache.flink.table.api.ModelDescriptor;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.TableResult;
@@ -75,6 +77,16 @@ public Table fromValues(AbstractDataType> dataType, Object... values) {
public Table sqlQuery(String query) {
return env.sqlQuery(query);
}
+
+ @Override
+ public Model fromModel(String modelPath) {
+ return env.fromModelPath(modelPath);
+ }
+
+ @Override
+ public Model from(ModelDescriptor modelDescriptor) {
+ return env.from(modelDescriptor);
+ }
});
}
@@ -111,5 +123,11 @@ public interface TableEnvAccessor {
/** See {@link TableEnvironment#sqlQuery(String)}. */
Table sqlQuery(String query);
+
+ /** See {@link TableEnvironment#fromModelPath(String)}. */
+ Model fromModel(String modelPath);
+
+ /** See {@link TableEnvironment#from(ModelDescriptor)}. */
+ Model from(ModelDescriptor modelDescriptor);
}
}
diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableTestProgram.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableTestProgram.java
index bb8ca296cae5a..d37c4c8dc957e 100644
--- a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableTestProgram.java
+++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableTestProgram.java
@@ -21,6 +21,7 @@
import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableRuntimeException;
+import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.test.program.FunctionTestStep.FunctionBehavior;
@@ -355,6 +356,22 @@ public Builder runFailingSql(
return this;
}
+ /**
+ * Run step for executing a Table API query that will fail eventually with either {@link
+ * ValidationException} (during planning time) or {@link TableRuntimeException} (during
+ * execution time).
+ */
+ public Builder runFailingTableApi(
+ Function toTable,
+ String sinkName,
+ Class extends Exception> expectedException,
+ String expectedErrorMessage) {
+ this.runSteps.add(
+ new FailingTableApiTestStep(
+ toTable, sinkName, expectedException, expectedErrorMessage));
+ return this;
+ }
+
public Builder runTableApi(Function toTable, String sinkName) {
this.runSteps.add(new TableApiTestStep(toTable, sinkName));
return this;
diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TestStep.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TestStep.java
index fc4245df79ff9..db2fe754f360d 100644
--- a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TestStep.java
+++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TestStep.java
@@ -51,7 +51,8 @@ enum TestKind {
SINK_WITHOUT_DATA,
SINK_WITH_DATA,
SINK_WITH_RESTORE_DATA,
- FAILING_SQL
+ FAILING_SQL,
+ FAILING_TABLE_API
}
TestKind getKind();
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/types/ColumnList.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/types/ColumnList.java
index d2d02bc498f70..bc70a660662f3 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/types/ColumnList.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/types/ColumnList.java
@@ -61,6 +61,10 @@ public static ColumnList of(List names) {
return of(names, List.of());
}
+ public static ColumnList of(String... names) {
+ return of(List.of(names));
+ }
+
/**
* Returns a list of column names.
*
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java
index b4c0cf396266a..7b44d97cbf901 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java
@@ -19,26 +19,42 @@
package org.apache.flink.table.planner.expressions.converter;
import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.catalog.Catalog;
+import org.apache.flink.table.catalog.ContextResolvedModel;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.data.DecimalData;
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.Expression;
-import org.apache.flink.table.expressions.ExpressionVisitor;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.LocalReferenceExpression;
+import org.apache.flink.table.expressions.ModelReferenceExpression;
import org.apache.flink.table.expressions.NestedFieldReferenceExpression;
+import org.apache.flink.table.expressions.ResolvedExpression;
+import org.apache.flink.table.expressions.ResolvedExpressionVisitor;
+import org.apache.flink.table.expressions.TableReferenceExpression;
import org.apache.flink.table.expressions.TimeIntervalUnit;
import org.apache.flink.table.expressions.TimePointUnit;
import org.apache.flink.table.expressions.TypeLiteralExpression;
import org.apache.flink.table.expressions.ValueLiteralExpression;
+import org.apache.flink.table.factories.FactoryUtil;
+import org.apache.flink.table.factories.ModelProviderFactory;
+import org.apache.flink.table.ml.ModelProvider;
+import org.apache.flink.table.module.Module;
+import org.apache.flink.table.operations.PartitionQueryOperation;
+import org.apache.flink.table.planner.calcite.FlinkContext;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.calcite.RexFieldVariable;
+import org.apache.flink.table.planner.calcite.RexModelCall;
+import org.apache.flink.table.planner.calcite.RexTableArgCall;
import org.apache.flink.table.planner.expressions.RexNodeExpression;
import org.apache.flink.table.planner.expressions.converter.CallExpressionConvertRule.ConvertContext;
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.TimeType;
+import org.apache.flink.table.types.utils.DataTypeUtils;
import org.apache.flink.types.ColumnList;
import org.apache.calcite.avatica.util.ByteString;
@@ -60,6 +76,7 @@
import java.time.Period;
import java.time.ZoneOffset;
import java.time.temporal.ChronoField;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
@@ -69,19 +86,22 @@
import static org.apache.flink.table.planner.utils.ShortcutUtils.unwrapContext;
import static org.apache.flink.table.planner.utils.TimestampStringUtils.fromLocalDateTime;
import static org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType;
+import static org.apache.flink.util.OptionalUtils.firstPresent;
/** Visit expression to generator {@link RexNode}. */
-public class ExpressionConverter implements ExpressionVisitor {
+public class ExpressionConverter extends ResolvedExpressionVisitor {
private final RelBuilder relBuilder;
private final FlinkTypeFactory typeFactory;
private final DataTypeFactory dataTypeFactory;
+ private final List inputStack;
public ExpressionConverter(RelBuilder relBuilder) {
this.relBuilder = relBuilder;
this.typeFactory = (FlinkTypeFactory) relBuilder.getRexBuilder().getTypeFactory();
this.dataTypeFactory =
unwrapContext(relBuilder.getCluster()).getCatalogManager().getDataTypeFactory();
+ this.inputStack = new ArrayList<>();
}
private List getFunctionConvertChain(boolean isBatchMode) {
@@ -235,27 +255,96 @@ public RexNode visit(TypeLiteralExpression typeLiteral) {
}
@Override
- public RexNode visit(Expression other) {
+ public RexNode visit(TableReferenceExpression tableRef) {
+ final LogicalType tableArgType = tableRef.getOutputDataType().getLogicalType();
+ final RelDataType rowType = typeFactory.buildRelNodeRowType((RowType) tableArgType);
+ final int[] partitionKeys;
+ if (tableRef.getQueryOperation() instanceof PartitionQueryOperation) {
+ final PartitionQueryOperation partitionOperation =
+ (PartitionQueryOperation) tableRef.getQueryOperation();
+ partitionKeys = partitionOperation.getPartitionKeys();
+ } else {
+ partitionKeys = new int[0];
+ }
+ final RexTableArgCall tableArgCall =
+ new RexTableArgCall(rowType, inputStack.size(), partitionKeys, new int[0]);
+ inputStack.add(relBuilder.build());
+ return tableArgCall;
+ }
+
+ public List copyInputStack() {
+ return new ArrayList<>(inputStack);
+ }
+
+ public void clearInputStack() {
+ inputStack.clear();
+ }
+
+ @Override
+ public RexNode visit(ModelReferenceExpression modelRef) {
+ final ContextResolvedModel contextResolvedModel = modelRef.getModel();
+ final FlinkContext flinkContext = ShortcutUtils.unwrapContext(relBuilder);
+
+ final Optional factoryFromCatalog =
+ contextResolvedModel
+ .getCatalog()
+ .flatMap(Catalog::getFactory)
+ .map(
+ f ->
+ f instanceof ModelProviderFactory
+ ? (ModelProviderFactory) f
+ : null);
+
+ final Optional factoryFromModule =
+ flinkContext.getModuleManager().getFactory(Module::getModelProviderFactory);
+
+ // Since the catalog is more specific, we give it
+ // precedence over a factory provided by any
+ // modules.
+ final ModelProviderFactory factory =
+ firstPresent(factoryFromCatalog, factoryFromModule).orElse(null);
+
+ final ModelProvider modelProvider =
+ FactoryUtil.createModelProvider(
+ factory,
+ contextResolvedModel.getIdentifier(),
+ contextResolvedModel.getResolvedModel(),
+ flinkContext.getTableConfig(),
+ flinkContext.getClassLoader(),
+ contextResolvedModel.isTemporary());
+ final LogicalType modelOutputType =
+ DataTypeUtils.fromResolvedSchemaPreservingTimeAttributes(
+ contextResolvedModel.getResolvedModel().getResolvedOutputSchema())
+ .getLogicalType();
+ final RelDataType modelOutputRelDataType =
+ typeFactory.buildRelNodeRowType((RowType) modelOutputType);
+
+ return new RexModelCall(modelOutputRelDataType, contextResolvedModel, modelProvider);
+ }
+
+ @Override
+ public RexNode visit(LocalReferenceExpression local) {
+ // check whether the local field reference can actually be resolved to an existing
+ // field otherwise preserve the locality attribute
+ RelNode inputNode;
+ try {
+ inputNode = relBuilder.peek();
+ } catch (Throwable t) {
+ inputNode = null;
+ }
+ if (inputNode != null && inputNode.getRowType().getFieldNames().contains(local.getName())) {
+ return relBuilder.field(local.getName());
+ }
+ return new RexFieldVariable(
+ local.getName(),
+ typeFactory.createFieldTypeFromLogicalType(
+ fromDataTypeToLogicalType(local.getOutputDataType())));
+ }
+
+ @Override
+ public RexNode visit(ResolvedExpression other) {
if (other instanceof RexNodeExpression) {
return ((RexNodeExpression) other).getRexNode();
- } else if (other instanceof LocalReferenceExpression) {
- final LocalReferenceExpression local = (LocalReferenceExpression) other;
- // check whether the local field reference can actually be resolved to an existing
- // field otherwise preserve the locality attribute
- RelNode inputNode;
- try {
- inputNode = relBuilder.peek();
- } catch (Throwable t) {
- inputNode = null;
- }
- if (inputNode != null
- && inputNode.getRowType().getFieldNames().contains(local.getName())) {
- return relBuilder.field(local.getName());
- }
- return new RexFieldVariable(
- local.getName(),
- typeFactory.createFieldTypeFromLogicalType(
- fromDataTypeToLogicalType(local.getOutputDataType())));
} else {
throw new UnsupportedOperationException(
other.getClass().getSimpleName() + ":" + other.toString());
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java
index 08ebe72f9e425..09c3cf159dd97 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java
@@ -36,7 +36,6 @@
import org.apache.flink.table.expressions.ExpressionDefaultVisitor;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
-import org.apache.flink.table.expressions.TableReferenceExpression;
import org.apache.flink.table.expressions.ValueLiteralExpression;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
@@ -68,7 +67,6 @@
import org.apache.flink.table.planner.calcite.FlinkContext;
import org.apache.flink.table.planner.calcite.FlinkRelBuilder;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
-import org.apache.flink.table.planner.calcite.RexTableArgCall;
import org.apache.flink.table.planner.connectors.DynamicSourceUtils;
import org.apache.flink.table.planner.expressions.RexNodeExpression;
import org.apache.flink.table.planner.expressions.SqlAggFunctionVisitor;
@@ -302,43 +300,14 @@ public RelNode visit(FunctionQueryOperation functionTable) {
final RelDataType outputRelDataType =
typeFactory.buildRelNodeRowType((RowType) outputType);
- final List inputStack = new ArrayList<>();
final List rexNodeArgs =
resolvedArgs.stream()
- .map(
- resolvedArg -> {
- if (resolvedArg instanceof TableReferenceExpression) {
- final TableReferenceExpression tableRef =
- (TableReferenceExpression) resolvedArg;
- final LogicalType tableArgType =
- tableRef.getOutputDataType().getLogicalType();
- final RelDataType rowType =
- typeFactory.buildRelNodeRowType(
- (RowType) tableArgType);
- final int[] partitionKeys;
- if (tableRef.getQueryOperation()
- instanceof PartitionQueryOperation) {
- final PartitionQueryOperation partitionOperation =
- (PartitionQueryOperation)
- tableRef.getQueryOperation();
- partitionKeys =
- partitionOperation.getPartitionKeys();
- } else {
- partitionKeys = new int[0];
- }
- final RexTableArgCall tableArgCall =
- new RexTableArgCall(
- rowType,
- inputStack.size(),
- partitionKeys,
- new int[0]);
- inputStack.add(relBuilder.build());
- return tableArgCall;
- }
- return convertExprToRexNode(resolvedArg);
- })
+ .map(QueryOperationConverter.this::convertExprToRexNode)
.collect(Collectors.toList());
+ final List inputStack = expressionConverter.copyInputStack();
+ expressionConverter.clearInputStack();
+
// relBuilder.build() works in LIFO fashion, this restores the original input order
Collections.reverse(inputStack);
@@ -547,7 +516,6 @@ else if (other instanceof DataStreamQueryOperation) {
dataStreamQueryOperation.getResolvedSchema(),
dataStreamQueryOperation.getIdentifier());
}
-
throw new TableException("Unknown table operation: " + other);
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java
index d0b294ed5ea70..0016ec9893d4a 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java
@@ -114,6 +114,10 @@ public RelWriter explainTerms(RelWriter pw) {
.item("rowType", getRowType());
}
+ public RexNode getMLPredictCall() {
+ return scan.getCall();
+ }
+
private MLPredictSpec buildMLPredictSpec(Map runtimeConfig) {
RexTableArgCall tableCall = extractOperand(operand -> operand instanceof RexTableArgCall);
RexCall descriptorCall =
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java
index 875a088ec591f..c34dda9ea6b79 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java
@@ -43,6 +43,7 @@
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalLegacyTableSourceScan;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalLimit;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalLookupJoin;
+import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMLPredictTableFunction;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMatch;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMiniBatchAssigner;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMultiJoin;
@@ -198,6 +199,9 @@ public StreamPhysicalRel visit(
(StreamPhysicalWindowTableFunction) rel, requireDeterminism);
} else if (rel instanceof StreamPhysicalDeltaJoin) {
return visitDeltaJoin((StreamPhysicalDeltaJoin) rel, requireDeterminism);
+ } else if (rel instanceof StreamPhysicalMLPredictTableFunction) {
+ return visitMLPredictTableFunction(
+ (StreamPhysicalMLPredictTableFunction) rel, requireDeterminism);
} else if (rel instanceof StreamPhysicalChangelogNormalize
|| rel instanceof StreamPhysicalDropUpdateBefore
|| rel instanceof StreamPhysicalMiniBatchAssigner
@@ -328,6 +332,16 @@ private StreamPhysicalRel visitCalc(
}
}
+ private StreamPhysicalRel visitMLPredictTableFunction(
+ final StreamPhysicalMLPredictTableFunction predictTableFunction,
+ final ImmutableBitSet requireDeterminism) {
+ if (!inputInsertOnly(predictTableFunction) && !requireDeterminism.isEmpty()) {
+ throwNonDeterministicConditionError(
+ "ML_PREDICT", predictTableFunction.getMLPredictCall(), predictTableFunction);
+ }
+ return transmitDeterminismRequirement(predictTableFunction, NO_REQUIRED_DETERMINISM);
+ }
+
private StreamPhysicalRel visitCorrelate(
final StreamPhysicalCorrelateBase correlate, final ImmutableBitSet requireDeterminism) {
if (inputInsertOnly(correlate) || requireDeterminism.isEmpty()) {
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FunctionCallUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FunctionCallUtil.java
index 18091c4ae1341..5c9f26032d8c3 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FunctionCallUtil.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FunctionCallUtil.java
@@ -35,9 +35,11 @@
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeName;
+import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
import java.util.HashMap;
import java.util.List;
@@ -45,10 +47,15 @@
import java.util.Objects;
import static org.apache.calcite.sql.SqlKind.MAP_VALUE_CONSTRUCTOR;
+import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;
+import static org.apache.flink.table.types.logical.LogicalTypeFamily.CHARACTER_STRING;
/** Common utils for function call, e.g. ML_PREDICT and Lookup Join. */
public abstract class FunctionCallUtil {
+ private static final String CONFIG_ERROR_MESSAGE =
+ "Config parameter should be a MAP data type consisting String literals.";
+
/** A field used as an equal condition when querying content from a dimension table. */
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type")
@JsonSubTypes({
@@ -225,18 +232,41 @@ public static Map convert(RexCall mapConstructor) {
for (int i = 0; i < mapConstructor.getOperands().size(); i += 2) {
RexNode keyNode = mapConstructor.getOperands().get(i);
RexNode valueNode = mapConstructor.getOperands().get(i + 1);
- // Both key and value should be string literals
- if (!(keyNode instanceof RexLiteral) || !(valueNode instanceof RexLiteral)) {
- throw new ValidationException(
- "Config parameter should be a MAP data type consisting String literals.");
- }
- String key = RexLiteral.stringValue(keyNode);
- String value = RexLiteral.stringValue(valueNode);
+ String key = getStringLiteral(keyNode);
+ String value = getStringLiteral(valueNode);
reducedConfig.put(key, value);
}
return reducedConfig;
}
+ private static String getStringLiteral(RexNode node) {
+ // Cast from string to string is used when Expressions.lit(Map(...)) is used as config map
+ // from table api
+ if (node instanceof RexCall && node.getKind() == SqlKind.CAST) {
+ final RexCall castCall = (RexCall) node;
+ // Unwrap CAST if present
+ final RexNode castOperand = castCall.getOperands().get(0);
+ if (!(castOperand instanceof RexLiteral)) {
+ throw new ValidationException(CONFIG_ERROR_MESSAGE);
+ }
+ final RelDataType operandType = castOperand.getType();
+ if (!toLogicalType(operandType).is(CHARACTER_STRING)) {
+ throw new ValidationException(CONFIG_ERROR_MESSAGE);
+ }
+ final RelDataType castType = castCall.getType();
+ if (!toLogicalType(castType).is(CHARACTER_STRING)) {
+ throw new ValidationException(CONFIG_ERROR_MESSAGE);
+ }
+ return RexLiteral.stringValue(castOperand);
+ }
+ // Both key and value should be string literals
+ if (!(node instanceof RexLiteral)) {
+ throw new ValidationException(CONFIG_ERROR_MESSAGE);
+ }
+
+ return RexLiteral.stringValue(node);
+ }
+
public static String explainFunctionParam(FunctionParam param, List fieldNames) {
if (param instanceof Constant) {
return RelExplainUtil.literalToString(((Constant) param).literal);
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSemanticTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSemanticTest.java
index 73ef260d56bd0..aeb8c7028a87c 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSemanticTest.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSemanticTest.java
@@ -20,6 +20,7 @@
import org.apache.flink.table.operations.QueryOperation;
import org.apache.flink.table.planner.plan.nodes.exec.testutils.SemanticTestBase;
+import org.apache.flink.table.test.program.FailingTableApiTestStep;
import org.apache.flink.table.test.program.TableApiTestStep;
import org.apache.flink.table.test.program.TableTestProgram;
import org.apache.flink.table.test.program.TestStep;
@@ -57,7 +58,11 @@ public List programs() {
QueryOperationTestPrograms.OVER_WINDOW_LAG,
QueryOperationTestPrograms.ACCESSING_NESTED_COLUMN,
QueryOperationTestPrograms.ROW_SEMANTIC_TABLE_PTF,
- QueryOperationTestPrograms.SET_SEMANTIC_TABLE_PTF);
+ QueryOperationTestPrograms.SET_SEMANTIC_TABLE_PTF,
+ QueryOperationTestPrograms.ML_PREDICT_MODEL_API,
+ QueryOperationTestPrograms.ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG,
+ QueryOperationTestPrograms.ASYNC_ML_PREDICT_MODEL_API,
+ QueryOperationTestPrograms.ML_PREDICT_ANON_MODEL_API);
}
@Override
@@ -65,6 +70,9 @@ protected void runStep(TestStep testStep, TableEnvironment env) throws Exception
if (testStep instanceof TableApiTestStep) {
final TableApiTestStep tableApiStep = (TableApiTestStep) testStep;
tableApiStep.applyAsSql(env).await();
+ } else if (testStep instanceof FailingTableApiTestStep) {
+ final FailingTableApiTestStep failingTableApiStep = (FailingTableApiTestStep) testStep;
+ failingTableApiStep.applyAsSql(env);
} else {
super.runStep(testStep, env);
}
@@ -72,6 +80,6 @@ protected void runStep(TestStep testStep, TableEnvironment env) throws Exception
@Override
public EnumSet supportedRunSteps() {
- return EnumSet.of(TestKind.TABLE_API, TestKind.SQL);
+ return EnumSet.of(TestKind.TABLE_API, TestKind.SQL, TestKind.FAILING_TABLE_API);
}
}
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
index 53315269c2f08..7f8bb251f931b 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
@@ -19,9 +19,11 @@
package org.apache.flink.table.api;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.config.ExecutionConfigOptions;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.operations.QueryOperation;
+import org.apache.flink.table.planner.factories.TestValuesModelFactory;
import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.ChainedReceivingFunction;
import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.ChainedSendingFunction;
import org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.RowSemanticTableFunction;
@@ -30,6 +32,7 @@
import org.apache.flink.table.test.program.SinkTestStep;
import org.apache.flink.table.test.program.SourceTestStep;
import org.apache.flink.table.test.program.TableTestProgram;
+import org.apache.flink.types.ColumnList;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;
@@ -39,6 +42,7 @@
import java.time.LocalTime;
import java.time.ZoneId;
import java.util.Collections;
+import java.util.Map;
import static org.apache.flink.table.api.Expressions.$;
import static org.apache.flink.table.api.Expressions.UNBOUNDED_ROW;
@@ -49,6 +53,10 @@
import static org.apache.flink.table.api.Expressions.lit;
import static org.apache.flink.table.api.Expressions.nullOf;
import static org.apache.flink.table.api.Expressions.row;
+import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_MODEL;
+import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SIMPLE_FEATURES_SOURCE;
+import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SIMPLE_SINK;
+import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SYNC_MODEL;
import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.BASE_SINK_SCHEMA;
import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.BASIC_VALUES;
import static org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.KEYED_TIMED_BASE_SINK_SCHEMA;
@@ -1117,6 +1125,97 @@ private static Instant dayOfSeconds(int second) {
"sink")
.build();
+ public static final TableTestProgram ML_PREDICT_MODEL_API =
+ TableTestProgram.of("ml-predict-model-api", "ml-predict using model API")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(SYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .runTableApi(
+ env ->
+ env.fromModel("chatgpt")
+ .predict(
+ env.from("features"), ColumnList.of("feature")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ASYNC_ML_PREDICT_MODEL_API =
+ TableTestProgram.of("async-ml-predict-model-api", "async ml-predict using model API")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(ASYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .setupConfig(
+ ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
+ ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED)
+ .runTableApi(
+ env ->
+ env.fromModel("chatgpt")
+ .predict(
+ env.from("features"),
+ ColumnList.of("feature"),
+ Map.of(
+ "async",
+ "true",
+ "max-concurrent-operations",
+ "10")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ML_PREDICT_ANON_MODEL_API =
+ TableTestProgram.of(
+ "ml-predict-anonymous-model-api",
+ "ml-predict using anonymous model API")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .runFailingTableApi(
+ env ->
+ env.from(
+ ModelDescriptor.forProvider("values")
+ .inputSchema(
+ Schema.newBuilder()
+ .column(
+ "feature",
+ "STRING")
+ .build())
+ .outputSchema(
+ Schema.newBuilder()
+ .column(
+ "category",
+ "STRING")
+ .build())
+ .option(
+ "data-id",
+ TestValuesModelFactory
+ .registerData(
+ SYNC_MODEL
+ .data))
+ .build())
+ .predict(
+ env.from("features"), ColumnList.of("feature")),
+ "sink",
+ ValidationException.class,
+ "Anonymous models cannot be serialized.")
+ .build();
+
+ public static final TableTestProgram ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG =
+ TableTestProgram.of(
+ "async-ml-predict-table-api-map-expression-config",
+ "ml-predict in async mode using Table API and map expression.")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(ASYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .setupConfig(
+ ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
+ ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED)
+ .runTableApi(
+ env ->
+ env.fromCall(
+ "ML_PREDICT",
+ env.from("features").asArgument("INPUT"),
+ env.fromModel("chatgpt").asArgument("MODEL"),
+ descriptor("feature").asArgument("ARGS"),
+ Expressions.map("async", "true").asArgument("CONFIG")),
+ "sink")
+ .build();
+
/**
* A function that will be used as an inline function in {@link #INLINE_FUNCTION_SERIALIZATION}.
*/
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictSemanticTests.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictSemanticTests.java
new file mode 100644
index 0000000000000..cb15ff4533d8e
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictSemanticTests.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.nodes.exec.stream;
+
+import org.apache.flink.table.planner.plan.nodes.exec.testutils.SemanticTestBase;
+import org.apache.flink.table.test.program.TableTestProgram;
+
+import java.util.List;
+
+import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_ML_PREDICT_MODEL_API;
+import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_ML_PREDICT_TABLE_API;
+import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG;
+import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ML_PREDICT_ANON_MODEL_API;
+import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ML_PREDICT_MODEL_API;
+import static org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SYNC_ML_PREDICT_TABLE_API;
+
+/** Semantic tests for {@link StreamExecMLPredictTableFunction} using Table API. */
+public class MLPredictSemanticTests extends SemanticTestBase {
+
+ @Override
+ public List programs() {
+ return List.of(
+ SYNC_ML_PREDICT_TABLE_API,
+ ASYNC_ML_PREDICT_TABLE_API,
+ ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG,
+ ML_PREDICT_MODEL_API,
+ ASYNC_ML_PREDICT_MODEL_API,
+ ML_PREDICT_ANON_MODEL_API);
+ }
+}
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictTestPrograms.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictTestPrograms.java
index 7d430e143b21e..26d4903a3e41d 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictTestPrograms.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictTestPrograms.java
@@ -18,11 +18,17 @@
package org.apache.flink.table.planner.plan.nodes.exec.stream;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Expressions;
+import org.apache.flink.table.api.ModelDescriptor;
+import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.config.ExecutionConfigOptions;
+import org.apache.flink.table.planner.factories.TestValuesModelFactory;
import org.apache.flink.table.test.program.ModelTestStep;
import org.apache.flink.table.test.program.SinkTestStep;
import org.apache.flink.table.test.program.SourceTestStep;
import org.apache.flink.table.test.program.TableTestProgram;
+import org.apache.flink.types.ColumnList;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;
@@ -30,6 +36,8 @@
import java.util.List;
import java.util.Map;
+import static org.apache.flink.table.api.Expressions.descriptor;
+
/** Programs for verifying {@link StreamExecMLPredictTableFunction}. */
public class MLPredictTestPrograms {
@@ -48,7 +56,13 @@ public class MLPredictTestPrograms {
Row.ofKind(RowKind.INSERT, 4, "Mysql"), Row.ofKind(RowKind.INSERT, 5, "Postgres")
};
- static final SourceTestStep FEATURES_TABLE =
+ public static final SourceTestStep SIMPLE_FEATURES_SOURCE =
+ SourceTestStep.newBuilder("features")
+ .addSchema(FEATURES_SCHEMA)
+ .producedValues(FEATURES_BEFORE_DATA)
+ .build();
+
+ static final SourceTestStep RESTORE_FEATURES_TABLE =
SourceTestStep.newBuilder("features")
.addSchema(FEATURES_SCHEMA)
.producedBeforeRestore(FEATURES_BEFORE_DATA)
@@ -61,7 +75,7 @@ public class MLPredictTestPrograms {
static final String[] MODEL_OUTPUT_SCHEMA = new String[] {"category STRING"};
static final Map> MODEL_DATA =
- new HashMap>() {
+ new HashMap<>() {
{
put(
Row.ofKind(RowKind.INSERT, "Flink"),
@@ -82,14 +96,14 @@ public class MLPredictTestPrograms {
}
};
- static final ModelTestStep SYNC_MODEL =
+ public static final ModelTestStep SYNC_MODEL =
ModelTestStep.newBuilder("chatgpt")
.addInputSchema(MODEL_INPUT_SCHEMA)
.addOutputSchema(MODEL_OUTPUT_SCHEMA)
.data(MODEL_DATA)
.build();
- static final ModelTestStep ASYNC_MODEL =
+ public static final ModelTestStep ASYNC_MODEL =
ModelTestStep.newBuilder("chatgpt")
.addInputSchema(MODEL_INPUT_SCHEMA)
.addOutputSchema(MODEL_OUTPUT_SCHEMA)
@@ -102,7 +116,7 @@ public class MLPredictTestPrograms {
static final String[] SINK_SCHEMA =
new String[] {"id INT PRIMARY KEY NOT ENFORCED", "feature STRING", "category STRING"};
- static final SinkTestStep SINK_TABLE =
+ static final SinkTestStep RESTORE_SINK_TABLE =
SinkTestStep.newBuilder("sink_t")
.addSchema(SINK_SCHEMA)
.consumedBeforeRestore(
@@ -112,22 +126,31 @@ public class MLPredictTestPrograms {
.consumedAfterRestore("+I[4, Mysql, Database]", "+I[5, Postgres, Database]")
.build();
+ public static final SinkTestStep SIMPLE_SINK =
+ SinkTestStep.newBuilder("sink")
+ .addSchema(SINK_SCHEMA)
+ .consumedValues(
+ "+I[1, Flink, Big Data]",
+ "+I[2, Spark, Big Data]",
+ "+I[3, Hive, Big Data]")
+ .build();
+
// -------------------------------------------------------------------------------------------
public static final TableTestProgram SYNC_ML_PREDICT =
TableTestProgram.of("sync-ml-predict", "ml-predict in sync mode.")
- .setupTableSource(FEATURES_TABLE)
+ .setupTableSource(RESTORE_FEATURES_TABLE)
.setupModel(SYNC_MODEL)
- .setupTableSink(SINK_TABLE)
+ .setupTableSink(RESTORE_SINK_TABLE)
.runSql(
"INSERT INTO sink_t SELECT * FROM ML_PREDICT(TABLE features, MODEL chatgpt, DESCRIPTOR(feature))")
.build();
public static final TableTestProgram ASYNC_UNORDERED_ML_PREDICT =
TableTestProgram.of("async-unordered-ml-predict", "ml-predict in async unordered mode.")
- .setupTableSource(FEATURES_TABLE)
+ .setupTableSource(RESTORE_FEATURES_TABLE)
.setupModel(ASYNC_MODEL)
- .setupTableSink(SINK_TABLE)
+ .setupTableSink(RESTORE_SINK_TABLE)
.setupConfig(
ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED)
@@ -139,11 +162,148 @@ public class MLPredictTestPrograms {
TableTestProgram.of(
"sync-ml-predict-with-runtime-options",
"ml-predict in sync mode with runtime config.")
- .setupTableSource(FEATURES_TABLE)
+ .setupTableSource(RESTORE_FEATURES_TABLE)
.setupModel(ASYNC_MODEL)
- .setupTableSink(SINK_TABLE)
+ .setupTableSink(RESTORE_SINK_TABLE)
.runSql(
"INSERT INTO sink_t SELECT * FROM ML_PREDICT(TABLE features, MODEL chatgpt, DESCRIPTOR(feature), MAP['async', 'false'])")
.build();
- ;
+
+ public static final TableTestProgram SYNC_ML_PREDICT_TABLE_API =
+ TableTestProgram.of(
+ "sync-ml-predict-table-api", "ml-predict in sync mode using Table API.")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(SYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .runTableApi(
+ env ->
+ env.fromCall(
+ "ML_PREDICT",
+ env.from("features").asArgument("INPUT"),
+ env.fromModel("chatgpt").asArgument("MODEL"),
+ descriptor("feature").asArgument("ARGS")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ASYNC_ML_PREDICT_TABLE_API =
+ TableTestProgram.of(
+ "async-ml-predict-table-api",
+ "ml-predict in async mode using Table API.")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(ASYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .setupConfig(
+ ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
+ ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED)
+ .runTableApi(
+ env ->
+ env.fromCall(
+ "ML_PREDICT",
+ env.from("features").asArgument("INPUT"),
+ env.fromModel("chatgpt").asArgument("MODEL"),
+ descriptor("feature").asArgument("ARGS"),
+ Expressions.lit(
+ Map.of("async", "true"),
+ DataTypes.MAP(
+ DataTypes.STRING(),
+ DataTypes.STRING())
+ .notNull())
+ .asArgument("CONFIG")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG =
+ TableTestProgram.of(
+ "async-ml-predict-table-api-map-expression-config",
+ "ml-predict in async mode using Table API and map expression.")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(ASYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .setupConfig(
+ ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
+ ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED)
+ .runTableApi(
+ env ->
+ env.fromCall(
+ "ML_PREDICT",
+ env.from("features").asArgument("INPUT"),
+ env.fromModel("chatgpt").asArgument("MODEL"),
+ descriptor("feature").asArgument("ARGS"),
+ Expressions.map(
+ "async",
+ "true",
+ "max-concurrent-operations",
+ "10")
+ .asArgument("CONFIG")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ML_PREDICT_MODEL_API =
+ TableTestProgram.of("ml-predict-model-api", "ml-predict using model API")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(SYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .runTableApi(
+ env ->
+ env.fromModel("chatgpt")
+ .predict(
+ env.from("features"), ColumnList.of("feature")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ASYNC_ML_PREDICT_MODEL_API =
+ TableTestProgram.of("async-ml-predict-model-api", "async ml-predict using model API")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(ASYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .setupConfig(
+ ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
+ ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED)
+ .runTableApi(
+ env ->
+ env.fromModel("chatgpt")
+ .predict(
+ env.from("features"),
+ ColumnList.of("feature"),
+ Map.of(
+ "async",
+ "true",
+ "max-concurrent-operations",
+ "10")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ML_PREDICT_ANON_MODEL_API =
+ TableTestProgram.of(
+ "ml-predict-anonymous-model-api",
+ "ml-predict using anonymous model API")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupTableSink(SIMPLE_SINK)
+ .runTableApi(
+ env ->
+ env.from(
+ ModelDescriptor.forProvider("values")
+ .inputSchema(
+ Schema.newBuilder()
+ .column(
+ "feature",
+ "STRING")
+ .build())
+ .outputSchema(
+ Schema.newBuilder()
+ .column(
+ "category",
+ "STRING")
+ .build())
+ .option(
+ "data-id",
+ TestValuesModelFactory
+ .registerData(
+ SYNC_MODEL
+ .data))
+ .build())
+ .predict(
+ env.from("features"), ColumnList.of("feature")),
+ "sink")
+ .build();
}
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/SemanticTestBase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/SemanticTestBase.java
index 9c65cec88c8d4..4ba4c3cc8651e 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/SemanticTestBase.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/SemanticTestBase.java
@@ -22,10 +22,13 @@
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import org.apache.flink.table.planner.factories.TestValuesModelFactory;
import org.apache.flink.table.planner.factories.TestValuesTableFactory;
import org.apache.flink.table.test.program.ConfigOptionTestStep;
import org.apache.flink.table.test.program.FailingSqlTestStep;
+import org.apache.flink.table.test.program.FailingTableApiTestStep;
import org.apache.flink.table.test.program.FunctionTestStep;
+import org.apache.flink.table.test.program.ModelTestStep;
import org.apache.flink.table.test.program.SinkTestStep;
import org.apache.flink.table.test.program.SourceTestStep;
import org.apache.flink.table.test.program.SqlTestStep;
@@ -62,6 +65,7 @@ public abstract class SemanticTestBase implements TableTestProgramRunner {
public EnumSet supportedSetupSteps() {
return EnumSet.of(
TestKind.CONFIG,
+ TestKind.MODEL,
TestKind.SOURCE_WITH_DATA,
TestKind.SINK_WITH_DATA,
TestKind.FUNCTION,
@@ -70,7 +74,8 @@ public EnumSet supportedSetupSteps() {
@Override
public EnumSet supportedRunSteps() {
- return EnumSet.of(TestKind.SQL, TestKind.FAILING_SQL, TestKind.TABLE_API);
+ return EnumSet.of(
+ TestKind.SQL, TestKind.FAILING_SQL, TestKind.TABLE_API, TestKind.FAILING_TABLE_API);
}
@AfterEach
@@ -145,6 +150,22 @@ protected void runStep(TestStep testStep, TableEnvironment env) throws Exception
sqlTestStep.apply(env);
}
break;
+ case FAILING_TABLE_API:
+ {
+ final FailingTableApiTestStep tableApiTestStep =
+ (FailingTableApiTestStep) testStep;
+ tableApiTestStep.apply(env);
+ }
+ break;
+ case MODEL:
+ {
+ final ModelTestStep modelTestStep = (ModelTestStep) testStep;
+ final Map options = new HashMap<>();
+ options.put("provider", "values");
+ options.put("data-id", TestValuesModelFactory.registerData(modelTestStep.data));
+ modelTestStep.apply(env, options);
+ }
+ break;
case TABLE_API:
{
final TableApiTestStep apiTestStep = (TableApiTestStep) testStep;
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableEnvironmentTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableEnvironmentTest.scala
index 6e9b9285bec6b..8bd7ca7869cab 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableEnvironmentTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableEnvironmentTest.scala
@@ -24,7 +24,7 @@ import org.apache.flink.core.testutils.FlinkAssertions.anyCauseMatches
import org.apache.flink.sql.parser.error.SqlValidateException
import org.apache.flink.streaming.api.environment.{LocalStreamEnvironment, StreamExecutionEnvironment}
import org.apache.flink.table.api.bridge.scala._
-import org.apache.flink.table.api.internal.TableEnvironmentInternal
+import org.apache.flink.table.api.internal.{ModelImpl, TableEnvironmentInternal}
import org.apache.flink.table.catalog._
import org.apache.flink.table.factories.{TableFactoryUtil, TableSourceFactoryContextImpl}
import org.apache.flink.table.functions.TestGenericUDF
@@ -3250,6 +3250,28 @@ class TableEnvironmentTest {
checkData(util.Arrays.asList(Row.of("your_model")).iterator(), tableResult3.collect())
}
+ @Test
+ def testGetNonExistModel(): Unit = {
+ assertThatThrownBy(() => tableEnv.fromModelPath("MyModel"))
+ .hasMessageContaining("Model `MyModel` was not found")
+ .isInstanceOf[ValidationException]
+ }
+
+ @Test
+ def testGetModel(): Unit = {
+ val inputSchema = Schema.newBuilder().column("feature", DataTypes.STRING()).build()
+
+ val outputSchema = Schema.newBuilder().column("response", DataTypes.DOUBLE()).build()
+ tableEnv.createModel(
+ "MyModel",
+ ModelDescriptor
+ .forProvider("openai")
+ .inputSchema(inputSchema)
+ .outputSchema(outputSchema)
+ .build())
+ assertThat(tableEnv.fromModelPath("MyModel")).isInstanceOf(classOf[ModelImpl])
+ }
+
@Test
def testTemporaryOperationListener(): Unit = {
val listener = new ListenerCatalog("listener_cat")
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
index f3d7b1023ed41..e4fca40e7f54e 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
@@ -37,7 +37,7 @@ import org.apache.flink.table.api.bridge.java.{StreamTableEnvironment => JavaStr
import org.apache.flink.table.api.bridge.scala.{StreamTableEnvironment => ScalaStreamTableEnv}
import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions}
import org.apache.flink.table.api.config.OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE
-import org.apache.flink.table.api.internal.{StatementSetImpl, TableEnvironmentImpl, TableEnvironmentInternal, TableImpl}
+import org.apache.flink.table.api.internal._
import org.apache.flink.table.api.typeutils.CaseClassTypeInfo
import org.apache.flink.table.catalog._
import org.apache.flink.table.connector.ChangelogMode