Skip to content

Commit 4834d08

Browse files
committed
Add StatementFilterFunction to R2dbcEntityTemplate.
See #1652
1 parent 90b6d8e commit 4834d08

File tree

2 files changed

+47
-8
lines changed

2 files changed

+47
-8
lines changed

spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java

+29-6
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAw
112112

113113
private @Nullable ReactiveEntityCallbacks entityCallbacks;
114114

115+
private Function<Statement, Statement> statementFilterFunction = Function.identity();
116+
115117
/**
116118
* Create a new {@link R2dbcEntityTemplate} given {@link ConnectionFactory}.
117119
*
@@ -174,6 +176,19 @@ public R2dbcEntityTemplate(DatabaseClient databaseClient, ReactiveDataAccessStra
174176
this.projectionFactory = new SpelAwareProxyProjectionFactory();
175177
}
176178

179+
/**
180+
* Set a {@link Function Statement Filter Function} that is applied to every {@link Statement}.
181+
*
182+
* @param statementFilterFunction must not be {@literal null}.
183+
* @since 3.4
184+
*/
185+
public void setStatementFilterFunction(Function<Statement, Statement> statementFilterFunction) {
186+
187+
Assert.notNull(statementFilterFunction, "StatementFilterFunction must not be null");
188+
189+
this.statementFilterFunction = statementFilterFunction;
190+
}
191+
177192
@Override
178193
public DatabaseClient getDatabaseClient() {
179194
return this.databaseClient;
@@ -274,6 +289,7 @@ Mono<Long> doCount(Query query, Class<?> entityClass, SqlIdentifier tableName) {
274289
PreparedOperation<?> operation = statementMapper.getMappedObject(selectSpec);
275290

276291
return this.databaseClient.sql(operation) //
292+
.filter(statementFilterFunction) //
277293
.map((r, md) -> r.get(0, Long.class)) //
278294
.first() //
279295
.defaultIfEmpty(0L);
@@ -302,6 +318,7 @@ Mono<Boolean> doExists(Query query, Class<?> entityClass, SqlIdentifier tableNam
302318
PreparedOperation<?> operation = statementMapper.getMappedObject(selectSpec);
303319

304320
return this.databaseClient.sql(operation) //
321+
.filter(statementFilterFunction) //
305322
.map((r, md) -> r) //
306323
.first() //
307324
.hasElement();
@@ -362,7 +379,7 @@ private <T> RowsFetchSpec<T> doSelect(Query query, Class<?> entityType, SqlIdent
362379
PreparedOperation<?> operation = statementMapper.getMappedObject(selectSpec);
363380

364381
return getRowsFetchSpec(
365-
databaseClient.sql(operation).filter(filterFunction),
382+
databaseClient.sql(operation).filter(statementFilterFunction.andThen(filterFunction)),
366383
entityType,
367384
returnType
368385
);
@@ -397,7 +414,7 @@ Mono<Long> doUpdate(Query query, Update update, Class<?> entityClass, SqlIdentif
397414
}
398415

399416
PreparedOperation<?> operation = statementMapper.getMappedObject(selectSpec);
400-
return this.databaseClient.sql(operation).fetch().rowsUpdated();
417+
return this.databaseClient.sql(operation).filter(statementFilterFunction).fetch().rowsUpdated();
401418
}
402419

403420
@Override
@@ -422,7 +439,7 @@ Mono<Long> doDelete(Query query, Class<?> entityClass, SqlIdentifier tableName)
422439
}
423440

424441
PreparedOperation<?> operation = statementMapper.getMappedObject(deleteSpec);
425-
return this.databaseClient.sql(operation).fetch().rowsUpdated().defaultIfEmpty(0L);
442+
return this.databaseClient.sql(operation).filter(statementFilterFunction).fetch().rowsUpdated().defaultIfEmpty(0L);
426443
}
427444

428445
// -------------------------------------------------------------------------
@@ -441,7 +458,8 @@ public <T> RowsFetchSpec<T> query(PreparedOperation<?> operation, Class<?> entit
441458
Assert.notNull(operation, "PreparedOperation must not be null");
442459
Assert.notNull(entityClass, "Entity class must not be null");
443460

444-
return new EntityCallbackAdapter<>(getRowsFetchSpec(databaseClient.sql(operation), entityClass, resultType),
461+
return new EntityCallbackAdapter<>(
462+
getRowsFetchSpec(databaseClient.sql(operation).filter(statementFilterFunction), entityClass, resultType),
445463
getTableNameOrEmpty(entityClass));
446464
}
447465

@@ -451,7 +469,8 @@ public <T> RowsFetchSpec<T> query(PreparedOperation<?> operation, BiFunction<Row
451469
Assert.notNull(operation, "PreparedOperation must not be null");
452470
Assert.notNull(rowMapper, "Row mapper must not be null");
453471

454-
return new EntityCallbackAdapter<>(databaseClient.sql(operation).map(rowMapper), SqlIdentifier.EMPTY);
472+
return new EntityCallbackAdapter<>(databaseClient.sql(operation).filter(statementFilterFunction).map(rowMapper),
473+
SqlIdentifier.EMPTY);
455474
}
456475

457476
@Override
@@ -462,7 +481,8 @@ public <T> RowsFetchSpec<T> query(PreparedOperation<?> operation, Class<?> entit
462481
Assert.notNull(entityClass, "Entity class must not be null");
463482
Assert.notNull(rowMapper, "Row mapper must not be null");
464483

465-
return new EntityCallbackAdapter<>(databaseClient.sql(operation).map(rowMapper), getTableNameOrEmpty(entityClass));
484+
return new EntityCallbackAdapter<>(databaseClient.sql(operation).filter(statementFilterFunction).map(rowMapper),
485+
getTableNameOrEmpty(entityClass));
466486
}
467487

468488
// -------------------------------------------------------------------------
@@ -541,6 +561,8 @@ private <T> Mono<T> doInsert(T entity, SqlIdentifier tableName, OutboundRow outb
541561
return this.databaseClient.sql(operation) //
542562
.filter(statement -> {
543563

564+
statement = statementFilterFunction.apply(statement);
565+
544566
if (identifierColumns.isEmpty()) {
545567
return statement.returnGeneratedValues();
546568
}
@@ -632,6 +654,7 @@ private <T> Mono<T> doUpdate(T entity, SqlIdentifier tableName, RelationalPersis
632654
PreparedOperation<?> operation = mapper.getMappedObject(updateSpec);
633655

634656
return this.databaseClient.sql(operation) //
657+
.filter(statementFilterFunction) //
635658
.fetch() //
636659
.rowsUpdated() //
637660
.handle((rowsUpdated, sink) -> {

spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java

+18-2
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,6 @@ void shouldProjectCountResultWithoutId() {
206206
@Test // GH-469
207207
void shouldExistsByCriteria() {
208208

209-
MockRowMetadata metadata = MockRowMetadata.builder()
210-
.columnMetadata(MockColumnMetadata.builder().name("name").type(R2dbcType.VARCHAR).build()).build();
211209
MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Long.class, 1L).build()).build();
212210

213211
recorder.addStubbing(s -> s.startsWith("SELECT"), result);
@@ -654,6 +652,24 @@ void projectDtoShouldReadPropertiesOnce() {
654652
}).verifyComplete();
655653
}
656654

655+
@Test // GH-1652
656+
void shouldConsiderFilterFunction() {
657+
658+
MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Long.class, 1L).build()).build();
659+
660+
recorder.addStubbing(s -> s.startsWith("SELECT"), result);
661+
662+
entityTemplate.setStatementFilterFunction(statement -> statement.fetchSize(10));
663+
entityTemplate.count(Query.empty(), Person.class) //
664+
.as(StepVerifier::create) //
665+
.expectNext(1L) //
666+
.verifyComplete();
667+
668+
StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT"));
669+
670+
assertThat(statement.getFetchSize()).isEqualTo(10);
671+
}
672+
657673
@ReadingConverter
658674
static class PkConverter implements Converter<ByteBuffer, DoubleHolder> {
659675

0 commit comments

Comments
 (0)