Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ public O visit(Expression.MultiOrList expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.NestedList expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(FieldReference expr, C context) throws E {
return visitFallback(expr, context);
Expand Down
41 changes: 41 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ default boolean nullable() {
}
}

interface Nested extends Expression {
@Value.Default
default boolean nullable() {
return false;
}
}

<R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E;

Expand Down Expand Up @@ -922,6 +929,40 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

/**
* A nested list expression with one or more elements.
*
* <p>Note: This class cannot be used to construct an empty list. To create an empty list, use
* {@link ExpressionCreator#emptyList(boolean, Type)} which returns an {@link EmptyListLiteral}.
*/
@Value.Immutable
abstract class NestedList implements Nested {
public abstract List<Expression> values();

@Value.Check
protected void check() {
assert !values().isEmpty() : "To specify an empty list, use ExpressionCreator.emptyList()";

assert values().stream().map(Expression::getType).distinct().count() <= 1
: "All values in NestedList must have the same type";
}

@Override
public Type getType() {
return Type.withNullability(nullable()).list(values().get(0).getType());
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}

public static ImmutableExpression.NestedList.Builder builder() {
return ImmutableExpression.NestedList.builder();
}
}

@Value.Immutable
abstract class MultiOrListRecord {
public abstract List<Expression> values();
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/java/io/substrait/expression/ExpressionCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,17 @@ public static Expression.StructLiteral struct(boolean nullable, Expression.Liter
return Expression.StructLiteral.builder().nullable(nullable).addFields(values).build();
}

/**
* Creates a nested list expression with one or more elements.
*
* <p>Note: This class cannot be used to construct an empty list. To create an empty list, use
* {@link ExpressionCreator#emptyList(boolean, Type)} which returns an {@link
* Expression.EmptyListLiteral}.
*/
public static Expression.NestedList nestedList(boolean nullable, List<Expression> values) {
return Expression.NestedList.builder().nullable(nullable).addAllValues(values).build();
}

public static Expression.StructLiteral struct(
boolean nullable, Iterable<? extends Expression.Literal> values) {
return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr

R visit(Expression.MultiOrList expr, C context) throws E;

R visit(Expression.NestedList expr, C context) throws E;

R visit(FieldReference expr, C context) throws E;

R visit(Expression.SetPredicate expr, C context) throws E;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,22 @@ public Expression visit(
.build();
}

@Override
public Expression visit(
io.substrait.expression.Expression.NestedList expr, EmptyVisitationContext context)
throws RuntimeException {

List<Expression> values =
expr.values().stream().map(this::toProto).collect(Collectors.toList());

return Expression.newBuilder()
.setNested(
Expression.Nested.newBuilder()
.setList(Expression.Nested.List.newBuilder().addAllValues(values))
.setNullable(expr.nullable()))
.build();
}

@Override
public Expression visit(FieldReference expr, EmptyVisitationContext context) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ public Expression from(io.substrait.proto.Expression expr) {
multiOrList.getValueList().stream().map(this::from).collect(Collectors.toList()))
.build();
}
case NESTED:
return from(expr.getNested());
case CAST:
return ExpressionCreator.cast(
protoTypeConverter.from(expr.getCast().getType()),
Expand Down Expand Up @@ -361,6 +363,18 @@ private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.B
}
}

public Expression.Nested from(io.substrait.proto.Expression.Nested nested) {
switch (nested.getNestedTypeCase()) {
case LIST:
List<Expression> list =
nested.getList().getValuesList().stream().map(this::from).collect(Collectors.toList());
return ExpressionCreator.nestedList(nested.getNullable(), list);
default:
throw new UnsupportedOperationException(
"Unimplemented nested type: " + nested.getNestedTypeCase());
}
}

public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
switch (literal.getLiteralTypeCase()) {
case BOOLEAN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,16 @@ public Optional<Expression> visit(
.build());
}

@Override
public Optional<Expression> visit(Expression.NestedList expr, EmptyVisitationContext context)
throws E {
Optional<List<Expression>> expressions = visitExprList(expr.values(), context);

return expressions.map(
expressionList ->
Expression.NestedList.builder().from(expr).values(expressionList).build());
}

protected Optional<Expression.MultiOrListRecord> visitMultiOrListRecord(
Expression.MultiOrListRecord multiOrListRecord, EmptyVisitationContext context) throws E {
return visitExprList(multiOrListRecord.values(), context)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package io.substrait.type.proto;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;

import io.substrait.TestBase;
import io.substrait.expression.Expression;
import io.substrait.expression.ImmutableExpression;
import org.junit.jupiter.api.Test;

class NestedListExpressionTest extends TestBase {
io.substrait.expression.Expression literalExpression =
Expression.BoolLiteral.builder().value(true).build();
Expression.ScalarFunctionInvocation nonLiteralExpression = b.add(b.i32(7), b.i32(42));

@Test
void rejectNestedListWithElementsOfDifferentTypes() {
ImmutableExpression.NestedList.Builder builder =
Expression.NestedList.builder().addValues(literalExpression).addValues(b.i32(12));
assertThrows(AssertionError.class, builder::build);
}

@Test
void acceptNestedListWithElementsOfSameType() {
ImmutableExpression.NestedList.Builder builder =
Expression.NestedList.builder().addValues(nonLiteralExpression).addValues(b.i32(12));
assertDoesNotThrow(builder::build);

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.addExpressions(builder.build())
.input(b.emptyScan())
.build();
verifyRoundTrip(project);
}

@Test
void rejectEmptyNestedListTest() {
ImmutableExpression.NestedList.Builder builder = Expression.NestedList.builder();
assertThrows(AssertionError.class, builder::build);
}

@Test
void literalNestedListTest() {
Expression.NestedList literalNestedList =
Expression.NestedList.builder()
.addValues(literalExpression)
.addValues(literalExpression)
.build();

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.addExpressions(literalNestedList)
.input(b.emptyScan())
.build();

verifyRoundTrip(project);
}

@Test
void literalNullableNestedListTest() {
Expression.NestedList literalNestedList =
Expression.NestedList.builder()
.addValues(literalExpression)
.addValues(literalExpression)
.nullable(true)
.build();

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.addExpressions(literalNestedList)
.input(b.emptyScan())
.build();

verifyRoundTrip(project);
}

@Test
void nonLiteralNestedListTest() {
Expression.NestedList nonLiteralNestedList =
Expression.NestedList.builder()
.addValues(nonLiteralExpression)
.addValues(nonLiteralExpression)
.build();

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.addExpressions(nonLiteralNestedList)
.input(b.emptyScan())
.build();

verifyRoundTrip(project);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.examples.util;

import io.substrait.expression.Expression;
import io.substrait.expression.Expression.BinaryLiteral;
import io.substrait.expression.Expression.BoolLiteral;
import io.substrait.expression.Expression.Cast;
Expand Down Expand Up @@ -256,6 +257,12 @@ public String visit(MultiOrList expr, EmptyVisitationContext context) throws Run
return sb.toString();
}

@Override
public String visit(Expression.NestedList expr, EmptyVisitationContext context)
throws RuntimeException {
return "<NestedList>";
}

@Override
public String visit(FieldReference expr, EmptyVisitationContext context) throws RuntimeException {
StringBuilder sb = new StringBuilder("FieldRef#");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ public Rel visit(org.apache.calcite.rel.core.Project project) {
.map(this::toExpression)
.collect(java.util.stream.Collectors.toList());

// if there are no input fields, no remap is necessary
if (project.getInput().getRowType().getFieldCount() == 0) {
return Project.builder().expressions(expressions).input(apply(project.getInput())).build();
}

// todo: eliminate excessive projects. This should be done by converting rexinputrefs to remaps.
return Project.builder()
.remap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ public static List<CallConverter> defaults(TypeConverter typeConverter) {
CallConverters.CASE,
CallConverters.CAST.apply(typeConverter),
CallConverters.REINTERPRET.apply(typeConverter),
new LiteralConstructorConverter(typeConverter));
new SqlArrayValueConstructorCallConverter(typeConverter),
new SqlMapValueConstructorCallConverter());
}

public interface SimpleCallConverter extends CallConverter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,24 @@ public RexNode visit(Expression.ListLiteral expr, Context context) throws Runtim
return rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, args);
}

@Override
public RexNode visit(Expression.NestedList expr, Context context) {
List<RexNode> args =
expr.values().stream().map(e -> e.accept(this, context)).collect(Collectors.toList());

// to preserve NestedList nullability
RelDataType elementType;
if (args.isEmpty()) {
throw new IllegalStateException("NestedList must have at least 1 element");
} else {
elementType = args.get(0).getType();
}
RelDataType nestedListType = typeFactory.createArrayType(elementType, -1);
nestedListType = typeFactory.createTypeWithNullability(nestedListType, expr.nullable());

return rexBuilder.makeCall(nestedListType, SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, args);
}

@Override
public RexNode visit(Expression.EmptyListLiteral expr, Context context) throws RuntimeException {
RelDataType calciteType = typeConverter.toCalcite(typeFactory, expr.getType());
Expand Down
Loading
Loading