Skip to content

Commit

Permalink
Fix how resolvers are handled during driver construction (#626)
Browse files Browse the repository at this point in the history
  • Loading branch information
ali-ince authored Jun 26, 2024
1 parent 18385f4 commit 29b77b4
Show file tree
Hide file tree
Showing 11 changed files with 279 additions and 203 deletions.
4 changes: 1 addition & 3 deletions common/src/main/scala/org/neo4j/spark/util/DriverCache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ class DriverCache(private val options: Neo4jDriverOptions, private val jobId: St
def getOrCreate(): Driver = {
this.synchronized {
jobIdCache.add(jobId)
cache.computeIfAbsent(options, new function.Function[Neo4jDriverOptions, Driver] {
override def apply(t: Neo4jDriverOptions): Driver = GraphDatabase.driver(t.url, t.toNeo4jAuth, t.toDriverConfig)
})
cache.computeIfAbsent(options, (t: Neo4jDriverOptions) => t.createDriver())
}
}

Expand Down
36 changes: 25 additions & 11 deletions common/src/main/scala/org/neo4j/spark/util/Neo4jOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,24 @@ import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import scala.language.implicitConversions
import org.apache.spark.internal.Logging
import org.jetbrains.annotations.TestOnly


class Neo4jOptions(private val options: java.util.Map[String, String]) extends Serializable with Logging {

import Neo4jOptions._
import QueryType._

def asMap() = new util.HashMap[String, String](options)

private def parameters: util.Map[String, String] = {
val sparkOptions = SparkSession.getActiveSession
.map { _.conf
.getAll
.filterKeys(k => k.startsWith("neo4j."))
.map { elem => (elem._1.substring("neo4.".length + 1), elem._2) }
.toMap
.map {
_.conf
.getAll
.filterKeys(k => k.startsWith("neo4j."))
.map { elem => (elem._1.substring("neo4j.".length), elem._2) }
.toMap
}
.getOrElse(Map.empty)

Expand Down Expand Up @@ -184,9 +187,9 @@ class Neo4jOptions(private val options: java.util.Map[String, String]) extends S
val retries = getParameter(TRANSACTION_RETRIES, DEFAULT_TRANSACTION_RETRIES.toString).toInt
val failOnTransactionCodes = getParameter(TRANSACTION_CODES_FAIL, DEFAULT_EMPTY)
.split(",")
.map(_.trim)
.filter(_.nonEmpty)
.toSet
.map(_.trim)
.filter(_.nonEmpty)
.toSet
val batchSize = getParameter(BATCH_SIZE, DEFAULT_BATCH_SIZE.toString).toInt
val retryTimeout = getParameter(TRANSACTION_RETRY_TIMEOUT, DEFAULT_TRANSACTION_RETRY_TIMEOUT.toString).toInt
Neo4jTransactionMetadata(retries, failOnTransactionCodes, batchSize, retryTimeout)
Expand Down Expand Up @@ -269,16 +272,19 @@ case class Neo4jApocConfig(procedureConfigMap: Map[String, AnyRef])
case class Neo4jSchemaOptimizations(nodeConstraint: ConstraintsOptimizationType.Value,
relConstraint: ConstraintsOptimizationType.Value,
schemaConstraints: Set[SchemaConstraintsOptimizationType.Value])

case class Neo4jSchemaMetadata(flattenLimit: Int,
strategy: SchemaStrategy.Value,
optimizationType: OptimizationType.Value,
optimization: Neo4jSchemaOptimizations,
mapGroupDuplicateKeys: Boolean)

case class Neo4jTransactionMetadata(retries: Int, failOnTransactionCodes: Set[String], batchSize: Int, retryTimeout: Long)

case class Neo4jNodeMetadata(labels: Seq[String], nodeKeys: Map[String, String], properties: Map[String, String]) {
def includesProperty(name: String): Boolean = nodeKeys.contains(name) || properties.contains(name)
}

case class Neo4jRelationshipMetadata(
source: Neo4jNodeMetadata,
target: Neo4jNodeMetadata,
Expand All @@ -290,6 +296,7 @@ case class Neo4jRelationshipMetadata(
saveStrategy: RelationshipSaveStrategy.Value,
relationshipKeys: Map[String, String]
)

case class Neo4jQueryMetadata(query: String, queryCount: String)

case class Neo4jGdsMetadata(parameters: util.Map[String, Any])
Expand Down Expand Up @@ -329,7 +336,12 @@ case class Neo4jDriverOptions(
connectionTimeout: Int
) extends Serializable {

def toDriverConfig: Config = {
def createDriver(): Driver = {
val (url, _) = connectionUrls
GraphDatabase.driver(url, toNeo4jAuth, toDriverConfig)
}

private def toDriverConfig: Config = {
val builder = Config.builder()
.withUserAgent(s"neo4j-${Neo4jUtil.connectorEnv}-connector/${Neo4jUtil.connectorVersion}")
.withLogging(Logging.slf4j())
Expand Down Expand Up @@ -370,18 +382,20 @@ case class Neo4jDriverOptions(
}

// public only for testing purposes
@TestOnly
def connectionUrls: (URI, Set[ServerAddress]) = {
val urls = url.split(",").toList
val extraUrls = urls
val resolved = urls
.drop(1)
.map(_.trim)
.map(URI.create)
.map(uri => ServerAddress.of(uri.getHost, if (uri.getPort > -1) uri.getPort else 7687))
.toSet
(URI.create(urls.head.trim), extraUrls)
(URI.create(urls.head.trim), resolved)
}

// public only for testing purposes
@TestOnly
def toNeo4jAuth: AuthToken = {
auth match {
case "basic" => AuthTokens.basic(username, password)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,32 @@ package org.neo4j.spark.service

import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.ArgumentCaptor
import org.neo4j.driver.AuthToken
import org.neo4j.driver.Config
import org.neo4j.driver.GraphDatabase
import org.neo4j.spark.util.DriverCache
import org.neo4j.spark.util.Neo4jOptions
import org.mockito.ArgumentMatchers
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito.times
import org.neo4j.driver.{AuthTokens, Config, GraphDatabase}
import org.neo4j.spark.util.{DriverCache, Neo4jOptions}
import org.powermock.api.mockito.PowerMockito
import org.powermock.core.classloader.annotations.PrepareForTest
import org.powermock.modules.junit4.PowerMockRunner
import org.testcontainers.shaded.com.google.common.io.BaseEncoding

import java.net.URI
import java.util
import org.junit.Assert.assertEquals
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito.times

@PrepareForTest(Array(classOf[GraphDatabase]))
@RunWith(classOf[PowerMockRunner])
class AuthenticationTest {

@Test
def testLdapConnectionToken(): Unit = {
val token = BaseEncoding.base64.encode("user:password".getBytes)
val options = new util.HashMap[String, String]
options.put("url", "bolt://localhost:7687")
options.put("authentication.type", "custom")
options.put("authentication.custom.credentials", BaseEncoding.base64.encode("user:password".getBytes))
options.put("authentication.custom.credentials", token)
options.put("labels", "Person")

val argumentCaptor = ArgumentCaptor.forClass(classOf[AuthToken])
val neo4jOptions = new Neo4jOptions(options)
val neo4jDriverOptions = neo4jOptions.connection
val driverCache = new DriverCache(neo4jDriverOptions, "jobId")
Expand All @@ -40,19 +37,17 @@ class AuthenticationTest {
driverCache.getOrCreate()

PowerMockito.verifyStatic(classOf[GraphDatabase], times(1))
GraphDatabase.driver(anyString, argumentCaptor.capture, any(classOf[Config]))

assertEquals(neo4jDriverOptions.toNeo4jAuth, argumentCaptor.getValue)
GraphDatabase.driver(any[URI](), ArgumentMatchers.eq(AuthTokens.custom("", token, "", "")) , any(classOf[Config]))
}

@Test
def testBearerAuthToken(): Unit = {
val token = BaseEncoding.base64.encode("user:password".getBytes)
val options = new util.HashMap[String, String]
options.put("url", "bolt://localhost:7687")
options.put("authentication.type", "bearer")
options.put("authentication.bearer.token", BaseEncoding.base64.encode("user:password".getBytes))
options.put("authentication.bearer.token", token)

val argumentCaptor = ArgumentCaptor.forClass(classOf[AuthToken])
val neo4jOptions = new Neo4jOptions(options)
val neo4jDriverOptions = neo4jOptions.connection
val driverCache = new DriverCache(neo4jDriverOptions, "jobId")
Expand All @@ -62,8 +57,6 @@ class AuthenticationTest {
driverCache.getOrCreate()

PowerMockito.verifyStatic(classOf[GraphDatabase], times(1))
GraphDatabase.driver(anyString, argumentCaptor.capture, any(classOf[Config]))

assertEquals(neo4jDriverOptions.toNeo4jAuth, argumentCaptor.getValue)
GraphDatabase.driver(any[URI](), ArgumentMatchers.eq(AuthTokens.bearer(token)), any())
}
}
51 changes: 51 additions & 0 deletions common/src/test/scala/org/neo4j/spark/util/Neo4jOptionsIT.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package org.neo4j.spark.util

import org.junit.Assert.{assertEquals, assertNotNull}
import org.junit.{Ignore, Test}
import org.neo4j.spark.SparkConnectorScalaSuiteIT
import org.neo4j.spark.SparkConnectorScalaSuiteIT.server

class Neo4jOptionsIT extends SparkConnectorScalaSuiteIT {

@Test
def shouldConstructDriver(): Unit = {
val options: java.util.Map[String, String] = new java.util.HashMap[String, String]()
options.put(Neo4jOptions.URL, server.getBoltUrl)
options.put(Neo4jOptions.AUTH_TYPE, "none")

val neo4jOptions = new Neo4jOptions(options)

use(neo4jOptions.connection.createDriver()) { driver =>
assertNotNull(driver)

use(driver.session()) { session =>
assertEquals(1, session.run("RETURN 1").single().get(0).asInt())
}
}
}

@Test
@Ignore("This requires a fix on driver, ignoring until it is implemented")
def shouldConstructDriverWithResolver(): Unit = {
val options: java.util.Map[String, String] = new java.util.HashMap[String, String]()
options.put(Neo4jOptions.URL, s"neo4j://localhost.localdomain:8888, bolt://localhost.localdomain:9999, ${server.getBoltUrl}")
options.put(Neo4jOptions.AUTH_TYPE, "none")

val neo4jOptions = new Neo4jOptions(options)

use(neo4jOptions.connection.createDriver()) { driver =>
assertNotNull(driver)

use(driver.session()) { session =>
assertEquals(1, session.run("RETURN 1").single().get(0).asInt())
}
}
}

def use[A <: AutoCloseable, B](resource: A)(code: A B): B =
try
code(resource)
finally
resource.close()

}
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class Neo4jOptionsTest {
@Test
def testUrls(): Unit = {
val options: java.util.Map[String, String] = new java.util.HashMap[String, String]()
options.put(Neo4jOptions.URL, "neo4j://localhost,neo4j://foo.bar,neo4j://foo.bar.baz:7783")
options.put(Neo4jOptions.URL, "neo4j://localhost, neo4j://foo.bar:7687, neo4j://foo.bar.baz:7783")

val neo4jOptions: Neo4jOptions = new Neo4jOptions(options)
val (baseUrl, resolvers) = neo4jOptions.connection.connectionUrls
Expand Down
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
<profile>
<id>neo4j-4.4</id>
<properties>
<neo4j.version>4.4.19</neo4j.version>
<neo4j.version>4.4.34</neo4j.version>
<neo4j.experimental>false</neo4j.experimental>
</properties>
</profile>
Expand All @@ -99,7 +99,7 @@
<activeByDefault>true</activeByDefault>
</activation>
<properties>
<neo4j.version>5.13.0</neo4j.version>
<neo4j.version>5.20.0</neo4j.version>
<neo4j.experimental>false</neo4j.experimental>
</properties>
</profile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import org.neo4j.spark.util.{ConstraintsOptimizationType, Neo4jOptions, SchemaCo
import java.sql.{Date, Timestamp}
import java.time.{LocalDate, LocalDateTime}
import scala.collection.JavaConverters.{iterableAsScalaIterableConverter, mapAsScalaMapConverter}
import scala.math.Ordering.Implicits.infixOrderingOps

object DataSourceSchemaWriterTSE {
@BeforeClass
def checkNeo4jVersion() {
Assume.assumeTrue(TestUtil.neo4jVersionAsDouble() >= 5.13)
Assume.assumeTrue(TestUtil.neo4jVersion() >= Versions.NEO4J_5_13)
}
}

Expand Down Expand Up @@ -47,10 +48,10 @@ class DataSourceSchemaWriterTSE extends SparkConnectorScalaBaseTSE {
.option(Neo4jOptions.SCHEMA_OPTIMIZATION, schemaOptimization)
.save()
val count: Long = SparkConnectorScalaSuiteIT.session().run(
"""
|MATCH (n:NodeWithSchema)
|RETURN count(n)
|""".stripMargin)
"""
|MATCH (n:NodeWithSchema)
|RETURN count(n)
|""".stripMargin)
.single()
.get(0)
.asLong()
Expand Down Expand Up @@ -137,7 +138,7 @@ class DataSourceSchemaWriterTSE extends SparkConnectorScalaBaseTSE {
Map("name" -> "spark_NODE-TYPE-CONSTRAINT-NodeWithSchema-stringArray", "type" -> "NODE_PROPERTY_TYPE", "entityType" -> "NODE", "labelsOrTypes" -> Seq("NodeWithSchema"), "properties" -> Seq("stringArray"), "propertyType" -> "LIST<STRING NOT NULL>"),
Map("name" -> "spark_NODE-TYPE-CONSTRAINT-NodeWithSchema-timestamp", "type" -> "NODE_PROPERTY_TYPE", "entityType" -> "NODE", "labelsOrTypes" -> Seq("NodeWithSchema"), "properties" -> Seq("timestamp"), "propertyType" -> "LOCAL DATETIME"),
Map("name" -> "spark_NODE-TYPE-CONSTRAINT-NodeWithSchema-timestampArray", "type" -> "NODE_PROPERTY_TYPE", "entityType" -> "NODE", "labelsOrTypes" -> Seq("NodeWithSchema"), "properties" -> Seq("timestampArray"), "propertyType" -> "LIST<LOCAL DATETIME NOT NULL>"),
Map("name" -> "spark_NODE_KEY-CONSTRAINT_NodeWithSchema_int-string", "propertyType" -> null, "properties" -> Seq("int", "string"), "labelsOrTypes" -> Seq("NodeWithSchema"), "entityType" -> "NODE", "type" -> "NODE_KEY")
Map("name" -> "spark_NODE_KEY-CONSTRAINT_NodeWithSchema_int-string", "propertyType" -> null, "properties" -> Seq("int", "string"), "labelsOrTypes" -> Seq("NodeWithSchema"), "entityType" -> "NODE", "type" -> "NODE_KEY")
)

val keys = Seq("name", "type", "entityType", "labelsOrTypes", "properties", "propertyType")
Expand Down
Loading

0 comments on commit 29b77b4

Please sign in to comment.