Skip to content

Commit 420cab8

Browse files
impl: Add @SQL annotation support for R2DBC in Spring tests
Signed-off-by: dev-jonghoonpark <[email protected]>
1 parent 121be15 commit 420cab8

File tree

11 files changed

+381
-2
lines changed

11 files changed

+381
-2
lines changed

spring-test/spring-test.gradle

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies {
88
optional(project(":spring-beans"))
99
optional(project(":spring-context"))
1010
optional(project(":spring-jdbc"))
11+
optional(project(":spring-r2dbc"))
1112
optional(project(":spring-orm"))
1213
optional(project(":spring-tx"))
1314
optional(project(":spring-web"))
@@ -80,6 +81,7 @@ dependencies {
8081
testImplementation("org.hibernate.orm:hibernate-core")
8182
testImplementation("org.hibernate.validator:hibernate-validator")
8283
testImplementation("org.hsqldb:hsqldb")
84+
testImplementation("io.r2dbc:r2dbc-h2")
8385
testImplementation("org.junit.platform:junit-platform-testkit")
8486
testRuntimeOnly("com.sun.xml.bind:jaxb-core")
8587
testRuntimeOnly("com.sun.xml.bind:jaxb-impl")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright 2002-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.context.jdbc;
18+
19+
import java.util.List;
20+
21+
import io.r2dbc.spi.ConnectionFactory;
22+
import reactor.core.publisher.Mono;
23+
24+
import org.springframework.core.io.Resource;
25+
import org.springframework.r2dbc.connection.init.ResourceDatabasePopulator;
26+
27+
/**
28+
* R2dbcPopulatorUtils is a separate class to avoid name conflicts with existing
29+
* jdbc-related classes.
30+
*
31+
* <p><b>NOTE:</b> In the current architecture, MergedSqlConfig is implemented
32+
* as a package-private method, so it has been placed in
33+
* org.springframework.test.context.jdbc.
34+
*
35+
* @author jonghoon park
36+
* @since 7.0
37+
* @see SqlScriptsTestExecutionListener
38+
* @see MergedSqlConfig
39+
*/
40+
public abstract class R2dbcPopulatorUtils {
41+
42+
static void execute(MergedSqlConfig mergedSqlConfig, ConnectionFactory connectionFactory, List<Resource> scriptResources) {
43+
ResourceDatabasePopulator populator = createResourceDatabasePopulator(mergedSqlConfig);
44+
populator.setScripts(scriptResources.toArray(new Resource[0]));
45+
46+
Mono.from(connectionFactory.create())
47+
.flatMap(populator::populate)
48+
.block();
49+
}
50+
51+
private static ResourceDatabasePopulator createResourceDatabasePopulator(MergedSqlConfig mergedSqlConfig) {
52+
ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
53+
populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
54+
populator.setSeparator(mergedSqlConfig.getSeparator());
55+
populator.setCommentPrefixes(mergedSqlConfig.getCommentPrefixes());
56+
populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
57+
populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
58+
populator.setContinueOnError(mergedSqlConfig.getErrorMode() == SqlConfig.ErrorMode.CONTINUE_ON_ERROR);
59+
populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == SqlConfig.ErrorMode.IGNORE_FAILED_DROPS);
60+
return populator;
61+
}
62+
}

spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import javax.sql.DataSource;
2626

27+
import io.r2dbc.spi.ConnectionFactory;
2728
import org.apache.commons.logging.Log;
2829
import org.apache.commons.logging.LogFactory;
2930
import org.jspecify.annotations.Nullable;
@@ -45,6 +46,7 @@
4546
import org.springframework.test.context.jdbc.SqlMergeMode.MergeMode;
4647
import org.springframework.test.context.support.AbstractTestExecutionListener;
4748
import org.springframework.test.context.transaction.TestContextTransactionUtils;
49+
import org.springframework.test.context.transaction.reactive.TestContextReactiveTransactionUtils;
4850
import org.springframework.test.context.util.TestContextResourceUtils;
4951
import org.springframework.transaction.PlatformTransactionManager;
5052
import org.springframework.transaction.TransactionDefinition;
@@ -332,8 +334,13 @@ else if (logger.isDebugEnabled()) {
332334
Assert.state(!newTxRequired, () -> String.format("Failed to execute SQL scripts for test context %s: " +
333335
"cannot execute SQL scripts using Transaction Mode " +
334336
"[%s] without a PlatformTransactionManager.", testContext, TransactionMode.ISOLATED));
335-
Assert.state(dataSource != null, () -> String.format("Failed to execute SQL scripts for test context %s: " +
336-
"supply at least a DataSource or PlatformTransactionManager.", testContext));
337+
if (dataSource == null) {
338+
ConnectionFactory connectionFactory = TestContextReactiveTransactionUtils.retrieveConnectionFactory(testContext);
339+
Assert.state(connectionFactory != null, () -> String.format("Failed to execute SQL scripts for test context %s: " +
340+
"supply at least a DataSource or PlatformTransactionManager or ConnectionFactory.", testContext));
341+
R2dbcPopulatorUtils.execute(mergedSqlConfig, connectionFactory, scriptResources);
342+
return;
343+
}
337344
// Execute scripts directly against the DataSource
338345
populator.execute(dataSource);
339346
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright 2002-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.context.transaction.reactive;
18+
19+
import java.util.Map;
20+
21+
import io.r2dbc.spi.Connection;
22+
import io.r2dbc.spi.ConnectionFactory;
23+
import org.apache.commons.logging.Log;
24+
import org.apache.commons.logging.LogFactory;
25+
import org.jspecify.annotations.Nullable;
26+
27+
import org.springframework.beans.BeansException;
28+
import org.springframework.beans.factory.BeanFactory;
29+
import org.springframework.beans.factory.BeanFactoryUtils;
30+
import org.springframework.beans.factory.ListableBeanFactory;
31+
import org.springframework.test.context.TestContext;
32+
import org.springframework.transaction.PlatformTransactionManager;
33+
import org.springframework.util.Assert;
34+
35+
/**
36+
* Utility methods for working with transactions and data access related beans
37+
* within the <em>Spring TestContext Framework</em>.
38+
*
39+
* <p>Mainly for internal use within the framework.
40+
*
41+
* @author jonghoon park
42+
* @since 7.0
43+
*/
44+
public abstract class TestContextReactiveTransactionUtils {
45+
46+
/**
47+
* Default bean name for a {@link ConnectionFactory}:
48+
* {@code "connectionFactory"}.
49+
*/
50+
public static final String DEFAULT_CONNECTION_FACTORY_NAME = "connectionFactory";
51+
52+
53+
private static final Log logger = LogFactory.getLog(TestContextReactiveTransactionUtils.class);
54+
55+
/**
56+
* Retrieve the {@link ConnectionFactory} to use for the supplied {@linkplain TestContext
57+
* test context}.
58+
* <p>The following algorithm is used to retrieve the {@code ConnectionFactory} from
59+
* the {@link org.springframework.context.ApplicationContext ApplicationContext}
60+
* of the supplied test context:
61+
* <ol>
62+
* <li>Attempt to look up the single {@code ConnectionFactory} by type.
63+
* <li>Attempt to look up the <em>primary</em> {@code ConnectionFactory} by type.
64+
* <li>Attempt to look up the {@code ConnectionFactory} by type and the
65+
* {@linkplain #DEFAULT_CONNECTION_FACTORY_NAME default data source name}.
66+
* </ol>
67+
* @param testContext the test context for which the {@code ConnectionFactory}
68+
* should be retrieved; never {@code null}
69+
* @return the {@code DataSource} to use, or {@code null} if not found
70+
*/
71+
@Nullable
72+
public static ConnectionFactory retrieveConnectionFactory(TestContext testContext) {
73+
Assert.notNull(testContext, "TestContext must not be null");
74+
BeanFactory bf = testContext.getApplicationContext().getAutowireCapableBeanFactory();
75+
76+
try {
77+
if (bf instanceof ListableBeanFactory lbf) {
78+
// Look up single bean by type
79+
Map<String, ConnectionFactory> ConnectionFactories =
80+
BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, ConnectionFactory.class);
81+
if (ConnectionFactories.size() == 1) {
82+
return ConnectionFactories.values().iterator().next();
83+
}
84+
85+
try {
86+
// look up single bean by type, with support for 'primary' beans
87+
return bf.getBean(ConnectionFactory.class);
88+
}
89+
catch (BeansException ex) {
90+
logBeansException(testContext, ex, PlatformTransactionManager.class);
91+
}
92+
}
93+
94+
// look up by type and default name
95+
return bf.getBean(DEFAULT_CONNECTION_FACTORY_NAME, ConnectionFactory.class);
96+
}
97+
catch (BeansException ex) {
98+
logBeansException(testContext, ex, Connection.class);
99+
return null;
100+
}
101+
}
102+
103+
private static void logBeansException(TestContext testContext, BeansException ex, Class<?> beanType) {
104+
if (logger.isTraceEnabled()) {
105+
logger.trace("Caught exception while retrieving %s for test context %s"
106+
.formatted(beanType.getSimpleName(), testContext), ex);
107+
}
108+
}
109+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
/**
2+
* JDBC support classes for the <em>Spring TestContext Framework</em>,
3+
* including support for declarative SQL script execution via {@code @Sql}.
4+
*/
5+
@NullMarked
6+
package org.springframework.test.context.transaction.reactive;
7+
8+
import org.jspecify.annotations.NullMarked;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright 2002-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.r2dbc;
18+
19+
import java.util.Objects;
20+
21+
import io.r2dbc.spi.ConnectionFactory;
22+
import org.jspecify.annotations.Nullable;
23+
import reactor.core.publisher.Mono;
24+
25+
import org.springframework.r2dbc.core.DatabaseClient;
26+
import org.springframework.util.StringUtils;
27+
28+
/**
29+
* {@code R2dbcTestUtils} is a collection of R2DBC related utility functions
30+
* intended to simplify standard database testing scenarios.
31+
*
32+
* @author jonghoon park
33+
* @since 7.0
34+
* @see org.springframework.r2dbc.core.DatabaseClient
35+
*/
36+
public abstract class R2dbcTestUtils {
37+
38+
/**
39+
* Count the rows in the given table.
40+
* @param connectionFactory the {@link ConnectionFactory} with which to perform R2DBC
41+
* operations
42+
* @param tableName name of the table to count rows in
43+
* @return the number of rows in the table
44+
*/
45+
public static Mono<Integer> countRowsInTable(ConnectionFactory connectionFactory, String tableName) {
46+
return countRowsInTable(DatabaseClient.create(connectionFactory), tableName);
47+
}
48+
49+
/**
50+
* Count the rows in the given table.
51+
* @param databaseClient the {@link DatabaseClient} with which to perform R2DBC
52+
* operations
53+
* @param tableName name of the table to count rows in
54+
* @return the number of rows in the table
55+
*/
56+
public static Mono<Integer> countRowsInTable(DatabaseClient databaseClient, String tableName) {
57+
return countRowsInTableWhere(databaseClient, tableName, null);
58+
}
59+
60+
/**
61+
* Count the rows in the given table, using the provided {@code WHERE} clause.
62+
* <p>If the provided {@code WHERE} clause contains text, it will be prefixed
63+
* with {@code " WHERE "} and then appended to the generated {@code SELECT}
64+
* statement. For example, if the provided table name is {@code "person"} and
65+
* the provided where clause is {@code "name = 'Bob' and age > 25"}, the
66+
* resulting SQL statement to execute will be
67+
* {@code "SELECT COUNT(0) FROM person WHERE name = 'Bob' and age > 25"}.
68+
* @param databaseClient the {@link DatabaseClient} with which to perform JDBC
69+
* operations
70+
* @param tableName the name of the table to count rows in
71+
* @param whereClause the {@code WHERE} clause to append to the query
72+
* @return the number of rows in the table that match the provided
73+
* {@code WHERE} clause
74+
*/
75+
public static Mono<Integer> countRowsInTableWhere(
76+
DatabaseClient databaseClient, String tableName, @Nullable String whereClause) {
77+
78+
String sql = "SELECT COUNT(0) FROM " + tableName;
79+
if (StringUtils.hasText(whereClause)) {
80+
sql += " WHERE " + whereClause;
81+
}
82+
return databaseClient.sql(sql)
83+
.map(row -> Objects.requireNonNull(row.get(0, Long.class)).intValue())
84+
.one();
85+
}
86+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
/**
2+
* Support classes for tests based on R2DBC.
3+
*/
4+
@NullMarked
5+
package org.springframework.test.r2dbc;
6+
7+
import org.jspecify.annotations.NullMarked;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright 2002-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.context.aot.samples.r2dbc;
18+
19+
import io.r2dbc.spi.ConnectionFactory;
20+
import org.junit.jupiter.api.Test;
21+
import reactor.test.StepVerifier;
22+
23+
import org.springframework.beans.factory.annotation.Autowired;
24+
import org.springframework.test.annotation.DirtiesContext;
25+
import org.springframework.test.context.TestPropertySource;
26+
import org.springframework.test.context.jdbc.Sql;
27+
import org.springframework.test.context.jdbc.SqlMergeMode;
28+
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
29+
import org.springframework.test.context.reactive.EmptyReactiveDatabaseConfig;
30+
31+
import static org.assertj.core.api.Assertions.assertThat;
32+
import static org.springframework.test.context.jdbc.SqlMergeMode.MergeMode.MERGE;
33+
import static org.springframework.test.r2dbc.R2dbcTestUtils.countRowsInTable;
34+
35+
/**
36+
* @author jonghoon park
37+
* @since 7.0
38+
*/
39+
@SpringJUnitConfig(EmptyReactiveDatabaseConfig.class)
40+
@SqlMergeMode(MERGE)
41+
@Sql("/org/springframework/test/context/r2dbc/schema.sql")
42+
@DirtiesContext
43+
@TestPropertySource(properties = "test.engine = jupiter")
44+
public class R2dbcSqlScriptsSpringJupiterTests {
45+
46+
@Test
47+
@Sql // default script --> org/springframework/test/context/aot/samples/r2dbc/R2dbcSqlScriptsSpringJupiterTests.test.sql
48+
void test(@Autowired ConnectionFactory connectionFactory) {
49+
StepVerifier.create(countRowsInTable(connectionFactory, "users"))
50+
.assertNext(count -> assertThat(count).isEqualTo(1))
51+
.verifyComplete();
52+
}
53+
54+
}

0 commit comments

Comments
 (0)