Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# limitations under the License.
################################################################################

from pyflink.testing.test_case_utils import PythonAPICompletenessTestCase, PyFlinkTestCase
from pyflink.table import TableEnvironment
from pyflink.testing.test_case_utils import PythonAPICompletenessTestCase, \
PyFlinkTestCase


class EnvironmentAPICompletenessTests(PythonAPICompletenessTestCase, PyFlinkTestCase):
Expand All @@ -40,6 +41,7 @@ def excluded_methods(cls):
'getCompletionHints',
'fromValues',
'fromCall',
'fromModelPath',
# See FLINK-25986
'loadPlan',
'compilePlanSql',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def excluded_methods(cls):
"from",
"registerFunction",
"fromCall",
"fromModelPath",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create a tickect about this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@ public static ApiExpression descriptor(String... columnNames) {
return new ApiExpression(valueLiteral(ColumnList.of(Arrays.asList(columnNames))));
}

/**
* Creates a literal describing an arbitrary, unvalidated list of column names.
*
* <p>Passing a column list can be useful for parameterizing a function. In particular, it
* enables declaring the {@code on_time} argument for {@link ProcessTableFunction} or the {@code
* inputColumns} for {@link Model#predict}.
*
* <p>The data type will be {@link DataTypes#DESCRIPTOR()}.
*/
public static ApiExpression descriptor(ColumnList columnList) {
return new ApiExpression(valueLiteral(columnList));
}

