Skip to content

Commit ea29642

Browse files
serezakorotaevschauder
authored andcommittedJan 16, 2025
Add Stream support to JdbcAggregateOperations
See #1714 Original pull request #1963 Signed-off-by: Sergey Korotaev <[email protected]>
1 parent 4ef0538 commit ea29642

11 files changed

+475
-2
lines changed
 

‎spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateOperations.java

+44
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.List;
1919
import java.util.Optional;
20+
import java.util.stream.Stream;
2021

2122
import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException;
2223
import org.springframework.data.domain.Example;
@@ -35,6 +36,7 @@
3536
* @author Chirag Tailor
3637
* @author Diego Krupitza
3738
* @author Myeonghyeon Lee
39+
* @author Sergey Korotaev
3840
*/
3941
public interface JdbcAggregateOperations {
4042

@@ -165,6 +167,17 @@ public interface JdbcAggregateOperations {
165167
*/
166168
<T> List<T> findAllById(Iterable<?> ids, Class<T> domainType);
167169

170+
/**
171+
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
172+
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
173+
*
174+
* @param ids the Ids of the entities to load. Must not be {@code null}.
175+
* @param domainType the type of entities to load. Must not be {@code null}.
176+
* @param <T> type of entities to load.
177+
* @return the loaded entities. Guaranteed to be not {@code null}.
178+
*/
179+
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);
180+
168181
/**
169182
* Load all aggregates of a given type.
170183
*
@@ -174,6 +187,15 @@ public interface JdbcAggregateOperations {
174187
*/
175188
<T> List<T> findAll(Class<T> domainType);
176189

190+
/**
191+
* Load all aggregates of a given type to a {@link Stream}.
192+
*
193+
* @param domainType the type of the aggregate roots. Must not be {@code null}.
194+
* @param <T> the type of the aggregate roots. Must not be {@code null}.
195+
* @return Guaranteed to be not {@code null}.
196+
*/
197+
<T> Stream<T> streamAll(Class<T> domainType);
198+
177199
/**
178200
* Load all aggregates of a given type, sorted.
179201
*
@@ -185,6 +207,17 @@ public interface JdbcAggregateOperations {
185207
*/
186208
<T> List<T> findAll(Class<T> domainType, Sort sort);
187209

210+
/**
211+
* Loads all entities of the given type to a {@link Stream}, sorted.
212+
*
213+
* @param domainType the type of entities to load. Must not be {@code null}.
214+
* @param <T> the type of entities to load.
215+
* @param sort the sorting information. Must not be {@code null}.
216+
* @return Guaranteed to be not {@code null}.
217+
* @since 2.0
218+
*/
219+
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);
220+
188221
/**
189222
* Load a page of (potentially sorted) aggregates of a given type.
190223
*
@@ -218,6 +251,17 @@ public interface JdbcAggregateOperations {
218251
*/
219252
<T> List<T> findAll(Query query, Class<T> domainType);
220253

254+
/**
255+
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
256+
*
257+
* @param query must not be {@literal null}.
258+
* @param domainType the type of entities. Must not be {@code null}.
259+
* @return a non-null list with all the matching results.
260+
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
261+
* @since 3.0
262+
*/
263+
<T> Stream<T> streamAll(Query query, Class<T> domainType);
264+
221265
/**
222266
* Returns a {@link Page} of entities matching the given {@link Query}. In case no match could be found, an empty
223267
* {@link Page} is returned.

‎spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java

+34
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Optional;
2626
import java.util.function.Function;
2727
import java.util.stream.Collectors;
28+
import java.util.stream.Stream;
2829
import java.util.stream.StreamSupport;
2930

3031
import org.springframework.context.ApplicationContext;
@@ -68,6 +69,7 @@
6869
* @author Myeonghyeon Lee
6970
* @author Chirag Tailor
7071
* @author Diego Krupitza
72+
* @author Sergey Korotaev
7173
*/
7274
public class JdbcAggregateTemplate implements JdbcAggregateOperations {
7375

@@ -283,6 +285,16 @@ public <T> List<T> findAll(Class<T> domainType, Sort sort) {
283285
return triggerAfterConvert(all);
284286
}
285287

288+
@Override
289+
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
290+
291+
Assert.notNull(domainType, "Domain type must not be null");
292+
293+
Stream<T> allStreamable = accessStrategy.streamAll(domainType, sort);
294+
295+
return allStreamable.map(this::triggerAfterConvert);
296+
}
297+
286298
@Override
287299
public <T> Page<T> findAll(Class<T> domainType, Pageable pageable) {
288300

@@ -307,6 +319,11 @@ public <T> List<T> findAll(Query query, Class<T> domainType) {
307319
return triggerAfterConvert(all);
308320
}
309321

322+
@Override
323+
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
324+
return accessStrategy.streamAll(query, domainType).map(this::triggerAfterConvert);
325+
}
326+
310327
@Override
311328
public <T> Page<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
312329

