Skip to content

Commit

Permalink
fix: validate schema options before initializing schema (#630)
Browse files Browse the repository at this point in the history
  • Loading branch information
fbiville authored Jun 27, 2024
1 parent 29b77b4 commit 186ed05
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
private val dataConverter = SparkToNeo4jDataConverter()

override def node(row: InternalRow, schema: StructType): java.util.Map[String, AnyRef] = {
Validations.validate(ValidateSchemaOptions(options, schema))

val rowMap: java.util.Map[String, Object] = new java.util.HashMap[String, Object]
val keys: java.util.Map[String, Object] = new java.util.HashMap[String, Object]
val properties: java.util.Map[String, Object] = new java.util.HashMap[String, Object]
Expand Down Expand Up @@ -97,8 +95,6 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
override def relationship(row: InternalRow, schema: StructType): java.util.Map[String, AnyRef] = {
val rowMap: java.util.Map[String, AnyRef] = new java.util.HashMap[String, AnyRef]

Validations.validate(ValidateSchemaOptions(options, schema))

val consumer = options.relationshipMetadata.saveStrategy match {
case RelationshipSaveStrategy.NATIVE => nativeStrategyConsumer()
case RelationshipSaveStrategy.KEYS => keysStrategyConsumer()
Expand Down
21 changes: 10 additions & 11 deletions common/src/main/scala/org/neo4j/spark/service/SchemaService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
})
}

private def createOptimizationsForNode(struct: Optional[StructType]): Unit = {
private def createOptimizationsForNode(struct: StructType): Unit = {
val schemaMetadata = options.schemaMetadata.optimization
if (schemaMetadata.nodeConstraint != ConstraintsOptimizationType.NONE
|| schemaMetadata.schemaConstraints != Set(SchemaConstraintsOptimizationType.NONE)) {
Expand All @@ -694,13 +694,12 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
options.nodeMetadata.nodeKeys)
}
if (schemaMetadata.schemaConstraints.nonEmpty) {
val structType: StructType = struct.orElse(emptyStruct)
val propFromStruct: Map[String, String] = structType
val propFromStruct: Map[String, String] = struct
.map(f => (f.name, f.name))
.toMap
val propsFromMeta: Map[String, String] = options.nodeMetadata.nodeKeys ++ options.nodeMetadata.properties
createEntityTypeConstraint("NODE", options.nodeMetadata.labels.head,
propsFromMeta ++ propFromStruct, structType, schemaMetadata.schemaConstraints)
propsFromMeta ++ propFromStruct, struct, schemaMetadata.schemaConstraints)
}
} else { // TODO old behaviour, remove it in the future
options.schemaMetadata.optimizationType match {
Expand All @@ -714,7 +713,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
}
}

private def createOptimizationsForRelationship(struct: Optional[StructType]): Unit = {
private def createOptimizationsForRelationship(struct: StructType): Unit = {
val schemaMetadata = options.schemaMetadata.optimization
if (schemaMetadata.nodeConstraint != ConstraintsOptimizationType.NONE
|| schemaMetadata.relConstraint != ConstraintsOptimizationType.NONE
Expand All @@ -735,19 +734,18 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
if (schemaMetadata.schemaConstraints.nonEmpty) {
val sourceNodeProps: Map[String, String] = options.relationshipMetadata.source.nodeKeys ++ options.relationshipMetadata.source.properties
val targetNodeProps: Map[String, String] = options.relationshipMetadata.target.nodeKeys ++ options.relationshipMetadata.target.properties
val baseStruct = struct.orElse(emptyStruct)
val allNodeProps: Map[String, String] = sourceNodeProps ++ targetNodeProps
val relStruct: StructType = StructType(baseStruct.filterNot(f => allNodeProps.contains(f.name)))
val relStruct: StructType = StructType(struct.filterNot(f => allNodeProps.contains(f.name)))
val relFromStruct: Map[String, String] = relStruct
.map(f => (f.name, f.name))
.toMap
val propsFromMeta: Map[String, String] = options.relationshipMetadata.relationshipKeys ++ options.relationshipMetadata.properties
createEntityTypeConstraint("RELATIONSHIP",options.relationshipMetadata.relationshipType,
propsFromMeta ++ relFromStruct, baseStruct, schemaMetadata.schemaConstraints)
propsFromMeta ++ relFromStruct, struct, schemaMetadata.schemaConstraints)
createEntityTypeConstraint("NODE", options.relationshipMetadata.source.labels.head,
sourceNodeProps, baseStruct, schemaMetadata.schemaConstraints)
sourceNodeProps, struct, schemaMetadata.schemaConstraints)
createEntityTypeConstraint("NODE", options.relationshipMetadata.target.labels.head,
targetNodeProps, baseStruct, schemaMetadata.schemaConstraints)
targetNodeProps, struct, schemaMetadata.schemaConstraints)
}
} else { // TODO old behaviour, remove it in the future
options.schemaMetadata.optimizationType match {
Expand All @@ -764,7 +762,8 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
}
}

