Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
return switch (connectorBehavior) {
case SUPPORTS_UPDATE -> true;
case SUPPORTS_ADD_COLUMN_WITH_POSITION,
SUPPORTS_CREATE_MATERIALIZED_VIEW,
SUPPORTS_CREATE_VIEW,
SUPPORTS_DEFAULT_COLUMN_VALUE,
SUPPORTS_MERGE,
SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN,
SUPPORTS_ROW_LEVEL_UPDATE -> false;
SUPPORTS_CREATE_MATERIALIZED_VIEW,
SUPPORTS_CREATE_VIEW,
SUPPORTS_DEFAULT_COLUMN_VALUE,
SUPPORTS_MERGE,
SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN,
SUPPORTS_ROW_LEVEL_UPDATE -> false;
// Dynamic filters can be pushed down only if predicate push down is supported.
// It is possible for a connector to have predicate push down support but not push down dynamic filters.
// TODO default SUPPORTS_DYNAMIC_FILTER_PUSHDOWN to SUPPORTS_PREDICATE_PUSHDOWN
Expand Down Expand Up @@ -621,15 +621,15 @@ public void testNumericAggregationPushdown()
assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + emptyTable.getName())).isFullyPushedDown();
assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + emptyTable.getName())).isFullyPushedDown();
assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + emptyTable.getName())).isFullyPushedDown();
assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + emptyTable.getName())).isFullyPushedDown();
assertNumericAveragePushdown(emptyTable);
}

try (TestTable testTable = createAggregationTestTable(schemaName + ".test_num_agg_pd",
ImmutableList.of("100.000, 100000000.000000000, 100.000, 100000000", "123.321, 123456789.987654321, 123.321, 123456789"))) {
assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + testTable.getName())).isFullyPushedDown();
assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + testTable.getName())).isFullyPushedDown();
assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + testTable.getName())).isFullyPushedDown();
assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + testTable.getName())).isFullyPushedDown();
assertNumericAveragePushdown(testTable);

// smoke testing of more complex cases
// WHERE on aggregation column
Expand All @@ -647,6 +647,11 @@ public void testNumericAggregationPushdown()
}
}

protected void assertNumericAveragePushdown(TestTable testTable)
{
assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + testTable.getName())).isFullyPushedDown();
}