@@ -325,6 +342,12 @@ public <T> List<T> findAll(Class<T> domainType) {
325342
return triggerAfterConvert(all);
326343
}
327344

345+
@Override
346+
public <T> Stream<T> streamAll(Class<T> domainType) {
347+
Iterable<T> items = triggerAfterConvert(accessStrategy.findAll(domainType));
348+
return StreamSupport.stream(items.spliterator(), false).map(this::triggerAfterConvert);
349+
}
350+
328351
@Override
329352
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
330353

@@ -335,6 +358,17 @@ public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
335358
return triggerAfterConvert(allById);
336359
}
337360

361+
@Override
362+
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
363+
364+
Assert.notNull(ids, "Ids must not be null");
365+
Assert.notNull(domainType, "Domain type must not be null");
366+
367+
Stream<T> allByIdStreamable = accessStrategy.streamAllByIds(ids, domainType);
368+
369+
return allByIdStreamable.map(this::triggerAfterConvert);
370+
}
371+
338372
@Override
339373
public <S> void delete(S aggregateRoot) {
340374

‎spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java

+22
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Optional;
2323
import java.util.function.Consumer;
2424
import java.util.function.Function;
25+
import java.util.stream.Stream;
2526

2627
import org.springframework.data.domain.Pageable;
2728
import org.springframework.data.domain.Sort;
@@ -42,6 +43,7 @@
4243
* @author Myeonghyeon Lee
4344
* @author Chirag Tailor
4445
* @author Diego Krupitza
46+
* @author Sergey Korotaev
4547
* @since 1.1
4648
*/
4749
public class CascadingDataAccessStrategy implements DataAccessStrategy {
@@ -132,11 +134,21 @@ public <T> Iterable<T> findAll(Class<T> domainType) {
132134
return collect(das -> das.findAll(domainType));
133135
}
134136

137+
@Override
138+
public <T> Stream<T> streamAll(Class<T> domainType) {
139+
return collect(das -> das.streamAll(domainType));
140+
}
141+
135142
@Override
136143
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
137144
return collect(das -> das.findAllById(ids, domainType));
138145
}
139146

147+
@Override
148+
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
149+
return collect(das -> das.streamAllByIds(ids, domainType));
150+
}
151+
140152
@Override
141153
public Iterable<Object> findAllByPath(Identifier identifier,
142154
PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
@@ -153,6 +165,11 @@ public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
153165
return collect(das -> das.findAll(domainType, sort));
154166
}
155167

168+
@Override
169+
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
170+
return collect(das -> das.streamAll(domainType, sort));
171+
}
172+
156173
@Override
157174
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
158175
return collect(das -> das.findAll(domainType, pageable));
@@ -168,6 +185,11 @@ public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
168185
return collect(das -> das.findAll(query, domainType));
169186
}
170187

188+
@Override
189+
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
190+
return collect(das -> das.streamAll(query, domainType));
191+
}
192+
171193
@Override
172194
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
173195
return collect(das -> das.findAll(query, domainType, pageable));

‎spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java

+48
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.List;
1919
import java.util.Map;
2020
import java.util.Optional;
21+
import java.util.stream.Stream;
2122

