-
Notifications
You must be signed in to change notification settings - Fork 16
Clustering algorithm based on HashingTF LSH method #25
base: develop
Are you sure you want to change the base?
Changes from all commits
38027bd
794ba69
f8b432d
f1fb2fa
2eb53a9
00218a9
808a7ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
package net.sansa_stack.ml.spark.clustering.algorithms | ||
|
||
import org.apache.spark.sql.SparkSession | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.jena.graph.Triple | ||
import org.apache.spark.sql.DataFrame | ||
import org.apache.spark.sql.Dataset | ||
import org.apache.spark.ml.feature._ | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.Row | ||
import org.apache.spark.sql.types.DataTypes | ||
import org.apache.spark.mllib.linalg.Vector | ||
import org.graphframes._ | ||
import org.graphframes.GraphFrame | ||
import org.apache.spark.ml.evaluation.ClusteringEvaluator | ||
import java.nio.file.attribute.BasicFileAttributes | ||
import java.nio.file._ | ||
import java.io._ | ||
|
||
/* | ||
* | ||
* Clustering | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Give a bit of a better description i.e name of the cluster and what id does. |
||
* @triplesType - List of rdf subjects and their predicates. | ||
* @return - cluster of similar subjects. | ||
*/ | ||
|
||
class LocalitySensitiveHashing(spark: SparkSession, nTriplesRDD: RDD[Triple], dir_path: String) extends Serializable { | ||
|
||
def run() = { | ||
val parsedTriples = getParsedTriples() | ||
val extractedEntity = getOnlyPredicates(parsedTriples) | ||
val featuredData_Df: DataFrame = vectoriseText(extractedEntity) | ||
val (model: MinHashLSHModel, transformedData_Df: DataFrame) = minHashLSH(featuredData_Df) | ||
val dataset: Dataset[_] = approxSimilarityJoin(model, transformedData_Df) | ||
clusterFormation(dataset, featuredData_Df) | ||
} | ||
|
||
def getParsedTriples(): RDD[(String, String, Object)] = { | ||
// Extracting last part of Triples | ||
return nTriplesRDD.distinct() | ||
.map(f => { | ||
val s = f.getSubject.getLocalName | ||
val p = f.getPredicate.getLocalName | ||
|
||
if (f.getObject.isURI()) { | ||
val o = f.getObject.getLocalName | ||
(s, p, o) | ||
} else { | ||
val o = f.getObject.getLiteralValue | ||
(s, p, o) | ||
} | ||
}) | ||
} | ||
|
||
def getOnlyPredicates(parsedTriples: RDD[(String, String, Object)]): RDD[(String, String)] = { | ||
return parsedTriples.map(f => { | ||
val key = f._1 + "" // Subject is the key | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why you are adding |
||
val value = f._2 + "" // Predicates are the values | ||
(key, value.replace(",", " ").stripSuffix(" ").distinct) | ||
}).reduceByKey(_ + " " + _) | ||
} | ||
|
||
def removeStopwords(tokenizedData_Df: DataFrame): DataFrame = { | ||
val remover = new StopWordsRemover().setInputCol("words").setOutputCol("filtered_words") | ||
val removed_df = remover.transform(tokenizedData_Df) | ||
return remover.transform(tokenizedData_Df) | ||
} | ||
|
||
def vectoriseText(entities: RDD[(String, String)]): DataFrame = { | ||
val entityProfile_Df = spark.createDataFrame(entities).toDF("entities", "attributes") | ||
val tokenizer = new Tokenizer().setInputCol("attributes").setOutputCol("words") | ||
val tokenizedData_Df = tokenizer.transform(entityProfile_Df) | ||
val cleanData_Df = removeStopwords(tokenizedData_Df).distinct | ||
val cleandfrdd = cleanData_Df.select("filtered_words").distinct.rdd | ||
val vocab_size = calculateVocabsize(cleandfrdd) | ||
val hashingTf = new HashingTF().setInputCol("filtered_words"). | ||
setOutputCol("raw_Features").setNumFeatures(Math.round(0.90 * vocab_size).toInt) | ||
val isNonZeroVector = udf({ v: Vector => v.numNonzeros > 0 }, DataTypes.BooleanType) | ||
val featuredData_hashedDf = hashingTf.transform(cleanData_Df).filter(isNonZeroVector(col("raw_Features"))) | ||
val idf = new IDF().setInputCol("raw_Features").setOutputCol("features") | ||
val idfModel = idf.fit(featuredData_hashedDf) | ||
val rescaledHashedData = idfModel.transform(featuredData_hashedDf). | ||
filter(isNonZeroVector(col("features"))) | ||
|
||
return rescaledHashedData | ||
} | ||
|
||
def calculateVocabsize(cleandfrdd: RDD[Row]): Int = { | ||
val vocab = cleandfrdd.map(_.mkString).reduce(_ + ", " + _).split(", ").toSet | ||
return (vocab.size) | ||
} | ||
|
||
def minHashLSH(featuredData_Df: DataFrame): (MinHashLSHModel, DataFrame) = { | ||
val mh = new MinHashLSH().setNumHashTables(3).setInputCol("features").setOutputCol("HashedValues") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this number |
||
val model = mh.fit(featuredData_Df) | ||
val transformedData_Df = model.transform(featuredData_Df) | ||
return (model, transformedData_Df) | ||
} | ||
|
||
def approxSimilarityJoin(model: MinHashLSHModel, transformedData_Df: DataFrame): Dataset[_] = { | ||
val threshold = 0.40 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How and who define this threshold i.e. to be |
||
return model.approxSimilarityJoin(transformedData_Df, transformedData_Df, threshold) | ||
} | ||
|
||
def clusterFormation(dataset: Dataset[_], featuredData_Df: DataFrame) = { | ||
val featuredData = featuredData_Df.drop("attributes", "words", "filtered_words") | ||
|
||
val refined_entities_dataset = dataset | ||
.filter("datasetA.entities != datasetB.entities") | ||
.select(col("datasetA.entities").alias("src"), col("datasetB.entities").alias("dst")) | ||
|
||
import spark.implicits._ | ||
val c_1 = refined_entities_dataset.select("src") | ||
val c_2 = refined_entities_dataset.select("dst") | ||
val vertexDF = c_1.union(c_2).distinct().toDF("id") | ||
|
||
val g = GraphFrame(vertexDF, refined_entities_dataset) | ||
g.persist() | ||
spark.sparkContext.setCheckpointDir(dir_path) | ||
|
||
// Connected Components are the generated clusters. | ||
val connected_components = g.connectedComponents.run() | ||
|
||
//Removing the graphframes checkpoint directory | ||
val file_path = Paths.get(dir_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this going to work in the cluster? or the dir_path of the checkpoint is always considered to be held on the driver? We should use any file system configurations as soon as it is needed to be distributed across the cluster. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why would somebody change the checkpoint dir locally? Also, the whole param isn't documented in the constructor. In my opinion, that should be configured during Spark setup or maybe if you really need it during Spark submit but not in the code. |
||
removePathFiles(file_path) | ||
|
||
val connected_components_ = connected_components.withColumnRenamed("component", "prediction"). | ||
withColumnRenamed("id", "entities") | ||
clusterQuality(connected_components_, featuredData) | ||
} | ||
|
||
/* | ||
* Removing the graphframes checkpoint directory. | ||
*/ | ||
|
||
def removePathFiles(root: Path): Unit = { | ||
if (Files.exists(root)) { | ||
Files.walkFileTree(root, new SimpleFileVisitor[Path] { | ||
override def visitFile(file: Path, attrs: BasicFileAttributes): FileVisitResult = { | ||
Files.delete(file) | ||
FileVisitResult.CONTINUE | ||
} | ||
|
||
override def postVisitDirectory(dir: Path, exc: IOException): FileVisitResult = { | ||
Files.delete(dir) | ||
FileVisitResult.CONTINUE | ||
} | ||
}) | ||
} | ||
} | ||
|
||
/* | ||
* Calculating Silhouette score, which will tell how good clusters are. | ||
* Silhouette values ranges from [-1,1]. | ||
* Values closer to 1 indicates better clusters | ||
*/ | ||
|
||
def clusterQuality(connectedComponents: DataFrame , featuredData: DataFrame) = { | ||
var silhouetteInput = connectedComponents.join(featuredData, "entities") | ||
val evaluator = new ClusteringEvaluator().setPredictionCol("prediction"). | ||
setFeaturesCol("features").setMetricName("silhouette") | ||
val silhouette = evaluator.evaluate(silhouetteInput) | ||
} | ||
|
||
} | ||
|
||
object LocalitySensitiveHashing { | ||
def apply(spark: SparkSession, nTriplesRDD: RDD[Triple], dir_path: String): LocalitySensitiveHashing = new LocalitySensitiveHashing(spark, nTriplesRDD, dir_path) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thanks for adding the
graphframes
dependency. The project seems to build now. Consider using the latest version i.e.0.8.0
and also format it :) -- i.e. align with other dependency lists.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have to add a Maven repo, nobody wants to add local Jars in the project classpath manually nowadays.
It is also mentioned on the GraphFrames Maven artifact page:
That means, we have to add