Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 9c91b16

Browse files
authored
Weighted PageRank (#718)
* pr bm * [WIP] Weighted PageRank * call weighted from the procedure - also adapt integration tests * Cleaner way of calling weighted vs unweighted * Implemented for Huge graphs * docs for weighted PR * update description * mention weightProperty in docs * configurable default * add default to docs * making PageRank and WeightedPageRank similar to make it easier to abstract * move PageRank stuff into its own package * abstract ComputeStep creation * ComputeStep as interface * abstract on ComputeStepFactory * no longer need this class * common base class * reuse code for Huge PageRank as well * calculate the weighted degree centrality up front and then reuse it * updating pr benchmarks! * weight benchmark * make pr benchmarks compile + fix link * exclude negative weights on Weighted PageRank * don't need to compute this for non weighted PageRank * better named interface * add refactoring to be done * introduce DegreeComputer so we have the same level of abstraction in PageRankVariant * merge left over * pr bm * inline all the things in ComputeStep * much faster with the class as RelationshipConsumer * update weighted to use consumer too * run all * try with all variations of iterations * add weighted to LDBC
1 parent 455f008 commit 9c91b16

35 files changed

+1809
-422
lines changed

algo/src/main/java/org/neo4j/graphalgo/PageRankProc.java

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
import org.neo4j.graphalgo.core.utils.TerminationFlag;
2929
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
3030
import org.neo4j.graphalgo.core.write.Exporter;
31-
import org.neo4j.graphalgo.impl.PageRankResult;
31+
import org.neo4j.graphalgo.impl.pagerank.PageRankResult;
3232
import org.neo4j.graphalgo.impl.Algorithm;
33-
import org.neo4j.graphalgo.impl.PageRankAlgorithm;
33+
import org.neo4j.graphalgo.impl.pagerank.PageRankAlgorithm;
3434
import org.neo4j.graphalgo.results.PageRankScore;
3535
import org.neo4j.graphdb.Direction;
3636
import org.neo4j.graphdb.Node;
@@ -58,6 +58,8 @@ public final class PageRankProc {
5858
public static final Integer DEFAULT_ITERATIONS = 20;
5959
public static final String DEFAULT_SCORE_PROPERTY = "pagerank";
6060

61+
public static final String CONFIG_WEIGHT_KEY = "weightProperty";
62+
6163
@Context
6264
public GraphDatabaseAPI api;
6365

@@ -69,7 +71,7 @@ public final class PageRankProc {
6971

7072
@Procedure(value = "algo.pageRank", mode = Mode.WRITE)
7173
@Description("CALL algo.pageRank(label:String, relationship:String, " +
72-
"{iterations:5, dampingFactor:0.85, write: true, writeProperty:'pagerank', concurrency:4}) " +
74+
"{iterations:5, dampingFactor:0.85, weightProperty: null, write: true, writeProperty:'pagerank', concurrency:4}) " +
7375
"YIELD nodes, iterations, loadMillis, computeMillis, writeMillis, dampingFactor, write, writeProperty" +
7476
" - calculates page rank and potentially writes back")
7577
public Stream<PageRankScore.Stats> pageRank(
@@ -79,17 +81,19 @@ public Stream<PageRankScore.Stats> pageRank(
7981

8082
ProcedureConfiguration configuration = ProcedureConfiguration.create(config);
8183

84+
final String weightPropertyKey = configuration.getString(CONFIG_WEIGHT_KEY, null);
85+
8286
PageRankScore.Stats.Builder statsBuilder = new PageRankScore.Stats.Builder();
8387
AllocationTracker tracker = AllocationTracker.create();
84-
final Graph graph = load(label, relationship, tracker, configuration.getGraphImpl(), statsBuilder, configuration);
88+
final Graph graph = load(label, relationship, tracker, configuration.getGraphImpl(), statsBuilder, configuration, weightPropertyKey);
8589

8690
if(graph.nodeCount() == 0) {
8791
graph.release();
8892
return Stream.of(statsBuilder.build());
8993
}
9094

9195
TerminationFlag terminationFlag = TerminationFlag.wrap(transaction);
92-
PageRankResult scores = evaluate(graph, tracker, terminationFlag, configuration, statsBuilder);
96+
PageRankResult scores = evaluate(graph, tracker, terminationFlag, configuration, statsBuilder, weightPropertyKey);
9397

9498
log.info("PageRank: overall memory usage: %s", tracker.getUsageString());
9599

@@ -100,7 +104,7 @@ public Stream<PageRankScore.Stats> pageRank(
100104

101105
@Procedure(value = "algo.pageRank.stream", mode = Mode.READ)
102106
@Description("CALL algo.pageRank.stream(label:String, relationship:String, " +
103-
"{iterations:20, dampingFactor:0.85, concurrency:4}) " +
107+
"{iterations:20, dampingFactor:0.85, weightProperty: null, concurrency:4}) " +
104108
"YIELD node, score - calculates page rank and streams results")
105109
public Stream<PageRankScore> pageRankStream(
106110
@Name(value = "label", defaultValue = "") String label,
@@ -109,17 +113,19 @@ public Stream<PageRankScore> pageRankStream(
109113

110114
ProcedureConfiguration configuration = ProcedureConfiguration.create(config);
111115

116+
final String weightPropertyKey = configuration.getString(CONFIG_WEIGHT_KEY, null);
117+
112118
PageRankScore.Stats.Builder statsBuilder = new PageRankScore.Stats.Builder();
113119
AllocationTracker tracker = AllocationTracker.create();
114-
final Graph graph = load(label, relationship, tracker, configuration.getGraphImpl(), statsBuilder, configuration);
120+
final Graph graph = load(label, relationship, tracker, configuration.getGraphImpl(), statsBuilder, configuration, weightPropertyKey);
115121

116122
if(graph.nodeCount() == 0) {
117123
graph.release();
118124
return Stream.empty();
119125
}
120126

121127
TerminationFlag terminationFlag = TerminationFlag.wrap(transaction);
122-
PageRankResult scores = evaluate(graph, tracker, terminationFlag, configuration, statsBuilder);
128+
PageRankResult scores = evaluate(graph, tracker, terminationFlag, configuration, statsBuilder, weightPropertyKey);
123129

124130
log.info("PageRank: overall memory usage: %s", tracker.getUsageString());
125131

@@ -152,11 +158,13 @@ private Graph load(
152158
String relationship,
153159
AllocationTracker tracker,
154160
Class<? extends GraphFactory> graphFactory,
155-
PageRankScore.Stats.Builder statsBuilder, ProcedureConfiguration configuration) {
161+
PageRankScore.Stats.Builder statsBuilder,
162+
ProcedureConfiguration configuration,
163+
String weightPropertyKey) {
156164
GraphLoader graphLoader = new GraphLoader(api, Pools.DEFAULT)
157165
.init(log, label, relationship, configuration)
158166
.withAllocationTracker(tracker)
159-
.withoutRelationshipWeights();
167+
.withOptionalRelationshipWeightsFromProperty(weightPropertyKey, configuration.getWeightPropertyDefaultValue(0.0));
160168

161169
Direction direction = configuration.getDirection(Direction.OUTGOING);
162170
if (direction == Direction.BOTH) {
@@ -178,7 +186,8 @@ private PageRankResult evaluate(
178186
AllocationTracker tracker,
179187
TerminationFlag terminationFlag,
180188
ProcedureConfiguration configuration,
181-
PageRankScore.Stats.Builder statsBuilder) {
189+
PageRankScore.Stats.Builder statsBuilder,
190+
String weightPropertyKey) {
182191

183192
double dampingFactor = configuration.get(CONFIG_DAMPING, DEFAULT_DAMPING);
184193
int iterations = configuration.getIterations(DEFAULT_ITERATIONS);
@@ -189,14 +198,29 @@ private PageRankResult evaluate(
189198

190199
List<Node> sourceNodes = configuration.get("sourceNodes", new ArrayList<>());
191200
LongStream sourceNodeIds = sourceNodes.stream().mapToLong(Node::getId);
192-
PageRankAlgorithm prAlgo = PageRankAlgorithm.of(
193-
tracker,
194-
graph,
195-
dampingFactor,
196-
sourceNodeIds,
197-
Pools.DEFAULT,
198-
concurrency,
199-
batchSize);
201+
202+
PageRankAlgorithm prAlgo;
203+
if(weightPropertyKey != null) {
204+
prAlgo = PageRankAlgorithm.weightedOf(
205+
tracker,
206+
graph,
207+
dampingFactor,
208+
sourceNodeIds,
209+
Pools.DEFAULT,
210+
concurrency,
211+
batchSize);
212+
} else {
213+
prAlgo = PageRankAlgorithm.of(
214+
tracker,
215+
graph,
216+
dampingFactor,
217+
sourceNodeIds,
218+
Pools.DEFAULT,
219+
concurrency,
220+
batchSize);
221+
}
222+
223+
200224
Algorithm<?> algo = prAlgo
201225
.algorithm()
202226
.withLog(log)

algo/src/main/java/org/neo4j/graphalgo/impl/WeightedDegreeCentrality.java

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
import org.neo4j.graphalgo.api.Graph;
44
import org.neo4j.graphalgo.api.WeightedRelationshipConsumer;
55
import org.neo4j.graphalgo.core.utils.ParallelUtil;
6+
import org.neo4j.graphalgo.core.utils.Pools;
7+
import org.neo4j.graphalgo.impl.pagerank.HugeComputeStep;
68
import org.neo4j.graphdb.Direction;
79

810
import java.util.ArrayList;
11+
import java.util.List;
912
import java.util.concurrent.ExecutorService;
1013
import java.util.concurrent.Future;
1114
import java.util.concurrent.atomic.AtomicInteger;
@@ -19,30 +22,35 @@ public class WeightedDegreeCentrality extends Algorithm<WeightedDegreeCentrality
1922
private final ExecutorService executor;
2023
private final int concurrency;
2124
private volatile AtomicInteger nodeQueue = new AtomicInteger();
22-
private int[] degrees;
25+
private double[] degrees;
2326

2427
public WeightedDegreeCentrality(
2528
Graph graph,
2629
ExecutorService executor,
2730
int concurrency,
2831
Direction direction
2932
) {
33+
if (concurrency <= 0) {
34+
concurrency = Pools.DEFAULT_QUEUE_SIZE;
35+
}
3036

3137
this.graph = graph;
3238
this.executor = executor;
3339
this.concurrency = concurrency;
3440
nodeCount = Math.toIntExact(graph.nodeCount());
3541
this.direction = direction;
36-
degrees = new int[nodeCount];
42+
degrees = new double[nodeCount];
3743
}
3844

3945
public WeightedDegreeCentrality compute() {
4046
nodeQueue.set(0);
41-
final ArrayList<Future<?>> futures = new ArrayList<>();
47+
48+
List<DegreeTask> tasks = new ArrayList<>();
4249
for (int i = 0; i < concurrency; i++) {
43-
futures.add(executor.submit(new DegreeTask()));
50+
tasks.add(new DegreeTask());
4451
}
45-
ParallelUtil.awaitTermination(futures);
52+
ParallelUtil.runWithConcurrency(concurrency, tasks, executor);
53+
4654
return this;
4755
}
4856

@@ -66,11 +74,11 @@ public void run() {
6674
return;
6775
}
6876

69-
int[] weightedDegree = new int[1];
77+
double[] weightedDegree = new double[1];
7078
graph.forEachRelationship(nodeId, direction, (sourceNodeId, targetNodeId, relationId, weight) -> {
71-
double v = graph.weightOf(targetNodeId, sourceNodeId);
72-
System.out.println(sourceNodeId + ", " + targetNodeId + " -> " + weight + ", " + v);
73-
weightedDegree[0] += weight;
79+
if(weight > 0) {
80+
weightedDegree[0] += weight;
81+
}
7482
return true;
7583
});
7684

@@ -80,7 +88,7 @@ public void run() {
8088
}
8189
}
8290

83-
public int[] degrees() {
91+
public double[] degrees() {
8492
return degrees;
8593
}
8694

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
package org.neo4j.graphalgo.impl.pagerank;
2+
3+
import org.neo4j.graphalgo.api.Degrees;
4+
import org.neo4j.graphalgo.api.RelationshipIterator;
5+
import org.neo4j.graphalgo.api.RelationshipWeights;
6+
import org.neo4j.graphdb.Direction;
7+
8+
import java.util.Arrays;
9+
import java.util.stream.IntStream;
10+
11+
import static org.neo4j.graphalgo.core.utils.ArrayUtil.binaryLookup;
12+
13+
public abstract class BaseComputeStep implements ComputeStep {
14+
private static final int S_INIT = 0;
15+
private static final int S_CALC = 1;
16+
private static final int S_SYNC = 2;
17+
18+
private int state;
19+
20+
int[] starts;
21+
private int[] lengths;
22+
private int[] sourceNodeIds;
23+
final RelationshipIterator relationshipIterator;
24+
final Degrees degrees;
25+
26+
private final double alpha;
27+
private final double dampingFactor;
28+
29+
private double[] pageRank;
30+
double[] deltas;
31+
int[][] nextScores;
32+
private int[][] prevScores;
33+
34+
private final int partitionSize;
35+
final int startNode;
36+
final int endNode;
37+
38+
BaseComputeStep(
39+
double dampingFactor,
40+
int[] sourceNodeIds,
41+
RelationshipIterator relationshipIterator,
42+
Degrees degrees,
43+
int partitionSize,
44+
int startNode) {
45+
this.dampingFactor = dampingFactor;
46+
this.alpha = 1.0 - dampingFactor;
47+
this.sourceNodeIds = sourceNodeIds;
48+
this.relationshipIterator = relationshipIterator;
49+
this.degrees = degrees;
50+
this.partitionSize = partitionSize;
51+
this.startNode = startNode;
52+
this.endNode = startNode + partitionSize;
53+
state = S_INIT;
54+
}
55+
56+
public void setStarts(int starts[], int[] lengths) {
57+
this.starts = starts;
58+
this.lengths = lengths;
59+
}
60+
61+
@Override
62+
public void run() {
63+
if (state == S_CALC) {
64+
singleIteration();
65+
state = S_SYNC;
66+
} else if (state == S_SYNC) {
67+
synchronizeScores(combineScores());
68+
state = S_CALC;
69+
} else if (state == S_INIT) {
70+
initialize();
71+
state = S_CALC;
72+
}
73+
}
74+
75+
private void initialize() {
76+
this.nextScores = new int[starts.length][];
77+
Arrays.setAll(nextScores, i -> new int[lengths[i]]);
78+
79+
double[] partitionRank = new double[partitionSize];
80+
81+
if(sourceNodeIds.length == 0) {
82+
Arrays.fill(partitionRank, alpha);
83+
} else {
84+
Arrays.fill(partitionRank,0);
85+
86+
int[] partitionSourceNodeIds = IntStream.of(sourceNodeIds)
87+
.filter(sourceNodeId -> sourceNodeId >= startNode && sourceNodeId < endNode)
88+
.toArray();
89+
90+
for (int sourceNodeId : partitionSourceNodeIds) {
91+
partitionRank[sourceNodeId - this.startNode] = alpha;
92+
}
93+
}
94+
95+
96+
this.pageRank = partitionRank;
97+
this.deltas = Arrays.copyOf(partitionRank, partitionSize);
98+
}
99+
100+
abstract void singleIteration();
101+
102+
public void prepareNextIteration(int[][] prevScores) {
103+
this.prevScores = prevScores;
104+
}
105+
106+
private int[] combineScores() {
107+
assert prevScores != null;
108+
assert prevScores.length >= 1;
109+
int[][] prevScores = this.prevScores;
110+
111+
int length = prevScores.length;
112+
int[] allScores = prevScores[0];
113+
for (int i = 1; i < length; i++) {
114+
int[] scores = prevScores[i];
115+
for (int j = 0; j < scores.length; j++) {
116+
allScores[j] += scores[j];
117+
scores[j] = 0;
118+
}
119+
}
120+
121+
return allScores;
122+
}
123+
124+
private void synchronizeScores(int[] allScores) {
125+
double dampingFactor = this.dampingFactor;
126+
double[] pageRank = this.pageRank;
127+
128+
int length = allScores.length;
129+
for (int i = 0; i < length; i++) {
130+
int sum = allScores[i];
131+
132+
double delta = dampingFactor * (sum / 100_000.0);
133+
pageRank[i] += delta;
134+
deltas[i] = delta;
135+
allScores[i] = 0;
136+
}
137+
}
138+
139+
@Override
140+
public int[][] nextScores() {
141+
return nextScores;
142+
}
143+
144+
@Override
145+
public double[] pageRank() {
146+
return pageRank;
147+
}
148+
149+
@Override
150+
public int[] starts() {
151+
return starts;
152+
}
153+
154+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package org.neo4j.graphalgo.impl.pagerank;
2+
3+
public interface ComputeStep extends Runnable {
4+
int[][] nextScores();
5+
6+
double[] pageRank();
7+
8+
int[] starts();
9+
10+
void setStarts(int[] startArray, int[] lengthArray);
11+
12+
void prepareNextIteration(int[][] score);
13+
}

0 commit comments

Comments
 (0)