2223
import org.springframework.dao.OptimisticLockingFailureException;
2324
import org.springframework.data.domain.Pageable;
@@ -41,6 +42,7 @@
4142
* @author Myeonghyeon Lee
4243
* @author Chirag Tailor
4344
* @author Diego Krupitza
45+
* @author Sergey Korotaev
4446
*/
4547
public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationResolver {
4648

@@ -252,6 +254,16 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
252254
@Override
253255
<T> Iterable<T> findAll(Class<T> domainType);
254256

257+
/**
258+
* Loads all entities of the given type to a {@link Stream}.
259+
*
260+
* @param domainType the type of entities to load. Must not be {@code null}.
261+
* @param <T> the type of entities to load.
262+
* @return Guaranteed to be not {@code null}.
263+
*/
264+
@Override
265+
<T> Stream<T> streamAll(Class<T> domainType);
266+
255267
/**
256268
* Loads all entities that match one of the ids passed as an argument. It is not guaranteed that the number of ids
257269
* passed in matches the number of entities returned.
@@ -264,6 +276,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
264276
@Override
265277
<T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType);
266278

279+
/**
280+
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
281+
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
282+
*
283+
* @param ids the Ids of the entities to load. Must not be {@code null}.
284+
* @param domainType the type of entities to load. Must not be {@code null}.
285+
* @param <T> type of entities to load.
286+
* @return the loaded entities. Guaranteed to be not {@code null}.
287+
*/
288+
@Override
289+
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);
290+
267291
@Override
268292
Iterable<Object> findAllByPath(Identifier identifier,
269293
PersistentPropertyPath<? extends RelationalPersistentProperty> path);
@@ -280,6 +304,18 @@ Iterable<Object> findAllByPath(Identifier identifier,
280304
@Override
281305
<T> Iterable<T> findAll(Class<T> domainType, Sort sort);
282306

307+
/**
308+
* Loads all entities of the given type to a {@link Stream}, sorted.
309+
*
310+
* @param domainType the type of entities to load. Must not be {@code null}.
311+
* @param <T> the type of entities to load.
312+
* @param sort the sorting information. Must not be {@code null}.
313+
* @return Guaranteed to be not {@code null}.
314+
* @since 2.0
315+
*/
316+
@Override
317+
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);
318+
283319
/**
284320
* Loads all entities of the given type, paged and sorted.
285321
*
@@ -316,6 +352,18 @@ Iterable<Object> findAllByPath(Identifier identifier,
316352
@Override
317353
<T> Iterable<T> findAll(Query query, Class<T> domainType);
318354

355+
/**
356+
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
357+
*
358+
* @param query must not be {@literal null}.
359+
* @param domainType the type of entities. Must not be {@code null}.
360+
* @return a non-null list with all the matching results.
361+
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
362+
* @since 3.0
363+
*/
364+
@Override
365+
<T> Stream<T> streamAll(Query query, Class<T> domainType);
366+
319367
/**
320368
* Execute a {@code SELECT} query and convert the resulting items to a {@link Iterable}. Applies the {@link Pageable}
321369
* to the result.

‎spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java

+34
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Collections;
2323
import java.util.List;
2424
import java.util.Optional;
25+
import java.util.stream.Stream;
2526

2627
import org.springframework.dao.EmptyResultDataAccessException;
2728
import org.springframework.dao.OptimisticLockingFailureException;
@@ -60,6 +61,7 @@
6061
* @author Radim Tlusty
6162
* @author Chirag Tailor
6263
* @author Diego Krupitza
64+
* @author Sergey Korotaev
6365
* @since 1.1
6466
*/
6567
public class DefaultDataAccessStrategy implements DataAccessStrategy {
@@ -276,6 +278,11 @@ public <T> List<T> findAll(Class<T> domainType) {
276278
return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType));
277279
}
278280

281+
@Override
282+
public <T> Stream<T> streamAll(Class<T> domainType) {
283+
return operations.queryForStream(sql(domainType).getFindAll(), new MapSqlParameterSource(), getEntityRowMapper(domainType));
284+
}
285+
279286
@Override
280287
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
281288

