diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/CriteriaUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/CriteriaUnitTests.java index e32ea8c840..47249ee836 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/CriteriaUnitTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/CriteriaUnitTests.java @@ -19,6 +19,8 @@ import static org.springframework.data.relational.core.query.Criteria.*; import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; import org.assertj.core.api.SoftAssertions; import org.junit.jupiter.api.Test; @@ -31,6 +33,7 @@ * * @author Mark Paluch * @author Mingyuan Wu + * @author Zhengyu Wu */ class CriteriaUnitTests { @@ -290,4 +293,17 @@ void shouldBuildIsFalseCriteria() { assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.IS_FALSE); } + + @Test + void shouldBuildCustomCriteria() { + List<String> values = List.of("bar", "baz", "qux"); + Criteria criteria = where("foo").custom("@>", + value -> "ARRAY[" + values.stream().map(s -> "'" + s + "'").collect(Collectors.joining(",")) + "]", + List.of("bar", "baz", "qux")); + + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); + assertThat(criteria.getComparator()).isEqualTo(Comparator.CUSTOM); + assertThat(criteria.toString()).isEqualTo("foo @> ARRAY['bar','baz','qux']"); + } + } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/Criteria.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/Criteria.java index d4a6687ff1..27ccc6c738 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/Criteria.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/Criteria.java @@ -21,7 +21,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.StringJoiner; +import java.util.Objects; +import java.util.function.Function; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.data.relational.core.sql.IdentifierProcessing; @@ -52,6 +53,7 @@ * @author Oliver Drotbohm * @author Roman Chigvintsev * @author Jens Schauder + * @author Zhengyu Wu * @since 2.0 */ public class Criteria implements CriteriaDefinition { @@ -456,47 +458,8 @@ private void render(CriteriaDefinition criteria, StringBuilder stringBuilder) { stringBuilder.append(criteria.getColumn().toSql(IdentifierProcessing.NONE)).append(' ') .append(criteria.getComparator().getComparator()); - switch (criteria.getComparator()) { - case BETWEEN: - case NOT_BETWEEN: - Pair<Object, Object> pair = (Pair<Object, Object>) criteria.getValue(); - stringBuilder.append(' ').append(pair.getFirst()).append(" AND ").append(pair.getSecond()); - break; - - case IS_NULL: - case IS_NOT_NULL: - case IS_TRUE: - case IS_FALSE: - break; - - case IN: - case NOT_IN: - stringBuilder.append(" (").append(renderValue(criteria.getValue())).append(')'); - break; - - default: - stringBuilder.append(' ').append(renderValue(criteria.getValue())); - } - } - - private static String renderValue(@Nullable Object value) { - - if (value instanceof Number) { - return value.toString(); - } - - if (value instanceof Collection) { - - StringJoiner joiner = new StringJoiner(", "); - ((Collection<?>) value).forEach(o -> joiner.add(renderValue(o))); - return joiner.toString(); - } - - if (value != null) { - return String.format("'%s'", value); - } - - return "null"; + String renderValue = criteria.getComparator().render(criteria.getValue()); + stringBuilder.append(renderValue); } /** @@ -630,6 +593,8 @@ public interface CriteriaStep { * @return a new {@link Criteria} object */ Criteria isFalse(); + + Criteria custom(String comparator, Function<Object, String> renderFunc, Object value); } /** @@ -789,6 +754,11 @@ public Criteria isFalse() { return createCriteria(Comparator.IS_FALSE, false); } + @Override + public Criteria custom(String comparator, Function<Object, String> renderFunc, Object value) { + return createCriteria(Comparator.CUSTOM.setCustomComparator(comparator, renderFunc), value); + } + protected Criteria createCriteria(Comparator comparator, @Nullable Object value) { return new Criteria(this.property, comparator, value); } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/CriteriaDefinition.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/CriteriaDefinition.java index 042755a425..ffc0cea155 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/CriteriaDefinition.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/CriteriaDefinition.java @@ -16,9 +16,13 @@ package org.springframework.data.relational.core.query; import java.util.Arrays; +import java.util.Collection; import java.util.List; +import java.util.StringJoiner; +import java.util.function.Function; import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.data.util.Pair; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -28,6 +32,7 @@ * * @author Mark Paluch * @author Jens Schauder + * @author Zhengyu Wu * @since 2.0 */ public interface CriteriaDefinition { @@ -139,17 +144,89 @@ enum Combinator { enum Comparator { INITIAL(""), EQ("="), NEQ("!="), BETWEEN("BETWEEN"), NOT_BETWEEN("NOT BETWEEN"), LT("<"), LTE("<="), GT(">"), GTE( - ">="), IS_NULL("IS NULL"), IS_NOT_NULL("IS NOT NULL"), LIKE( - "LIKE"), NOT_LIKE("NOT LIKE"), NOT_IN("NOT IN"), IN("IN"), IS_TRUE("IS TRUE"), IS_FALSE("IS FALSE"); + ">="), IS_NULL("IS NULL"), IS_NOT_NULL("IS NOT NULL"), LIKE("LIKE"), NOT_LIKE( + "NOT LIKE"), NOT_IN("NOT IN"), IN("IN"), IS_TRUE("IS TRUE"), IS_FALSE("IS FALSE"), CUSTOM("CUSTOM"); private final String comparator; + private Function<Object, String> customRenderValueFunc = Object::toString; + private String customComparator = ""; Comparator(String comparator) { this.comparator = comparator; } + Comparator setCustomComparator(String customComparator, Function<Object, String> customRenderValueFunc) { + if (this == CUSTOM) { + this.customComparator = customComparator; + this.customRenderValueFunc = customRenderValueFunc; + } else { + throw new UnsupportedOperationException("Only CUSTOM comparator can be customized."); + } + return this; + } + public String getComparator() { + if (this == CUSTOM) { + if (customComparator.isEmpty()) { + throw new UnsupportedOperationException("CUSTOM comparator must be customized."); + } + return customComparator; + } return comparator; } + + private static String renderValue(@Nullable Object value) { + + if (value instanceof Number) { + return value.toString(); + } + + if (value instanceof Collection) { + + StringJoiner joiner = new StringJoiner(", "); + ((Collection<?>) value).forEach(o -> joiner.add(renderValue(o))); + return joiner.toString(); + } + + if (value != null) { + return String.format("'%s'", value); + } + + return "null"; + } + + public String render(Object value) { + StringBuilder stringBuilder = new StringBuilder(); + switch (this) { + case BETWEEN: + case NOT_BETWEEN: + if (value instanceof Pair<?, ?> pair) { + stringBuilder.append(pair.getFirst()).append(" AND ").append(pair.getSecond()); + } else { + throw new IllegalArgumentException("Value must be a Pair"); + } + break; + + case IS_NULL: + case IS_NOT_NULL: + case IS_TRUE: + case IS_FALSE: + break; + + case IN: + case NOT_IN: + stringBuilder.append(" (").append(renderValue(value)).append(')'); + break; + case CUSTOM: + stringBuilder.append(' ').append(customRenderValueFunc.apply(value)); + break; + + default: + stringBuilder.append(' ').append(renderValue(value)); + } + return stringBuilder.toString(); + } + } + } diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/core/query/CriteriaUnitTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/core/query/CriteriaUnitTests.java index ce94b2f6a3..ceef79edd4 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/core/query/CriteriaUnitTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/core/query/CriteriaUnitTests.java @@ -20,6 +20,8 @@ import static org.springframework.data.relational.core.query.Criteria.*; import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; import org.junit.jupiter.api.Test; @@ -31,6 +33,7 @@ * @author Mark Paluch * @author Jens Schauder * @author Roman Chigvintsev + * @author Zhengyu Wu */ class CriteriaUnitTests { @@ -297,4 +300,17 @@ void shouldBuildIsFalseCriteria() { assertThat(criteria.getComparator()).isEqualTo(CriteriaDefinition.Comparator.IS_FALSE); assertThat(criteria.getValue()).isEqualTo(false); } + + @Test + void shouldBuildCustomCriteria() { + List<String> values = List.of("bar", "baz", "qux"); + Criteria criteria = where("foo").custom("@>", + value -> "ARRAY[" + values.stream().map(s -> "'" + s + "'").collect(Collectors.joining(",")) + "]", + List.of("bar", "baz", "qux")); + + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); + assertThat(criteria.getComparator()).isEqualTo(Comparator.CUSTOM); + assertThat(criteria.toString()).isEqualTo("foo @> ARRAY['bar','baz','qux']"); + } + }