@Test
public void testCountDistinctWithStringTypes()
{
Expand Down Expand Up @@ -1153,12 +1158,12 @@ public void testArithmeticPredicatePushdown()

assertThat(query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % nationkey = 2"))
.isFullyPushedDown()
.matches("VALUES (BIGINT '3', CAST('CANADA' AS varchar(25)), BIGINT '1')");
.matches(getArithmeticPredicatePushdownExpectedValues());

// some databases calculate remainder instead of modulus when one of the values is negative
assertThat(query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % -nationkey = 2"))
.isFullyPushedDown()
.matches("VALUES (BIGINT '3', CAST('CANADA' AS varchar(25)), BIGINT '1')");
.matches(getArithmeticPredicatePushdownExpectedValues());

assertThat(query("SELECT nationkey, name, regionkey FROM nation WHERE nationkey > 0 AND (nationkey - regionkey) % 0 = 2"))
.failure().hasMessageContaining("by zero");
Expand All @@ -1170,6 +1175,11 @@ public void testArithmeticPredicatePushdown()
// TODO add coverage for other arithmetic pushdowns https://github.com/trinodb/trino/issues/14808
}

protected String getArithmeticPredicatePushdownExpectedValues()
{
return "VALUES (BIGINT '3', CAST('CANADA' AS varchar(25)), BIGINT '1')";
}

@Test
public void testCaseSensitiveTopNPushdown()
{
Expand Down Expand Up @@ -1307,7 +1317,8 @@ public void testJoinPushdown()
assertThat(query(session, format("SELECT n.name FROM nation n %s orders o ON DATE '2025-03-19' = o.orderdate", joinOperator))).joinIsNotFullyPushedDown();

// no projection on the probe side, only filter
assertJoinConditionallyPushedDown(session, format("SELECT n.name FROM nation n %s orders o ON n.regionkey = 1", joinOperator),
// reduced the size of the join table to make the test faster: instead of joining on the large orders table, it is joined on only one record
assertJoinConditionallyPushedDown(session, format("SELECT n.name FROM nation n %s (SELECT * FROM orders WHERE orderkey = 1) o ON n.regionkey = 1", joinOperator),
expectJoinPushdownOnEmptyProjection(joinOperator));

// pushdown when using USING
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import com.google.common.collect.ImmutableSet;
import com.google.inject.Inject;
import io.airlift.slice.Slices;
import io.trino.plugin.base.aggregation.AggregateFunctionRewriter;
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.plugin.base.mapping.IdentifierMapping;
import io.trino.plugin.jdbc.BaseJdbcClient;
import io.trino.plugin.jdbc.BaseJdbcConfig;
Expand All @@ -35,6 +38,25 @@
import io.trino.plugin.jdbc.SliceWriteFunction;
import io.trino.plugin.jdbc.WriteFunction;
import io.trino.plugin.jdbc.WriteMapping;
import io.trino.plugin.jdbc.aggregation.ImplementAvgDecimal;
import io.trino.plugin.jdbc.aggregation.ImplementAvgFloatingPoint;
import io.trino.plugin.jdbc.aggregation.ImplementCorr;
import io.trino.plugin.jdbc.aggregation.ImplementCount;
import io.trino.plugin.jdbc.aggregation.ImplementCountAll;
import io.trino.plugin.jdbc.aggregation.ImplementCountDistinct;
import io.trino.plugin.jdbc.aggregation.ImplementCovariancePop;
import io.trino.plugin.jdbc.aggregation.ImplementCovarianceSamp;
import io.trino.plugin.jdbc.aggregation.ImplementMinMax;
import io.trino.plugin.jdbc.aggregation.ImplementRegrIntercept;
import io.trino.plugin.jdbc.aggregation.ImplementRegrSlope;
import io.trino.plugin.jdbc.aggregation.ImplementStddevPop;
import io.trino.plugin.jdbc.aggregation.ImplementStddevSamp;
import io.trino.plugin.jdbc.aggregation.ImplementSum;
import io.trino.plugin.jdbc.aggregation.ImplementVariancePop;
import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp;
import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.plugin.jdbc.expression.RewriteIn;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.AggregateFunction;
Expand All @@ -43,7 +65,10 @@
import io.trino.spi.connector.ColumnPosition;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;

import java.sql.Connection;
import java.sql.Date;
Expand All @@ -64,7 +89,10 @@
import static io.trino.plugin.jdbc.StandardColumnMappings.defaultVarcharColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.doubleColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.integerColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.shortDecimalWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.smallintColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.varcharWriteFunction;
import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.getUnsupportedTypeHandling;
import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
Expand All @@ -83,6 +111,9 @@ public class ExasolClient
.add("EXA_STATISTICS")
.add("SYS")
.build();
public static final int MAX_EXASOL_DECIMAL_PRECISION = 36;
private final ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter;
private final AggregateFunctionRewriter<JdbcExpression, ?> aggregateFunctionRewriter;

@Inject
public ExasolClient(
Expand All @@ -93,6 +124,57 @@ public ExasolClient(
RemoteQueryModifier queryModifier)
{
super("\"", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, false);
// Basic implementation required to enable JOIN and AGGREGATION pushdown support
// It is covered by "testJoinPushdown" and "testAggregationPushdown" integration tests.
// More detailed test case scenarios are covered by Unit tests in "TestExasolClient"
this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder()
.addStandardRules(this::quoted)
.add(new RewriteIn())
.withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double"))
.map("$equal(left, right)").to("left = right")
.map("$not_equal(left, right)").to("left <> right")
// Exasol doesn't support "IS NOT DISTINCT FROM" expression,
// so "$identical(left, right)" is rewritten with equivalent "(left = right OR (left IS NULL AND right IS NULL))" expression
.map("$identical(left, right)").to("(left = right OR (left IS NULL AND right IS NULL))")
.map("$less_than(left, right)").to("left < right")
.map("$less_than_or_equal(left, right)").to("left <= right")
.map("$greater_than(left, right)").to("left > right")
.map("$greater_than_or_equal(left, right)").to("left >= right")
.map("$not($is_null(value))").to("value IS NOT NULL")
.map("$not(value: boolean)").to("NOT value")
.map("$is_null(value)").to("value IS NULL")
.map("$add(left: numeric_type, right: numeric_type)").to("left + right")
.map("$subtract(left: numeric_type, right: numeric_type)").to("left - right")
.map("$multiply(left: numeric_type, right: numeric_type)").to("left * right")
.map("$divide(left: numeric_type, right: numeric_type)").to("left / right")
.map("$modulus(left: numeric_type, right: numeric_type)").to("mod(left, right)")
.map("$negate(value: numeric_type)").to("-value")
.map("$like(value: varchar, pattern: varchar): boolean").to("value LIKE pattern")
.map("$like(value: varchar, pattern: varchar, escape: varchar(1)): boolean").to("value LIKE pattern ESCAPE escape")
.map("$nullif(first, second)").to("NULLIF(first, second)")
.build();
JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>(
this.connectorExpressionRewriter,
ImmutableSet.<AggregateFunctionRule<JdbcExpression, ParameterizedExpression>>builder()
.add(new ImplementCountAll(bigintTypeHandle))
.add(new ImplementMinMax(true))
.add(new ImplementCount(bigintTypeHandle))
.add(new ImplementCountDistinct(bigintTypeHandle, true))
.add(new ImplementSum(ExasolClient::toSumTypeHandle))
.add(new ImplementAvgFloatingPoint())
.add(new ImplementAvgDecimal())
.add(new ImplementExasolAvgBigInt())
.add(new ImplementStddevSamp())
.add(new ImplementStddevPop())
.add(new ImplementVarianceSamp())
.add(new ImplementVariancePop())
.add(new ImplementCovarianceSamp())
.add(new ImplementCovariancePop())
.add(new ImplementCorr())
.add(new ImplementRegrIntercept())
.add(new ImplementRegrSlope())
.build());
}

@Override
Expand Down Expand Up @@ -194,18 +276,34 @@ protected void renameTable(ConnectorSession session, Connection connection, Stri
throw new TrinoException(NOT_SUPPORTED, "This connector does not support renaming tables");
}

@Override
public Optional<ParameterizedExpression> convertPredicate(ConnectorSession session, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
{
return connectorExpressionRewriter.rewrite(session, expression, assignments);
}

@Override
protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition)
{
// Deactivated because test 'testJoinPushdown()' requires write access which is not implemented for Exasol
return false;
return true;
}

@Override
public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder)
{
return true;
}

@Override
public Optional<JdbcExpression> implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map<String, ColumnHandle> assignments)
{
// Deactivated because test 'testCaseSensitiveAggregationPushdown()' requires write access which is not implemented for Exasol
return Optional.empty();
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

@Override
public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets)
{
return true;
}

@Override
Expand All @@ -230,9 +328,9 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
case Types.DOUBLE:
return Optional.of(doubleColumnMapping());
case Types.DECIMAL:
int decimalDigits = typeHandle.requiredDecimalDigits();
int columnSize = typeHandle.requiredColumnSize();
return Optional.of(decimalColumnMapping(createDecimalType(columnSize, decimalDigits)));
int precision = typeHandle.requiredColumnSize();
int scale = typeHandle.requiredDecimalDigits();
return Optional.of(decimalColumnMapping(createDecimalType(precision, scale)));
case Types.CHAR:
return Optional.of(defaultCharColumnMapping(typeHandle.requiredColumnSize(), true));
case Types.VARCHAR:
Expand All @@ -256,6 +354,12 @@ private boolean isHashType(JdbcTypeHandle typeHandle)
&& typeHandle.jdbcTypeName().get().equalsIgnoreCase("HASHTYPE");
}

private static Optional<JdbcTypeHandle> toSumTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(Types.DECIMAL, Optional.of("decimal"),
Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
}

private static ColumnMapping dateColumnMapping()
{
// Exasol driver does not support LocalDate
Expand Down Expand Up @@ -310,7 +414,25 @@ private static SliceWriteFunction hashTypeWriteFunction()
@Override
public WriteMapping toWriteMapping(ConnectorSession session, Type type)
{
throw new TrinoException(NOT_SUPPORTED, "This connector does not support writing");
if (type instanceof DecimalType decimalType) {
String dataType = "decimal(%s, %s)".formatted(decimalType.getPrecision(), decimalType.getScale());
if (decimalType.isShort()) {
return WriteMapping.longMapping(dataType, shortDecimalWriteFunction(decimalType));
}
return WriteMapping.objectMapping(dataType, longDecimalWriteFunction(decimalType));
}
if (type instanceof VarcharType varcharType) {
String dataType;
if (varcharType.isUnbounded()) {
dataType = "varchar";
}
else {
dataType = "varchar(" + varcharType.getBoundedLength() + ")";
}
return WriteMapping.sliceMapping(dataType, varcharWriteFunction());
}

throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName());
}

@Override
Expand Down Expand Up @@ -357,10 +479,4 @@ public boolean isLimitGuaranteed(ConnectorSession session)
{
return true;
}

@Override
public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder)
{
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.exasol;

import io.trino.plugin.jdbc.aggregation.BaseImplementAvgBigint;

public class ImplementExasolAvgBigInt
extends BaseImplementAvgBigint
{
@Override
protected String getRewriteFormatExpression()
{
return "avg(CAST(%s AS double))";
}
}
Loading