@@ -288,6 +295,19 @@ public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
288295
return operations.query(findAllInListSql, parameterSource, getEntityRowMapper(domainType));
289296
}
290297

298+
@Override
299+
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
300+
301+
if (!ids.iterator().hasNext()) {
302+
return Stream.empty();
303+
}
304+
305+
SqlParameterSource parameterSource = sqlParametersFactory.forQueryByIds(ids, domainType);
306+
String findAllInListSql = sql(domainType).getFindAllInList();
307+
308+
return operations.queryForStream(findAllInListSql, parameterSource, getEntityRowMapper(domainType));
309+
}
310+
291311
@Override
292312
@SuppressWarnings("unchecked")
293313
public List<Object> findAllByPath(Identifier identifier,
@@ -342,6 +362,11 @@ public <T> List<T> findAll(Class<T> domainType, Sort sort) {
342362
return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType));
343363
}
344364

365+
@Override
366+
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
367+
return operations.queryForStream(sql(domainType).getFindAll(sort), new MapSqlParameterSource(), getEntityRowMapper(domainType));
368+
}
369+
345370
@Override
346371
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
347372
return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType));
@@ -369,6 +394,15 @@ public <T> List<T> findAll(Query query, Class<T> domainType) {
369394
return operations.query(sqlQuery, parameterSource, getEntityRowMapper(domainType));
370395
}
371396

397+
@Override
398+
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
399+
400+
MapSqlParameterSource parameterSource = new MapSqlParameterSource();
401+
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource);
402+
403+
return operations.queryForStream(sqlQuery, parameterSource, getEntityRowMapper(domainType));
404+
}
405+
372406
@Override
373407
public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
374408

‎spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java

+22
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.List;
1919
import java.util.Optional;
20+
import java.util.stream.Stream;
2021

2122
import org.springframework.data.domain.Pageable;
2223
import org.springframework.data.domain.Sort;
@@ -37,6 +38,7 @@
3738
* @author Myeonghyeon Lee
3839
* @author Chirag Tailor
3940
* @author Diego Krupitza
41+
* @author Sergey Korotaev
4042
* @since 1.1
4143
*/
4244
public class DelegatingDataAccessStrategy implements DataAccessStrategy {
@@ -135,11 +137,21 @@ public <T> Iterable<T> findAll(Class<T> domainType) {
135137
return delegate.findAll(domainType);
136138
}
137139

140+
@Override
141+
public <T> Stream<T> streamAll(Class<T> domainType) {
142+
return delegate.streamAll(domainType);
143+
}
144+
138145
@Override
139146
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
140147
return delegate.findAllById(ids, domainType);
141148
}
142149

150+
@Override
151+
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
152+
return delegate.streamAllByIds(ids, domainType);
153+
}
154+
143155
@Override
144156
public Iterable<Object> findAllByPath(Identifier identifier,
145157
PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
@@ -156,6 +168,11 @@ public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
156168
return delegate.findAll(domainType, sort);
157169
}
158170

171+
@Override
172+
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
173+
return delegate.streamAll(domainType, sort);
174+
}
175+
159176
@Override
160177
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
161178
return delegate.findAll(domainType, pageable);
@@ -171,6 +188,11 @@ public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
171188
return delegate.findAll(query, domainType);
172189
}
173190

191+
@Override
192+
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
193+
return delegate.streamAll(query, domainType);
194+
}
195+
174196
@Override
175197
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
176198
return delegate.findAll(query, domainType, pageable);

‎spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ReadingDataAccessStrategy.java

+44
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.data.jdbc.core.convert;
1818

1919
import java.util.Optional;
20+
import java.util.stream.Stream;
2021