/**
* Indicates a range from 'start' to 'end', which can be used in columns selection.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.api;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.api.config.MLPredictRuntimeConfigOptions;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.types.ColumnList;

import java.util.Map;

/**
* The {@link Model} object is the core abstraction for ML model resources in the Table API.
*
* <p>A {@link Model} object describes a machine learning model resource that can be used for
* inference operations. It provides methods to perform prediction on data tables.
*
* <p>The {@link Model} interface offers main operations:
*
* <ul>
* <li>{@link #predict(Table, ColumnList)} - Applies the model to make predictions on input data
* </ul>
*
* <p>{@code ml_predict} operation supports runtime options for configuring execution parameters
* such as asynchronous execution mode.
*
* <p>Every {@link Model} object has input and output schemas that describe the expected data
* structure for model operations, available through {@link #getResolvedInputSchema()} and {@link
* #getResolvedOutputSchema()}.
*
* <p>Example usage:
*
* <pre>{@code
* Model model = tableEnv.fromModelPath("my_model");
*
* // Simple prediction
* Table predictions = model.predict(inputTable, ColumnList.of("feature1", "feature2"));
*
* // Prediction with options
* Map<String, String> options = Map.of("max-concurrent-operations", "100", "timeout", "30s", "async", "true");
* Table predictions = model.predict(inputTable, ColumnList.of("feature1", "feature2"), options);
* }</pre>
*/
@PublicEvolving
public interface Model {

/**
* Returns the resolved input schema of this model.
*
* <p>The input schema describes the structure and data types of the input columns that the
* model expects for inference operations.
*
* @return the resolved input schema.
*/
ResolvedSchema getResolvedInputSchema();

/**
* Returns the resolved output schema of this model.
*
* <p>The output schema describes the structure and data types of the output columns that the
* model produces during inference operations.
*
* @return the resolved output schema.
*/
ResolvedSchema getResolvedOutputSchema();

/**
* Performs prediction on the given table using specified input columns.
*
* <p>This method applies the model to the input data to generate predictions. The input columns
* must match the model's expected input schema.
*
* <p>Example:
*
* <pre>{@code
* Table predictions = model.predict(inputTable, ColumnList.of("feature1", "feature2"));
* }</pre>
*
* @param table the input table containing data for prediction
* @param inputColumns the columns from the input table to use as model input
* @return a table containing the input data along with prediction results
*/
Table predict(Table table, ColumnList inputColumns);

/**
* Performs prediction on the given table using specified input columns with runtime options.
*
* <p>This method applies the model to the input data to generate predictions with additional
* runtime configuration options such as max-concurrent-operations, timeout, and execution mode
* settings.
*
* <p>For Common runtime options, see {@link MLPredictRuntimeConfigOptions}.
*
* <p>Example:
*
* <pre>{@code
* Map<String, String> options = Map.of("max-concurrent-operations", "100", "timeout", "30s", "async", "true");
* Table predictions = model.predict(inputTable,
* ColumnList.of("feature1", "feature2"), options);
* }</pre>
*
* @param table the input table containing data for prediction
* @param inputColumns the columns from the input table to use as model input
* @param options runtime options for configuring the prediction operation
* @return a table containing the input data along with prediction results
*/
Table predict(Table table, ColumnList inputColumns, Map<String, String> options);

/**
* Converts this model object into a named argument.
*
* <p>This method is intended for use in function calls that accept model arguments,
* particularly in process table functions (PTFs) or other operations that work with models.
*
* <p>Example:
*
* <pre>{@code
* env.fromCall(
* "ML_PREDICT",
* inputTable.asArgument("INPUT"),
* model.asArgument("MODEL"),
* Expressions.descriptor(ColumnList.of("feature1", "feature2")).asArgument("ARGS")
* )
* }</pre>
*
* @param name the name to assign to this model argument
* @return an expression that can be passed to functions expecting model arguments
*/
ApiExpression asArgument(String name);
}
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,49 @@ void createTemporarySystemFunction(
*/
Table fromCall(Class<? extends UserDefinedFunction> function, Object... arguments);

/**
* Returns a {@link Model} object that is backed by the specified model path.
*
* <p>This method creates a {@link Model} object from a given model path in the catalog. The
* model path can be fully or partially qualified (e.g., "catalog.db.model" or just "model"),
* depending on the current catalog and database context.
*
* <p>The returned {@link Model} object can be used for further transformations or as input to
* other operations in the Table API.
*
* <p>Example:
*
* <pre>{@code
* Model model = tableEnv.fromModelPath("my_model");
* }</pre>
*
* @param modelPath The path of the model in the catalog.
* @return The {@link Model} object describing the model resource.
*/
Model fromModelPath(String modelPath);

/**
* Returns a {@link Model} object that is backed by the specified {@link ModelDescriptor}.
*
* <p>This method creates a {@link Model} object using the provided {@link ModelDescriptor},
* which contains the necessary information to identify and configure the model resource in the
* catalog.
*
* <p>The returned {@link Model} object can be used for further transformations or as input to
* other operations in the Table API.
*
* <p>Example:
*
* <pre>{@code
* ModelDescriptor descriptor = ...;
* Model model = tableEnv.from(descriptor);
* }</pre>
*
* @param descriptor The {@link ModelDescriptor} describing the model resource.
* @return The {@link Model} object representing the model resource.
*/
Model from(ModelDescriptor descriptor);

/**
* Gets the names of all catalogs registered in this environment.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.api.internal;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.ApiExpression;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Model;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.catalog.ContextResolvedModel;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.table.expressions.ApiExpressionUtils;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.types.ColumnList;

import java.util.ArrayList;
import java.util.Map;

import static org.apache.flink.table.api.Expressions.lit;

/** Implementation of {@link Model} that works with the Table API. */
@Internal
public class ModelImpl implements Model {

private final TableEnvironmentInternal tableEnvironment;
private final ContextResolvedModel model;

private ModelImpl(TableEnvironmentInternal tableEnvironment, ContextResolvedModel model) {
this.tableEnvironment = tableEnvironment;
this.model = model;
}

public static ModelImpl createModel(
TableEnvironmentInternal tableEnvironment, ContextResolvedModel model) {
return new ModelImpl(tableEnvironment, model);
}

public ContextResolvedModel getModel() {
return model;
}

@Override
public ResolvedSchema getResolvedInputSchema() {
return model.getResolvedModel().getResolvedInputSchema();
}

@Override
public ResolvedSchema getResolvedOutputSchema() {
return model.getResolvedModel().getResolvedOutputSchema();
}

public TableEnvironment getTableEnv() {
return tableEnvironment;
}

@Override
public Table predict(Table table, ColumnList inputColumns) {
return predict(table, inputColumns, Map.of());
}

@Override
public Table predict(Table table, ColumnList inputColumns, Map<String, String> options) {
// Use Expressions.map() instead of Expressions.lit() to create a MAP literal since
// lit() is not serializable to sql.
if (options.isEmpty()) {
return tableEnvironment.fromCall(
BuiltInFunctionDefinitions.ML_PREDICT.getName(),
table.asArgument("INPUT"),
this.asArgument("MODEL"),
Expressions.descriptor(inputColumns).asArgument("ARGS"));
}
ArrayList<String> configKVs = new ArrayList<>();
options.forEach(
(k, v) -> {
configKVs.add(k);
configKVs.add(v);
});
return tableEnvironment.fromCall(
BuiltInFunctionDefinitions.ML_PREDICT.getName(),
table.asArgument("INPUT"),
this.asArgument("MODEL"),
Expressions.descriptor(inputColumns).asArgument("ARGS"),
Expressions.map(
configKVs.get(0),
configKVs.get(1),
configKVs.subList(2, configKVs.size()).toArray())
.asArgument("CONFIG"));
}

@Override
public ApiExpression asArgument(String name) {
return new ApiExpression(
ApiExpressionUtils.unresolvedCall(
BuiltInFunctionDefinitions.ASSIGNMENT,
lit(name),
ApiExpressionUtils.modelRef(name, this)));
}

public TableEnvironment getTableEnvironment() {
return tableEnvironment;
}
}
Loading