28
28
import org .neo4j .graphalgo .core .utils .TerminationFlag ;
29
29
import org .neo4j .graphalgo .core .utils .paged .AllocationTracker ;
30
30
import org .neo4j .graphalgo .core .write .Exporter ;
31
- import org .neo4j .graphalgo .impl .PageRankResult ;
31
+ import org .neo4j .graphalgo .impl .pagerank . PageRankResult ;
32
32
import org .neo4j .graphalgo .impl .Algorithm ;
33
- import org .neo4j .graphalgo .impl .PageRankAlgorithm ;
33
+ import org .neo4j .graphalgo .impl .pagerank . PageRankAlgorithm ;
34
34
import org .neo4j .graphalgo .results .PageRankScore ;
35
35
import org .neo4j .graphdb .Direction ;
36
36
import org .neo4j .graphdb .Node ;
@@ -58,6 +58,8 @@ public final class PageRankProc {
58
58
public static final Integer DEFAULT_ITERATIONS = 20 ;
59
59
public static final String DEFAULT_SCORE_PROPERTY = "pagerank" ;
60
60
61
+ public static final String CONFIG_WEIGHT_KEY = "weightProperty" ;
62
+
61
63
@ Context
62
64
public GraphDatabaseAPI api ;
63
65
@@ -69,7 +71,7 @@ public final class PageRankProc {
69
71
70
72
@ Procedure (value = "algo.pageRank" , mode = Mode .WRITE )
71
73
@ Description ("CALL algo.pageRank(label:String, relationship:String, " +
72
- "{iterations:5, dampingFactor:0.85, write: true, writeProperty:'pagerank', concurrency:4}) " +
74
+ "{iterations:5, dampingFactor:0.85, weightProperty: null, write: true, writeProperty:'pagerank', concurrency:4}) " +
73
75
"YIELD nodes, iterations, loadMillis, computeMillis, writeMillis, dampingFactor, write, writeProperty" +
74
76
" - calculates page rank and potentially writes back" )
75
77
public Stream <PageRankScore .Stats > pageRank (
@@ -79,17 +81,19 @@ public Stream<PageRankScore.Stats> pageRank(
79
81
80
82
ProcedureConfiguration configuration = ProcedureConfiguration .create (config );
81
83
84
+ final String weightPropertyKey = configuration .getString (CONFIG_WEIGHT_KEY , null );
85
+
82
86
PageRankScore .Stats .Builder statsBuilder = new PageRankScore .Stats .Builder ();
83
87
AllocationTracker tracker = AllocationTracker .create ();
84
- final Graph graph = load (label , relationship , tracker , configuration .getGraphImpl (), statsBuilder , configuration );
88
+ final Graph graph = load (label , relationship , tracker , configuration .getGraphImpl (), statsBuilder , configuration , weightPropertyKey );
85
89
86
90
if (graph .nodeCount () == 0 ) {
87
91
graph .release ();
88
92
return Stream .of (statsBuilder .build ());
89
93
}
90
94
91
95
TerminationFlag terminationFlag = TerminationFlag .wrap (transaction );
92
- PageRankResult scores = evaluate (graph , tracker , terminationFlag , configuration , statsBuilder );
96
+ PageRankResult scores = evaluate (graph , tracker , terminationFlag , configuration , statsBuilder , weightPropertyKey );
93
97
94
98
log .info ("PageRank: overall memory usage: %s" , tracker .getUsageString ());
95
99
@@ -100,7 +104,7 @@ public Stream<PageRankScore.Stats> pageRank(
100
104
101
105
@ Procedure (value = "algo.pageRank.stream" , mode = Mode .READ )
102
106
@ Description ("CALL algo.pageRank.stream(label:String, relationship:String, " +
103
- "{iterations:20, dampingFactor:0.85, concurrency:4}) " +
107
+ "{iterations:20, dampingFactor:0.85, weightProperty: null, concurrency:4}) " +
104
108
"YIELD node, score - calculates page rank and streams results" )
105
109
public Stream <PageRankScore > pageRankStream (
106
110
@ Name (value = "label" , defaultValue = "" ) String label ,
@@ -109,17 +113,19 @@ public Stream<PageRankScore> pageRankStream(
109
113
110
114
ProcedureConfiguration configuration = ProcedureConfiguration .create (config );
111
115
116
+ final String weightPropertyKey = configuration .getString (CONFIG_WEIGHT_KEY , null );
117
+
112
118
PageRankScore .Stats .Builder statsBuilder = new PageRankScore .Stats .Builder ();
113
119
AllocationTracker tracker = AllocationTracker .create ();
114
- final Graph graph = load (label , relationship , tracker , configuration .getGraphImpl (), statsBuilder , configuration );
120
+ final Graph graph = load (label , relationship , tracker , configuration .getGraphImpl (), statsBuilder , configuration , weightPropertyKey );
115
121
116
122
if (graph .nodeCount () == 0 ) {
117
123
graph .release ();
118
124
return Stream .empty ();
119
125
}
120
126
121
127
TerminationFlag terminationFlag = TerminationFlag .wrap (transaction );
122
- PageRankResult scores = evaluate (graph , tracker , terminationFlag , configuration , statsBuilder );
128
+ PageRankResult scores = evaluate (graph , tracker , terminationFlag , configuration , statsBuilder , weightPropertyKey );
123
129
124
130
log .info ("PageRank: overall memory usage: %s" , tracker .getUsageString ());
125
131
@@ -152,11 +158,13 @@ private Graph load(
152
158
String relationship ,
153
159
AllocationTracker tracker ,
154
160
Class <? extends GraphFactory > graphFactory ,
155
- PageRankScore .Stats .Builder statsBuilder , ProcedureConfiguration configuration ) {
161
+ PageRankScore .Stats .Builder statsBuilder ,
162
+ ProcedureConfiguration configuration ,
163
+ String weightPropertyKey ) {
156
164
GraphLoader graphLoader = new GraphLoader (api , Pools .DEFAULT )
157
165
.init (log , label , relationship , configuration )
158
166
.withAllocationTracker (tracker )
159
- .withoutRelationshipWeights ( );
167
+ .withOptionalRelationshipWeightsFromProperty ( weightPropertyKey , configuration . getWeightPropertyDefaultValue ( 0.0 ) );
160
168
161
169
Direction direction = configuration .getDirection (Direction .OUTGOING );
162
170
if (direction == Direction .BOTH ) {
@@ -178,7 +186,8 @@ private PageRankResult evaluate(
178
186
AllocationTracker tracker ,
179
187
TerminationFlag terminationFlag ,
180
188
ProcedureConfiguration configuration ,
181
- PageRankScore .Stats .Builder statsBuilder ) {
189
+ PageRankScore .Stats .Builder statsBuilder ,
190
+ String weightPropertyKey ) {
182
191
183
192
double dampingFactor = configuration .get (CONFIG_DAMPING , DEFAULT_DAMPING );
184
193
int iterations = configuration .getIterations (DEFAULT_ITERATIONS );
@@ -189,14 +198,29 @@ private PageRankResult evaluate(
189
198
190
199
List <Node > sourceNodes = configuration .get ("sourceNodes" , new ArrayList <>());
191
200
LongStream sourceNodeIds = sourceNodes .stream ().mapToLong (Node ::getId );
192
- PageRankAlgorithm prAlgo = PageRankAlgorithm .of (
193
- tracker ,
194
- graph ,
195
- dampingFactor ,
196
- sourceNodeIds ,
197
- Pools .DEFAULT ,
198
- concurrency ,
199
- batchSize );
201
+
202
+ PageRankAlgorithm prAlgo ;
203
+ if (weightPropertyKey != null ) {
204
+ prAlgo = PageRankAlgorithm .weightedOf (
205
+ tracker ,
206
+ graph ,
207
+ dampingFactor ,
208
+ sourceNodeIds ,
209
+ Pools .DEFAULT ,
210
+ concurrency ,
211
+ batchSize );
212
+ } else {
213
+ prAlgo = PageRankAlgorithm .of (
214
+ tracker ,
215
+ graph ,
216
+ dampingFactor ,
217
+ sourceNodeIds ,
218
+ Pools .DEFAULT ,
219
+ concurrency ,
220
+ batchSize );
221
+ }
222
+
223
+
200
224
Algorithm <?> algo = prAlgo
201
225
.algorithm ()
202
226
.withLog (log )
0 commit comments