2122
import org.springframework.data.domain.Pageable;
2223
import org.springframework.data.domain.Sort;
@@ -27,6 +28,7 @@
2728
* The finding methods of a {@link DataAccessStrategy}.
2829
*
2930
* @author Jens Schauder
31+
* @author Sergey Korotaev
3032
* @since 3.2
3133
*/
3234
interface ReadingDataAccessStrategy {
@@ -51,6 +53,15 @@ interface ReadingDataAccessStrategy {
5153
*/
5254
<T> Iterable<T> findAll(Class<T> domainType);
5355

56+
/**
57+
* Loads all entities of the given type to a {@link Stream}.
58+
*
59+
* @param domainType the type of entities to load. Must not be {@code null}.
60+
* @param <T> the type of entities to load.
61+
* @return Guaranteed to be not {@code null}.
62+
*/
63+
<T> Stream<T> streamAll(Class<T> domainType);
64+
5465
/**
5566
* Loads all entities that match one of the ids passed as an argument. It is not guaranteed that the number of ids
5667
* passed in matches the number of entities returned.
@@ -62,6 +73,17 @@ interface ReadingDataAccessStrategy {
6273
*/
6374
<T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType);
6475

76+
/**
77+
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
78+
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
79+
*
80+
* @param ids the Ids of the entities to load. Must not be {@code null}.
81+
* @param domainType the type of entities to load. Must not be {@code null}.
82+
* @param <T> type of entities to load.
83+
* @return the loaded entities. Guaranteed to be not {@code null}.
84+
*/
85+
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);
86+
6587
/**
6688
* Loads all entities of the given type, sorted.
6789
*
@@ -73,6 +95,17 @@ interface ReadingDataAccessStrategy {
7395
*/
7496
<T> Iterable<T> findAll(Class<T> domainType, Sort sort);
7597

98+
/**
99+
* Loads all entities of the given type to a {@link Stream}, sorted.
100+
*
101+
* @param domainType the type of entities to load. Must not be {@code null}.
102+
* @param <T> the type of entities to load.
103+
* @param sort the sorting information. Must not be {@code null}.
104+
* @return Guaranteed to be not {@code null}.
105+
* @since 2.0
106+
*/
107+
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);
108+
76109
/**
77110
* Loads all entities of the given type, paged and sorted.
78111
*
@@ -106,6 +139,17 @@ interface ReadingDataAccessStrategy {
106139
*/
107140
<T> Iterable<T> findAll(Query query, Class<T> domainType);
108141

142+
/**
143+
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
144+
*
145+
* @param query must not be {@literal null}.
146+
* @param domainType the type of entities. Must not be {@code null}.
147+
* @return a non-null list with all the matching results.
148+
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
149+
* @since 3.0
150+
*/
151+
<T> Stream<T> streamAll(Query query, Class<T> domainType);
152+
109153
/**
110154
* Execute a {@code SELECT} query and convert the resulting items to a {@link Iterable}. Applies the {@link Pageable}
111155
* to the result.

‎spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java

+22
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.List;
2020
import java.util.Optional;
21+
import java.util.stream.Stream;
2122

2223
import org.springframework.data.domain.Pageable;
2324
import org.springframework.data.domain.Sort;
@@ -32,6 +33,7 @@
3233
*
3334
* @author Jens Schauder
3435
* @author Mark Paluch
36+
* @author Sergey Korotaev
3537
* @since 3.2
3638
*/
3739
class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy {
@@ -56,16 +58,31 @@ public <T> List<T> findAll(Class<T> domainType) {
5658
return aggregateReader.findAll(getPersistentEntity(domainType));
5759
}
5860

61+
@Override
62+
public <T> Stream<T> streamAll(Class<T> domainType) {
63+
throw new UnsupportedOperationException();
64+
}
65+
5966
@Override
6067
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
6168
return aggregateReader.findAllById(ids, getPersistentEntity(domainType));
6269
}
6370

71+
@Override
72+
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
73+
throw new UnsupportedOperationException();
74+
}
75+
6476
@Override
6577
public <T> List<T> findAll(Class<T> domainType, Sort sort) {
6678
throw new UnsupportedOperationException();
6779
}
6880

81+
@Override
82+
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
83+
throw new UnsupportedOperationException();
84+
}
85+
6986
@Override
7087
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
7188
throw new UnsupportedOperationException();
@@ -81,6 +98,11 @@ public <T> List<T> findAll(Query query, Class<T> domainType) {
8198
return aggregateReader.findAll(query, getPersistentEntity(domainType));
8299
}
83100

