From 52297657214be525ff4c179f21ba8219f73fb9ce Mon Sep 17 00:00:00 2001 From: Mark Needham Date: Wed, 10 Apr 2019 11:29:27 +0100 Subject: [PATCH] Degree cutoff skip values (#880) * consider skip values when checking degree cut off * d has no links so it gets filtered out by the degree cut off * d has no links so it gets filtered out by the degree cut off * typo --- .../graphalgo/similarity/SimilarityInput.java | 22 ++++++++++ .../graphalgo/similarity/SimilarityProc.java | 44 +------------------ .../graphalgo/similarity/WeightedInput.java | 30 +++++++++++++ .../similarity/WeightedInputTest.java | 39 ++++++++++++++++ .../graphalgo/core/utils/Intersections.java | 2 +- .../graphalgo/algo/similarity/CosineTest.java | 3 -- .../algo/similarity/EuclideanTest.java | 3 -- 7 files changed, 94 insertions(+), 49 deletions(-) diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityInput.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityInput.java index 9cf9ea942..fcd551c12 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityInput.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityInput.java @@ -64,4 +64,26 @@ static int[] indexesFor(long[] inputIds, ProcedureConfiguration configuration, S } } + + static List extractValues(Object rawValues) { + if (rawValues == null) { + return Collections.emptyList(); + } + + List valueList = new ArrayList<>(); + if (rawValues instanceof long[]) { + long[] values = (long[]) rawValues; + for (long value : values) { + valueList.add(value); + } + } else if (rawValues instanceof double[]) { + double[] values = (double[]) rawValues; + for (double value : values) { + valueList.add(value); + } + } else { + valueList = (List) rawValues; + } + return valueList; + } } diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java index 580bc15f8..119c060dc 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java @@ -134,7 +134,7 @@ CategoricalInput[] prepareCategories(List> data, long degree CategoricalInput[] ids = new CategoricalInput[data.size()]; int idx = 0; for (Map row : data) { - List targetIds = extractValues(row.get("categories")); + List targetIds = SimilarityInput.extractValues(row.get("categories")); int size = targetIds.size(); if (size > degreeCutoff) { long[] targets = new long[size]; @@ -156,7 +156,7 @@ WeightedInput[] prepareWeights(Object rawData, ProcedureConfiguration configurat return prepareSparseWeights(api, (String) rawData, skipValue, configuration); } else { List> data = (List>) rawData; - return preparseDenseWeights(data, getDegreeCutoff(configuration), skipValue); + return WeightedInput.prepareDenseWeights(data, getDegreeCutoff(configuration), skipValue); } } @@ -164,24 +164,6 @@ Double readSkipValue(ProcedureConfiguration configuration) { return configuration.get("skipValue", Double.NaN); } - private WeightedInput[] preparseDenseWeights(List> data, long degreeCutoff, Double skipValue) { - WeightedInput[] inputs = new WeightedInput[data.size()]; - int idx = 0; - for (Map row : data) { - - List weightList = extractValues(row.get("weights")); - - int size = weightList.size(); - if (size > degreeCutoff) { - double[] weights = Weights.buildWeights(weightList); - inputs[idx++] = skipValue == null ? WeightedInput.dense((Long) row.get("item"), weights) : WeightedInput.dense((Long) row.get("item"), weights, skipValue); - } - } - if (idx != inputs.length) inputs = Arrays.copyOf(inputs, idx); - Arrays.sort(inputs); - return inputs; - } - private WeightedInput[] prepareSparseWeights(GraphDatabaseAPI api, String query, Double skipValue, ProcedureConfiguration configuration) throws Exception { Map params = configuration.getParams(); Long degreeCutoff = getDegreeCutoff(configuration); @@ -230,28 +212,6 @@ private WeightedInput[] prepareSparseWeights(GraphDatabaseAPI api, String query, return inputs; } - private List extractValues(Object rawValues) { - if (rawValues == null) { - return Collections.emptyList(); - } - - List valueList = new ArrayList<>(); - if (rawValues instanceof long[]) { - long[] values = (long[]) rawValues; - for (long value : values) { - valueList.add(value); - } - } else if (rawValues instanceof double[]) { - double[] values = (double[]) rawValues; - for (double value : values) { - valueList.add(value); - } - } else { - valueList = (List) rawValues; - } - return valueList; - } - int getTopK(ProcedureConfiguration configuration) { return configuration.getInt("topK", 0); } diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/WeightedInput.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/WeightedInput.java index b37e23aa9..7badcf4eb 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/WeightedInput.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/WeightedInput.java @@ -20,6 +20,10 @@ import org.neo4j.graphalgo.core.utils.Intersections; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + class WeightedInput implements Comparable, SimilarityInput { private final long id; private int itemCount; @@ -62,6 +66,32 @@ public static WeightedInput dense(long id, double[] weights) { return new WeightedInput(id, weights); } + static WeightedInput[] prepareDenseWeights(List> data, long degreeCutoff, Double skipValue) { + WeightedInput[] inputs = new WeightedInput[data.size()]; + int idx = 0; + + boolean skipAnything = skipValue != null; + boolean skipNan = skipAnything && Double.isNaN(skipValue); + + for (Map row : data) { + List weightList = SimilarityInput.extractValues(row.get("weights")); + + long weightsSize = skipAnything ? skipSize(skipValue, skipNan, weightList) : weightList.size(); + + if (weightsSize > degreeCutoff) { + double[] weights = Weights.buildWeights(weightList); + inputs[idx++] = skipValue == null ? dense((Long) row.get("item"), weights) : dense((Long) row.get("item"), weights, skipValue); + } + } + if (idx != inputs.length) inputs = Arrays.copyOf(inputs, idx); + Arrays.sort(inputs); + return inputs; + } + + private static long skipSize(Double skipValue, boolean skipNan, List weightList) { + return weightList.stream().filter(value -> !Intersections.shouldSkip(value.doubleValue(), skipValue, skipNan)).count(); + } + public int compareTo(WeightedInput o) { return Long.compare(id, o.id); } diff --git a/algo/src/test/java/org/neo4j/graphalgo/similarity/WeightedInputTest.java b/algo/src/test/java/org/neo4j/graphalgo/similarity/WeightedInputTest.java index 0dbf27165..5b46b8ef4 100644 --- a/algo/src/test/java/org/neo4j/graphalgo/similarity/WeightedInputTest.java +++ b/algo/src/test/java/org/neo4j/graphalgo/similarity/WeightedInputTest.java @@ -19,11 +19,50 @@ package org.neo4j.graphalgo.similarity; import org.junit.Test; +import org.neo4j.helpers.collection.MapUtil; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; import static junit.framework.TestCase.assertEquals; import static junit.framework.TestCase.assertNull; public class WeightedInputTest { + @Test + public void degreeCutoffBasedOnSkipValue() { + List> data = new ArrayList<>(); + data.add(MapUtil.map("item", 1L,"weights", Arrays.asList(2.0, 3.0, 4.0))); + data.add(MapUtil.map("item", 2L,"weights", Arrays.asList(2.0, 3.0, Double.NaN))); + + WeightedInput[] weightedInputs = WeightedInput.prepareDenseWeights(data, 2L, Double.NaN); + + assertEquals(1, weightedInputs.length); + } + + @Test + public void degreeCutoffWithoutSkipValue() { + List> data = new ArrayList<>(); + data.add(MapUtil.map("item", 1L,"weights", Arrays.asList(2.0, 3.0, 4.0))); + data.add(MapUtil.map("item", 2L,"weights", Arrays.asList(2.0, 3.0, Double.NaN))); + + WeightedInput[] weightedInputs = WeightedInput.prepareDenseWeights(data, 2L, null); + + assertEquals(2, weightedInputs.length); + } + + @Test + public void degreeCutoffWithNumericSkipValue() { + List> data = new ArrayList<>(); + data.add(MapUtil.map("item", 1L,"weights", Arrays.asList(2.0, 3.0, 4.0))); + data.add(MapUtil.map("item", 2L,"weights", Arrays.asList(2.0, 3.0, 5.0))); + + WeightedInput[] weightedInputs = WeightedInput.prepareDenseWeights(data, 2L, 5.0); + + assertEquals(1, weightedInputs.length); + } + @Test public void pearsonNoCompression() { double[] weights1 = new double[]{1, 2, 3, 4, 4, 4, 4, 5, 6}; diff --git a/core/src/main/java/org/neo4j/graphalgo/core/utils/Intersections.java b/core/src/main/java/org/neo4j/graphalgo/core/utils/Intersections.java index abf3e1621..a6b3e526d 100644 --- a/core/src/main/java/org/neo4j/graphalgo/core/utils/Intersections.java +++ b/core/src/main/java/org/neo4j/graphalgo/core/utils/Intersections.java @@ -269,7 +269,7 @@ public static double pearsonSkip(double[] vector1, double[] vector2, int len, do return Double.isNaN(result) ? 0 : result; } - private static boolean shouldSkip(double weight, double skipValue, boolean skipNan) { + public static boolean shouldSkip(double weight, double skipValue, boolean skipNan) { return weight == skipValue || (skipNan && Double.isNaN(weight)); } diff --git a/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/CosineTest.java b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/CosineTest.java index 4b658aced..ea9e244fe 100644 --- a/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/CosineTest.java +++ b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/CosineTest.java @@ -238,10 +238,7 @@ public void cosineSkipStreamTest() { assertTrue(results.hasNext()); assert01Skip(results.next()); assert02Skip(results.next()); - assert03Skip(results.next()); assert12Skip(results.next()); - assert13Skip(results.next()); - assert23Skip(results.next()); assertFalse(results.hasNext()); } diff --git a/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/EuclideanTest.java b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/EuclideanTest.java index f46a403fa..d3959d856 100644 --- a/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/EuclideanTest.java +++ b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/EuclideanTest.java @@ -259,10 +259,7 @@ public void eucideanSkipStreamTest() { assertTrue(results.hasNext()); assert01Skip(results.next()); assert02Skip(results.next()); - assert03Skip(results.next()); assert12Skip(results.next()); - assert13Skip(results.next()); - assert23Skip(results.next()); assertFalse(results.hasNext()); }