def createOptimizations(struct: Optional[StructType]): Unit = {
def createOptimizations(struct: StructType): Unit = {
Validations.validate(ValidateSchemaOptions(options, struct))
options.query.queryType match {
case QueryType.LABELS => createOptimizationsForNode(struct)
case QueryType.RELATIONSHIP => createOptimizationsForRelationship(struct)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ class Neo4jScan(neo4jOptions: Neo4jOptions,
optsMap.put(Neo4jOptions.STREAMING_METADATA_STORAGE, StorageType.SPARK.toString)
val newOpts = new Neo4jOptions(optsMap)
Validations.validate(ValidateReadStreaming(newOpts, jobId))
new Neo4jMicroBatchReader(Optional.of(schema), newOpts, jobId, aggregateColumns)
new Neo4jMicroBatchReader(schema, newOpts, jobId, aggregateColumns)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import org.neo4j.spark.util._
import java.lang
import java.util.Optional

class Neo4jMicroBatchReader(private val optionalSchema: Optional[StructType],
class Neo4jMicroBatchReader(private val schema: StructType,
private val neo4jOptions: Neo4jOptions,
private val jobId: String,
private val aggregateColumns: Array[AggregateFunc])
Expand All @@ -27,7 +27,7 @@ class Neo4jMicroBatchReader(private val optionalSchema: Optional[StructType],

private lazy val scriptResult = {
val schemaService = new SchemaService(neo4jOptions, driverCache)
schemaService.createOptimizations(optionalSchema)
schemaService.createOptimizations(schema)
val scriptResult = schemaService.execute(neo4jOptions.script)
schemaService.close()
scriptResult
Expand Down Expand Up @@ -108,7 +108,7 @@ class Neo4jMicroBatchReader(private val optionalSchema: Optional[StructType],

override def createReaderFactory(): PartitionReaderFactory = {
new Neo4jStreamingPartitionReaderFactory(
neo4jOptions, optionalSchema.orElse(new StructType()), jobId, scriptResult, offsetAccumulator, aggregateColumns
neo4jOptions, schema, jobId, scriptResult, offsetAccumulator, aggregateColumns
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Neo4jStreamingWriter(val queryId: String,

private lazy val scriptResult = {
val schemaService = new SchemaService(neo4jOptions, driverCache)
schemaService.createOptimizations(Optional.of(schema))
schemaService.createOptimizations(schema)
val scriptResult = schemaService.execute(neo4jOptions.script)
schemaService.close()
scriptResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Neo4jBatchWriter(jobId: String,
neo4jOptions: Neo4jOptions) extends BatchWrite{
override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory = {
val schemaService = new SchemaService(neo4jOptions, driverCache)
schemaService.createOptimizations(Optional.of(structType))
schemaService.createOptimizations(structType)
val scriptResult = schemaService.execute(neo4jOptions.script)
schemaService.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ package org.neo4j.spark

import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerStageCompleted}
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.junit.Assert.{assertEquals, assertTrue, fail}
import org.junit.{Assume, BeforeClass, Test}
import org.junit.Test
import org.neo4j.driver.summary.ResultSummary
import org.neo4j.driver.{Result, SessionConfig, Transaction, TransactionWork}
import org.neo4j.driver.{Result, Session, SessionConfig, Transaction, TransactionWork}
import org.neo4j.spark.writer.DataWriterMetrics

import java.util.concurrent.{CountDownLatch, TimeUnit}

class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE {

val sparkSession = SparkSession.builder().getOrCreate()
private val sparkSession = SparkSession.builder().getOrCreate()

import sparkSession.implicits._

Expand Down Expand Up @@ -496,15 +496,15 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE {
.option("relationship.target.node.keys", "instrument:name")
.save()
} catch {
case sparkException: SparkException => {
val clientException = ExceptionUtils.getRootCause(sparkException)
case exception: IllegalArgumentException => {
val clientException = ExceptionUtils.getRootCause(exception)
assertTrue(clientException.getMessage.equals(
"""Write failed due to the following errors:
| - Schema is missing musician_name from option `relationship.source.node.keys`
|
|The option key and value might be inverted.""".stripMargin))
}
case generic: Throwable => fail(s"should be thrown a ${classOf[SparkException].getName}, got ${generic.getClass} instead")
case generic: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${generic.getClass} instead")
}
}

Expand Down Expand Up @@ -533,16 +533,16 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE {
.option("relationship.target.node.keys", "instrument_name:name")
.save()
} catch {
case sparkException: SparkException => {
val clientException = ExceptionUtils.getRootCause(sparkException)
case exception: IllegalArgumentException => {
val clientException = ExceptionUtils.getRootCause(exception)
assertTrue(clientException.getMessage.equals(
"""Write failed due to the following errors:
| - Schema is missing instrument_name from option `relationship.target.node.keys`
| - Schema is missing musician_name, another_name from option `relationship.source.node.keys`
|
|The option key and value might be inverted.""".stripMargin))
}
case generic: Throwable => fail(s"should be thrown a ${classOf[SparkException].getName}, got ${generic.getClass} instead")
case generic: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${generic.getClass} instead")
}
}

Expand All @@ -565,15 +565,15 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE {
.option("node.properties", "musician_name:name,another_name:name")
.save()
} catch {
case sparkException: SparkException => {
val clientException = ExceptionUtils.getRootCause(sparkException)
case exception: IllegalArgumentException => {
val clientException = ExceptionUtils.getRootCause(exception)
assertTrue(clientException.getMessage.equals(
"""Write failed due to the following errors:
| - Schema is missing instrument_name from option `node.properties`
|
|The option key and value might be inverted.""".stripMargin))
}
case generic: Throwable => fail(s"should be thrown a ${classOf[SparkException].getName}, got ${generic.getClass} instead: ${generic.getMessage}")
case generic: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${generic.getClass} instead: ${generic.getMessage}")
}
}

Expand Down Expand Up @@ -604,6 +604,43 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE {
latch.await(30, TimeUnit.SECONDS)
}

@Test
def `does not create constraint if schema validation fails`(): Unit = {
val cities = Seq(
(1, "Cherbourg en Cotentin"),
(2, "London"),
(3, "Malmö"),
).toDF("id", "city")

try {
cities.write
.format(classOf[DataSource].getName)
.mode(SaveMode.Overwrite)
.option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl)
.option("labels", ":News")
.option("node.keys", "newsId")
.option("schema.optimization.node.keys", "UNIQUE")
.save()
} catch {
case _:Exception => {
}
}

var session: Session = null
try {
session = SparkConnectorScalaSuiteIT.driver.session()
val result = session.run("SHOW CONSTRAINTS YIELD labelsOrTypes WHERE labelsOrTypes[0] = 'News' RETURN count(*) AS count")
.single()
.get("count")
.asLong()
assertEquals(0, result)
} finally {
if (session != null) {
session.close()
}
}
}

class MetricsListener(expectedMetrics: Map[String, Any], done: CountDownLatch) extends SparkListener {
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
val actualMetrics = stageCompleted
Expand Down

0 comments on commit 186ed05

Please sign in to comment.