101+
@Override
102+
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
103+
throw new UnsupportedOperationException();
104+
}
105+
84106
@Override
85107
public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
86108
throw new UnsupportedOperationException();

‎spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java

+38
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
import java.util.List;
2323
import java.util.Map;
2424
import java.util.Optional;
25+
import java.util.stream.Stream;
26+
import java.util.stream.StreamSupport;
2527

28+
import org.apache.ibatis.cursor.Cursor;
2629
import org.apache.ibatis.session.SqlSession;
2730
import org.mybatis.spring.SqlSessionTemplate;
2831
import org.springframework.dao.EmptyResultDataAccessException;
@@ -59,6 +62,7 @@
5962
* @author Chirag Tailor
6063
* @author Christopher Klein
6164
* @author Mikhail Polivakha
65+
* @author Sergey Korotaev
6266
*/
6367
public class MyBatisDataAccessStrategy implements DataAccessStrategy {
6468

@@ -263,12 +267,28 @@ public <T> List<T> findAll(Class<T> domainType) {
263267
return sqlSession().selectList(statement, parameter);
264268
}
265269

270+
@Override
271+
public <T> Stream<T> streamAll(Class<T> domainType) {
272+
String statement = namespace(domainType) + ".streamAll";
273+
MyBatisContext parameter = new MyBatisContext(null, null, domainType, Collections.emptyMap());
274+
Cursor<T> cursor = sqlSession().selectCursor(statement, parameter);
275+
return StreamSupport.stream(cursor.spliterator(), false);
276+
}
277+
266278
@Override
267279
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
268280
return sqlSession().selectList(namespace(domainType) + ".findAllById",
269281
new MyBatisContext(ids, null, domainType, Collections.emptyMap()));
270282
}
271283

284+
@Override
285+
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
286+
String statement = namespace(domainType) + ".streamAllByIds";
287+
MyBatisContext parameter = new MyBatisContext(ids, null, domainType, Collections.emptyMap());
288+
Cursor<T> cursor = sqlSession().selectCursor(statement, parameter);
289+
return StreamSupport.stream(cursor.spliterator(), false);
290+
}
291+
272292
@Override
273293
public List<Object> findAllByPath(Identifier identifier,
274294
PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
@@ -296,6 +316,19 @@ public <T> List<T> findAll(Class<T> domainType, Sort sort) {
296316
new MyBatisContext(null, null, domainType, additionalContext));
297317
}
298318

319+
@Override
320+
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
321+
322+
Map<String, Object> additionalContext = new HashMap<>();
323+
additionalContext.put("sort", sort);
324+
325+
String statement = namespace(domainType) + ".streamAllSorted";
326+
MyBatisContext parameter = new MyBatisContext(null, null, domainType, additionalContext);
327+
328+
Cursor<T> cursor = sqlSession().selectCursor(statement, parameter);
329+
return StreamSupport.stream(cursor.spliterator(), false);
330+
}
331+
299332
@Override
300333
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
301334

@@ -315,6 +348,11 @@ public <T> List<T> findAll(Query query, Class<T> probeType) {
315348
throw new UnsupportedOperationException("Not implemented");
316349
}
317350

