Skip to content

Commit

Permalink
fix: enable #getSlice on SearchStreams (resolves gh-471)
Browse files Browse the repository at this point in the history
  • Loading branch information
bsbodden committed Jun 18, 2024
1 parent 48d35ce commit 1f2591a
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

import static com.redis.om.spring.metamodel.MetamodelUtils.getMetamodelForIdField;
import static com.redis.om.spring.util.ObjectUtils.floatArrayToByteArray;
import static com.redis.om.spring.util.ObjectUtils.pageFromSlice;
import static java.util.stream.Collectors.toCollection;

public class SearchStreamImpl<E> implements SearchStream<E> {
Expand Down Expand Up @@ -556,6 +557,10 @@ private Stream<E> resolveStream() {
return resolvedStream;
}

private boolean isStreamResolved() {
return resolvedStream != null;
}

@Override
public Class<E> getEntityClass() {
return entityClass;
Expand Down Expand Up @@ -695,13 +700,24 @@ public SearchOperations<String> getSearchOperations() {

@Override
public Slice<E> getSlice(Pageable pageable) {
resolvedStream = Stream.empty();
if (pageable.getClass().isAssignableFrom(AggregationPageable.class)) {
resolvedStream = Stream.empty();
AggregationPageable ap = (AggregationPageable) pageable;
AggregationResult ar = search.cursorRead(ap.getCursorId(), pageable.getPageSize());
return new AggregationPage<>(ar, pageable, entityClass, getGson(), mappingConverter, isDocument);
} else {
return Page.empty(pageable);
if (!isStreamResolved()) {
this.sorted(pageable.getSort()).limit(pageable.getPageSize()).skip(Math.toIntExact(pageable.getOffset()));
// issue a count query to answer the hasNext? question for the slice/page
Query countQuery = (rootNode.toString().isBlank()) ? new Query() : new Query(rootNode.toString());
countQuery.limit(Math.toIntExact(pageable.getOffset() + pageable.getPageSize()), pageable.getPageSize());
SearchResult searchResult = search.search(countQuery);

return new SliceImpl<>(this.resolveStream().toList(), pageable, !searchResult.getDocuments().isEmpty());
} else {
return new SliceImpl<E>(List.of());
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Slice;
import com.redis.om.spring.repository.query.Sort;
import org.springframework.data.geo.Distance;
import org.springframework.data.geo.Metrics;
import org.springframework.data.geo.Point;
Expand Down Expand Up @@ -2724,4 +2727,66 @@ void testOrElseAndUpdate() {

repository.delete(updatedCompany);
}

@Test
void testManualPagination() {
int PAGE_SIZE = 2;
int page = 0;

// get first page
List<Company> page0 = entityStream.of(Company.class) //
.sorted(Company$.NAME) //
.skip(page * PAGE_SIZE) //
.limit(PAGE_SIZE) //
.collect(Collectors.toList());

assertThat(page0).hasSize(PAGE_SIZE);

List<String> names0 = page0.stream().map(Company::getName).collect(Collectors.toList());
assertThat(names0).containsExactly("Microsoft", "RedisInc");

// get second page
page = 1;
List<Company> page1 = entityStream.of(Company.class) //
.sorted(Company$.NAME) //
.skip(page * PAGE_SIZE) //
.limit(PAGE_SIZE) //
.collect(Collectors.toList());

assertThat(page1).hasSize(1);

List<String> names1 = page1.stream().map(Company::getName).collect(Collectors.toList());
assertThat(names1).containsExactly("Tesla");
}

@Test
void testPageablePagination() {
int PAGE_SIZE = 2;
int page = 0;

var page0Request = PageRequest.of(page, PAGE_SIZE, Sort.by(Company$.NAME));

// get first page
Slice<Company> page0 = entityStream.of(Company.class) //
.getSlice(page0Request);


assertThat(page0).hasSize(PAGE_SIZE);
assertThat(page0.hasNext()).isTrue();

List<String> names0 = page0.stream().map(Company::getName).collect(Collectors.toList());
assertThat(names0).containsExactly("Microsoft", "RedisInc");

// get second page
page = 1;
var page1Request = PageRequest.of(page, PAGE_SIZE, Sort.by(Company$.NAME));
Slice<Company> page1 = entityStream.of(Company.class) //
.getSlice(page1Request);

assertThat(page1).hasSize(1);
assertThat(page1.hasNext()).isFalse();

List<String> names1 = page1.stream().map(Company::getName).collect(Collectors.toList());
assertThat(names1).containsExactly("Tesla");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
import com.redis.om.spring.fixtures.hash.model.Company$;
import com.redis.om.spring.fixtures.hash.repository.ASimpleHashRepository;
import com.redis.om.spring.fixtures.hash.repository.CompanyRepository;
import com.redis.om.spring.repository.query.Sort;
import com.redis.om.spring.tuple.Fields;
import com.redis.om.spring.tuple.Pair;
import com.redis.om.spring.tuple.Triple;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Slice;
import org.springframework.data.geo.Distance;
import org.springframework.data.geo.Metrics;
import org.springframework.data.geo.Point;
Expand Down Expand Up @@ -1731,4 +1734,66 @@ void testContainingPredicateOnFreeFormTextMiddleIncompleteWords() {

aSimpleHashRepository.delete(hash);
}

@Test
void testManualPagination() {
int PAGE_SIZE = 2;
int page = 0;

// get first page
List<Company> page0 = entityStream.of(Company.class) //
.sorted(Company$.NAME) //
.skip(page * PAGE_SIZE) //
.limit(PAGE_SIZE) //
.collect(Collectors.toList());

assertThat(page0).hasSize(PAGE_SIZE);

List<String> names0 = page0.stream().map(Company::getName).collect(Collectors.toList());
assertThat(names0).containsExactly("Microsoft", "RedisInc");

// get second page
page = 1;
List<Company> page1 = entityStream.of(Company.class) //
.sorted(Company$.NAME) //
.skip(page * PAGE_SIZE) //
.limit(PAGE_SIZE) //
.collect(Collectors.toList());

assertThat(page1).hasSize(1);

List<String> names1 = page1.stream().map(Company::getName).collect(Collectors.toList());
assertThat(names1).containsExactly("Tesla");
}

@Test
void testPageablePagination() {
int PAGE_SIZE = 2;
int page = 0;

var page0Request = PageRequest.of(page, PAGE_SIZE, Sort.by(Company$.NAME));

// get first page
Slice<Company> page0 = entityStream.of(Company.class) //
.getSlice(page0Request);


assertThat(page0).hasSize(PAGE_SIZE);
assertThat(page0.hasNext()).isTrue();

List<String> names0 = page0.stream().map(Company::getName).collect(Collectors.toList());
assertThat(names0).containsExactly("Microsoft", "RedisInc");

// get second page
page = 1;
var page1Request = PageRequest.of(page, PAGE_SIZE, Sort.by(Company$.NAME));
Slice<Company> page1 = entityStream.of(Company.class) //
.getSlice(page1Request);

assertThat(page1).hasSize(1);
assertThat(page1.hasNext()).isFalse();

List<String> names1 = page1.stream().map(Company::getName).collect(Collectors.toList());
assertThat(names1).containsExactly("Tesla");
}
}

0 comments on commit 1f2591a

Please sign in to comment.