Skip to content
This repository has been archived by the owner on Apr 22, 2020. It is now read-only.

Commit

Permalink
parameterise similarity test
Browse files Browse the repository at this point in the history
  • Loading branch information
mneedham committed Feb 12, 2019
1 parent 87df54f commit 17d4502
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ public int export(Stream<SimilarityResult> similarityPairs, long batchSize) {
.stream();

int queueSize = dssResult.getSetCount();

if(queueSize == 0) {
return 0;
}

log.info("ParallelSimilarityExporter: Relationships to be created: %d, Partitions found: %d", numberOfRelationships[0], queueSize);

ArrayBlockingQueue<List<SimilarityResult>> outQueue = new ArrayBlockingQueue<>(queueSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class SequentialSimilarityExporter extends StatementApi implements Simila

public SequentialSimilarityExporter(GraphDatabaseAPI api,
Log log, String relationshipType,
String propertyName) {
String propertyName, int nodeCount) {
super(api);
this.log = log;
propertyId = getOrCreatePropertyId(propertyName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Stream<SimilaritySummaryResult> writeAndAggregateResults(Stream<SimilarityResult

} else {
try (ProgressTimer timer = builder.timeWrite()) {
SequentialSimilarityExporter similarityExporter = new SequentialSimilarityExporter(api, log, writeRelationshipType, writeProperty);
SequentialSimilarityExporter similarityExporter = new SequentialSimilarityExporter(api, log, writeRelationshipType, writeProperty, length);
similarityExporter.export(stream.peek(recorder), writeBatchSize);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
package org.neo4j.graphalgo.similarity;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.neo4j.graphdb.Transaction;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.logging.Log;
import org.neo4j.logging.NullLog;
import org.neo4j.test.rule.ImpermanentDatabaseRule;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
Expand All @@ -17,19 +26,63 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;

@RunWith(Parameterized.class)
public class SimilarityExporterTest {
@Rule
public final ImpermanentDatabaseRule DB = new ImpermanentDatabaseRule();

private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
private static final MethodType CTOR_METHOD = MethodType.methodType(
void.class,
GraphDatabaseAPI.class,
Log.class,
String.class,
String.class,
int.class);

private static final String RELATIONSHIP_TYPE = "SIMILAR";
private static final String PROPERTY_NAME = "score";
private SimilarityExporter exporter;
private GraphDatabaseAPI api;
private Class<? extends SimilarityExporter> similarityExporterFactory;

@Parameterized.Parameters(name = "{1}")
public static Collection<Object[]> data() {
return Arrays.asList(
new Object[]{SequentialSimilarityExporter.class, "Sequential"},
new Object[]{ParallelSimilarityExporter.class, "Parallel"}
);
}

@Test
public void createNothing() {
GraphDatabaseAPI api = DB.getGraphDatabaseAPI();
createNodes(api, 2);
@Before
public void setup() {
api = DB.getGraphDatabaseAPI();
}

public SimilarityExporterTest(Class<? extends SimilarityExporter> similarityExporterFactory,
String ignoreParamOnlyForTestNaming) throws Throwable {

this.similarityExporterFactory = similarityExporterFactory;
}

SequentialSimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME);
public SimilarityExporter load(Class<? extends SimilarityExporter> factoryType, int nodeCount) throws Throwable {
final MethodHandle constructor = findConstructor(factoryType);
return (SimilarityExporter) constructor.invoke(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME, nodeCount);
}

private MethodHandle findConstructor(Class<?> factoryType) {
try {
return LOOKUP.findConstructor(factoryType, CTOR_METHOD);
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}

@Test
public void createNothing() throws Throwable {
int nodeCount = 2;
createNodes(api, nodeCount);
exporter = load(similarityExporterFactory, nodeCount);

Stream<SimilarityResult> similarityPairs = Stream.empty();

Expand All @@ -43,11 +96,10 @@ public void createNothing() {
}

@Test
public void createOneRelationship() {
GraphDatabaseAPI api = DB.getGraphDatabaseAPI();
createNodes(api, 2);

SequentialSimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME);
public void createOneRelationship() throws Throwable {
int nodeCount = 2;
createNodes(api, nodeCount);
exporter = load(similarityExporterFactory, nodeCount);

Stream<SimilarityResult> similarityPairs = Stream.of(new SimilarityResult(0, 1, -1, -1, -1, 0.5));

Expand All @@ -62,11 +114,12 @@ public void createOneRelationship() {
}

@Test
public void multipleBatches() {
GraphDatabaseAPI api = DB.getGraphDatabaseAPI();
createNodes(api, 4);
public void multipleBatches() throws Throwable {
int nodeCount = 4;
createNodes(api, nodeCount);
exporter = load(similarityExporterFactory, nodeCount);

SequentialSimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME);
SimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME, 4);

Stream<SimilarityResult> similarityPairs = Stream.of(
new SimilarityResult(0, 1, -1, -1, -1, 0.5),
Expand All @@ -86,16 +139,16 @@ public void multipleBatches() {
}

@Test
public void smallerThanBatchSize() {
GraphDatabaseAPI api = DB.getGraphDatabaseAPI();
createNodes(api, 5);

SequentialSimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME);
public void smallerThanBatchSize() throws Throwable {
int nodeCount = 5;
createNodes(api, nodeCount);
exporter = load(similarityExporterFactory, nodeCount);

Stream<SimilarityResult> similarityPairs = Stream.of(
new SimilarityResult(0, 1, -1, -1, -1, 0.5),
new SimilarityResult(1, 2, -1, -1, -1, 0.6),
new SimilarityResult(2, 3, -1, -1, -1, 0.7),
new SimilarityResult(3, 4, -1, -1, -1, 0.7)
new SimilarityResult(3, 4, -1, -1, -1, 0.8)
);

int batches = exporter.export(similarityPairs, 10);
Expand All @@ -104,10 +157,11 @@ public void smallerThanBatchSize() {
try (Transaction tx = api.beginTx()) {
List<SimilarityRelationship> allRelationships = getSimilarityRelationships(api);

assertThat(allRelationships, hasSize(3));
assertThat(allRelationships, hasSize(4));
assertThat(allRelationships, hasItems(new SimilarityRelationship(0, 1, 0.5)));
assertThat(allRelationships, hasItems(new SimilarityRelationship(1, 2, 0.6)));
assertThat(allRelationships, hasItems(new SimilarityRelationship(2, 3, 0.7)));
assertThat(allRelationships, hasItems(new SimilarityRelationship(3, 4, 0.7)));
assertThat(allRelationships, hasItems(new SimilarityRelationship(3, 4, 0.8)));
}
}

Expand Down

0 comments on commit 17d4502

Please sign in to comment.