351+
@Override
352+
public <T> Stream<T> streamAll(Query query, Class<T> probeType) {
353+
throw new UnsupportedOperationException("Not implemented");
354+
}
355+
318356
@Override
319357
public <T> List<T> findAll(Query query, Class<T> probeType, Pageable pageable) {
320358
throw new UnsupportedOperationException("Not implemented");

‎spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java

+45-1
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
import java.util.ArrayList;
2828
import java.util.function.Function;
2929
import java.util.stream.IntStream;
30+
import java.util.stream.Stream;
3031

31-
import org.assertj.core.api.SoftAssertions;
3232
import org.junit.jupiter.api.Test;
3333
import org.springframework.beans.factory.annotation.Autowired;
3434
import org.springframework.context.ApplicationEventPublisher;
@@ -81,6 +81,7 @@
8181
* @author Mikhail Polivakha
8282
* @author Chirag Tailor
8383
* @author Vincent Galloy
84+
* @author Sergey Korotaev
8485
*/
8586
@IntegrationTest
8687
abstract class AbstractJdbcAggregateTemplateIntegrationTests {
@@ -309,6 +310,18 @@ void saveAndLoadManyEntitiesWithReferencedEntity() {
309310
.containsExactly(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content));
310311
}
311312

313+
@Test // GH-1714
314+
void saveAndLoadManeEntitiesWithReferenceEntityLikeStream() {
315+
316+
template.save(legoSet);
317+
318+
Stream<LegoSet> streamable = template.streamAll(LegoSet.class);
319+
320+
assertThat(streamable)
321+
.extracting("id", "manual.id", "manual.content") //
322+
.containsExactly(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content));
323+
}
324+
312325
@Test // DATAJDBC-101
313326
void saveAndLoadManyEntitiesWithReferencedEntitySorted() {
314327

@@ -323,6 +336,20 @@ void saveAndLoadManyEntitiesWithReferencedEntitySorted() {
323336
.containsExactly("Frozen", "Lava", "Star");
324337
}
325338

339+
@Test // GH-1714
340+
void saveAndLoadManyEntitiesWithReferencedEntitySortedLikeStream() {
341+
342+
template.save(createLegoSet("Lava"));
343+
template.save(createLegoSet("Star"));
344+
template.save(createLegoSet("Frozen"));
345+
346+
Stream<LegoSet> reloadedLegoSets = template.streamAll(LegoSet.class, Sort.by("name"));
347+
348+
assertThat(reloadedLegoSets) //
349+
.extracting("name") //
350+
.containsExactly("Frozen", "Lava", "Star");
351+
}
352+
326353
@Test // DATAJDBC-101
327354
void saveAndLoadManyEntitiesWithReferencedEntitySortedAndPaged() {
328355

@@ -360,6 +387,12 @@ void findByNonPropertySortFails() {
360387
.isInstanceOf(InvalidPersistentPropertyPath.class);
361388
}
362389

390+
@Test // GH-1714
391+
void findByNonPropertySortLikeStreamFails() {
392+
assertThatThrownBy(() -> template.streamAll(LegoSet.class, Sort.by("somethingNotExistant")))
393+
.isInstanceOf(InvalidPersistentPropertyPath.class);
394+
}
395+
363396
@Test // DATAJDBC-112
364397
void saveAndLoadManyEntitiesByIdWithReferencedEntity() {
365398

@@ -371,6 +404,17 @@ void saveAndLoadManyEntitiesByIdWithReferencedEntity() {
371404
.contains(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content));
372405
}
373406

407+
@Test // GH-1714
408+
void saveAndLoadManyEntitiesByIdWithReferencedEntityLikeStream() {
409+
410+
template.save(legoSet);
411+
412+
Stream<LegoSet> reloadedLegoSets = template.streamAllByIds(singletonList(legoSet.id), LegoSet.class);
413+
414+
assertThat(reloadedLegoSets).hasSize(1).extracting("id", "manual.id", "manual.content")
415+
.contains(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content));
416+
}
417+
374418
@Test // DATAJDBC-112
375419
void saveAndLoadAnEntityWithReferencedNullEntity() {
376420

‎spring-data-jdbc/src/test/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategyUnitTests.java

+122-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
import static org.mockito.Mockito.*;
2323
import static org.springframework.data.relational.core.sql.SqlIdentifier.*;
2424

25+
import java.util.Iterator;
26+
import java.util.List;
27+
import java.util.stream.Stream;
28+
import org.apache.ibatis.cursor.Cursor;
2529
import org.apache.ibatis.session.SqlSession;
30+
import org.jetbrains.annotations.NotNull;
2631
import org.junit.jupiter.api.BeforeEach;
2732
import org.junit.jupiter.api.Test;
2833
import org.mockito.ArgumentCaptor;
@@ -43,6 +48,7 @@
4348
* @author Mark Paluch
4449
* @author Tyler Van Gorder
4550
* @author Chirag Tailor
51+
* @author Sergey Korotaev
4652
*/
4753
public class MyBatisDataAccessStrategyUnitTests {
4854

@@ -241,6 +247,36 @@ public void findAll() {
241247
);
242248
}
243249

250+
@Test
251+
public void streamAll() {
252+
253+
String value = "some answer";
254+
255+
Cursor<String> cursor = getCursor(value);
256+
257+
when(session.selectCursor(anyString(), any())).then(answer -> cursor);
258+
259+
Stream<String> streamable = accessStrategy.streamAll(String.class);
260+
261+
verify(session).selectCursor(eq("java.lang.StringMapper.streamAll"), captor.capture());
262+
263+
assertThat(streamable).isNotNull().containsExactly(value);
264+
265+
assertThat(captor.getValue()) //
266+
.isNotNull() //
267+
.extracting( //
268+
MyBatisContext::getInstance, //
269+
MyBatisContext::getId, //
270+
MyBatisContext::getDomainType, //
271+
c -> c.get("key") //
272+
).containsExactly( //
273+
null, //
274+
null, //
275+
String.class, //
276+
null //
277+
);
278+
}
279+
244280
@Test // DATAJDBC-123
245281
public void findAllById() {
246282

@@ -263,6 +299,33 @@ public void findAllById() {
263299
);
264300
}
265301

302+
@Test
303+
public void streamAllByIds() {
304+
305+
String value = "some answer 2";
306+
Cursor<String> cursor = getCursor(value);
307+
308+
when(session.selectCursor(anyString(), any())).then(answer -> cursor);
309+
310+
accessStrategy.streamAllByIds(asList("id1", "id2"), String.class);
311+
312+
verify(session).selectCursor(eq("java.lang.StringMapper.streamAllByIds"), captor.capture());
313+
314+
assertThat(captor.getValue()) //
315+
.isNotNull() //
316+
.extracting( //
317+
MyBatisContext::getInstance, //
318+
MyBatisContext::getId, //
319+
MyBatisContext::getDomainType, //
320+
c -> c.get("key") //
321+
).containsExactly( //
322+
null, //
323+
asList("id1", "id2"), //
324+
String.class, //
325+
null //
326+
);
327+
}
328+
266329
@SuppressWarnings("unchecked")
267330
@Test // DATAJDBC-384
268331
public void findAllByPath() {
@@ -367,6 +430,33 @@ public void findAllSorted() {
367430
);
368431
}
369432

433+
@Test
434+
public void streamAllSorted() {
435+
436+
String value = "some answer 3";
437+
Cursor<String> cursor = getCursor(value);
438+
439+
when(session.selectCursor(anyString(), any())).then(answer -> cursor);
440+
441+
accessStrategy.streamAll(String.class, Sort.by("length"));
442+
443+
verify(session).selectCursor(eq("java.lang.StringMapper.streamAllSorted"), captor.capture());
444+
445+
assertThat(captor.getValue()) //
446+
.isNotNull() //
447+
.extracting( //
448+
MyBatisContext::getInstance, //
449+
MyBatisContext::getId, //
450+
MyBatisContext::getDomainType, //
451+
c -> c.get("sort") //
452+
).containsExactly( //
453+
null, //
454+
null, //
455+
String.class, //
456+
Sort.by("length") //
457+
);
458+
}
459+
370460
@Test // DATAJDBC-101
371461
public void findAllPaged() {
372462

@@ -399,5 +489,36 @@ private static class ChildOne {
399489
ChildTwo two;
400490
}
401491

402-
private static class ChildTwo {}
492+
private static class ChildTwo {
493+
}
494+
495+
private Cursor<String> getCursor(String value) {
496+
return new Cursor<>() {
497+
@Override
498+
public boolean isOpen() {
499+
return false;
500+
}
501+
502+
@Override
503+
public boolean isConsumed() {
504+
return false;
505+
}
506+
507+
@Override
508+
public int getCurrentIndex() {
509+
return 0;
510+
}
511+
512+
@Override
513+
public void close() {
514+
515+
}
516+
517+
@NotNull
518+
@Override
519+
public Iterator<String> iterator() {
520+
return List.of(value).iterator();
521+
}
522+
};
523+
}
403524
}

0 commit comments

Comments
 (0)
Please sign in to comment.