diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/ParallelSimilarityExporter.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/ParallelSimilarityExporter.java index c74bb8d1b..78313d482 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/ParallelSimilarityExporter.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/ParallelSimilarityExporter.java @@ -90,6 +90,11 @@ public int export(Stream similarityPairs, long batchSize) { .stream(); int queueSize = dssResult.getSetCount(); + + if(queueSize == 0) { + return 0; + } + log.info("ParallelSimilarityExporter: Relationships to be created: %d, Partitions found: %d", numberOfRelationships[0], queueSize); ArrayBlockingQueue> outQueue = new ArrayBlockingQueue<>(queueSize); diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/SequentialSimilarityExporter.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/SequentialSimilarityExporter.java index b753a65a7..787ba39c8 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/SequentialSimilarityExporter.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/SequentialSimilarityExporter.java @@ -42,7 +42,7 @@ public class SequentialSimilarityExporter extends StatementApi implements Simila public SequentialSimilarityExporter(GraphDatabaseAPI api, Log log, String relationshipType, - String propertyName) { + String propertyName, int nodeCount) { super(api); this.log = log; propertyId = getOrCreatePropertyId(propertyName); diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java index d6ba6a87c..b48dc42c8 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java @@ -161,7 +161,7 @@ Stream writeAndAggregateResults(Stream similarityExporterFactory; + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList( + new Object[]{SequentialSimilarityExporter.class, "Sequential"}, + new Object[]{ParallelSimilarityExporter.class, "Parallel"} + ); + } - @Test - public void createNothing() { - GraphDatabaseAPI api = DB.getGraphDatabaseAPI(); - createNodes(api, 2); + @Before + public void setup() { + api = DB.getGraphDatabaseAPI(); + } + + public SimilarityExporterTest(Class similarityExporterFactory, + String ignoreParamOnlyForTestNaming) throws Throwable { + + this.similarityExporterFactory = similarityExporterFactory; + } - SequentialSimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME); + public SimilarityExporter load(Class factoryType, int nodeCount) throws Throwable { + final MethodHandle constructor = findConstructor(factoryType); + return (SimilarityExporter) constructor.invoke(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME, nodeCount); + } + + private MethodHandle findConstructor(Class factoryType) { + try { + return LOOKUP.findConstructor(factoryType, CTOR_METHOD); + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + @Test + public void createNothing() throws Throwable { + int nodeCount = 2; + createNodes(api, nodeCount); + exporter = load(similarityExporterFactory, nodeCount); Stream similarityPairs = Stream.empty(); @@ -43,11 +96,10 @@ public void createNothing() { } @Test - public void createOneRelationship() { - GraphDatabaseAPI api = DB.getGraphDatabaseAPI(); - createNodes(api, 2); - - SequentialSimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME); + public void createOneRelationship() throws Throwable { + int nodeCount = 2; + createNodes(api, nodeCount); + exporter = load(similarityExporterFactory, nodeCount); Stream similarityPairs = Stream.of(new SimilarityResult(0, 1, -1, -1, -1, 0.5)); @@ -62,11 +114,12 @@ public void createOneRelationship() { } @Test - public void multipleBatches() { - GraphDatabaseAPI api = DB.getGraphDatabaseAPI(); - createNodes(api, 4); + public void multipleBatches() throws Throwable { + int nodeCount = 4; + createNodes(api, nodeCount); + exporter = load(similarityExporterFactory, nodeCount); - SequentialSimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME); + SimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME, 4); Stream similarityPairs = Stream.of( new SimilarityResult(0, 1, -1, -1, -1, 0.5), @@ -86,16 +139,16 @@ public void multipleBatches() { } @Test - public void smallerThanBatchSize() { - GraphDatabaseAPI api = DB.getGraphDatabaseAPI(); - createNodes(api, 5); - - SequentialSimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME); + public void smallerThanBatchSize() throws Throwable { + int nodeCount = 5; + createNodes(api, nodeCount); + exporter = load(similarityExporterFactory, nodeCount); Stream similarityPairs = Stream.of( new SimilarityResult(0, 1, -1, -1, -1, 0.5), + new SimilarityResult(1, 2, -1, -1, -1, 0.6), new SimilarityResult(2, 3, -1, -1, -1, 0.7), - new SimilarityResult(3, 4, -1, -1, -1, 0.7) + new SimilarityResult(3, 4, -1, -1, -1, 0.8) ); int batches = exporter.export(similarityPairs, 10); @@ -104,10 +157,11 @@ public void smallerThanBatchSize() { try (Transaction tx = api.beginTx()) { List allRelationships = getSimilarityRelationships(api); - assertThat(allRelationships, hasSize(3)); + assertThat(allRelationships, hasSize(4)); assertThat(allRelationships, hasItems(new SimilarityRelationship(0, 1, 0.5))); + assertThat(allRelationships, hasItems(new SimilarityRelationship(1, 2, 0.6))); assertThat(allRelationships, hasItems(new SimilarityRelationship(2, 3, 0.7))); - assertThat(allRelationships, hasItems(new SimilarityRelationship(3, 4, 0.7))); + assertThat(allRelationships, hasItems(new SimilarityRelationship(3, 4, 0.8))); } }