Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: validate schema options before initializing schema #630

Merged
merged 4 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
private val dataConverter = SparkToNeo4jDataConverter()

override def node(row: InternalRow, schema: StructType): java.util.Map[String, AnyRef] = {
Validations.validate(ValidateSchemaOptions(options, schema))

val rowMap: java.util.Map[String, Object] = new java.util.HashMap[String, Object]
val keys: java.util.Map[String, Object] = new java.util.HashMap[String, Object]
val properties: java.util.Map[String, Object] = new java.util.HashMap[String, Object]
Expand Down Expand Up @@ -97,8 +95,6 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
override def relationship(row: InternalRow, schema: StructType): java.util.Map[String, AnyRef] = {
val rowMap: java.util.Map[String, AnyRef] = new java.util.HashMap[String, AnyRef]

Validations.validate(ValidateSchemaOptions(options, schema))

val consumer = options.relationshipMetadata.saveStrategy match {
case RelationshipSaveStrategy.NATIVE => nativeStrategyConsumer()
case RelationshipSaveStrategy.KEYS => keysStrategyConsumer()
Expand Down
21 changes: 10 additions & 11 deletions common/src/main/scala/org/neo4j/spark/service/SchemaService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
})
}

private def createOptimizationsForNode(struct: Optional[StructType]): Unit = {
private def createOptimizationsForNode(struct: StructType): Unit = {
val schemaMetadata = options.schemaMetadata.optimization
if (schemaMetadata.nodeConstraint != ConstraintsOptimizationType.NONE
|| schemaMetadata.schemaConstraints != Set(SchemaConstraintsOptimizationType.NONE)) {
Expand All @@ -694,13 +694,12 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
options.nodeMetadata.nodeKeys)
}
if (schemaMetadata.schemaConstraints.nonEmpty) {
val structType: StructType = struct.orElse(emptyStruct)
val propFromStruct: Map[String, String] = structType
val propFromStruct: Map[String, String] = struct
.map(f => (f.name, f.name))
.toMap
val propsFromMeta: Map[String, String] = options.nodeMetadata.nodeKeys ++ options.nodeMetadata.properties
createEntityTypeConstraint("NODE", options.nodeMetadata.labels.head,
propsFromMeta ++ propFromStruct, structType, schemaMetadata.schemaConstraints)
propsFromMeta ++ propFromStruct, struct, schemaMetadata.schemaConstraints)
}
} else { // TODO old behaviour, remove it in the future
options.schemaMetadata.optimizationType match {
Expand All @@ -714,7 +713,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
}
}

private def createOptimizationsForRelationship(struct: Optional[StructType]): Unit = {
private def createOptimizationsForRelationship(struct: StructType): Unit = {
val schemaMetadata = options.schemaMetadata.optimization
if (schemaMetadata.nodeConstraint != ConstraintsOptimizationType.NONE
|| schemaMetadata.relConstraint != ConstraintsOptimizationType.NONE
Expand All @@ -735,19 +734,18 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
if (schemaMetadata.schemaConstraints.nonEmpty) {
val sourceNodeProps: Map[String, String] = options.relationshipMetadata.source.nodeKeys ++ options.relationshipMetadata.source.properties
val targetNodeProps: Map[String, String] = options.relationshipMetadata.target.nodeKeys ++ options.relationshipMetadata.target.properties
val baseStruct = struct.orElse(emptyStruct)
val allNodeProps: Map[String, String] = sourceNodeProps ++ targetNodeProps
val relStruct: StructType = StructType(baseStruct.filterNot(f => allNodeProps.contains(f.name)))
val relStruct: StructType = StructType(struct.filterNot(f => allNodeProps.contains(f.name)))
val relFromStruct: Map[String, String] = relStruct
.map(f => (f.name, f.name))
.toMap
val propsFromMeta: Map[String, String] = options.relationshipMetadata.relationshipKeys ++ options.relationshipMetadata.properties
createEntityTypeConstraint("RELATIONSHIP",options.relationshipMetadata.relationshipType,
propsFromMeta ++ relFromStruct, baseStruct, schemaMetadata.schemaConstraints)
propsFromMeta ++ relFromStruct, struct, schemaMetadata.schemaConstraints)
createEntityTypeConstraint("NODE", options.relationshipMetadata.source.labels.head,
sourceNodeProps, baseStruct, schemaMetadata.schemaConstraints)
sourceNodeProps, struct, schemaMetadata.schemaConstraints)
createEntityTypeConstraint("NODE", options.relationshipMetadata.target.labels.head,
targetNodeProps, baseStruct, schemaMetadata.schemaConstraints)
targetNodeProps, struct, schemaMetadata.schemaConstraints)
}
} else { // TODO old behaviour, remove it in the future
options.schemaMetadata.optimizationType match {
Expand All @@ -764,7 +762,8 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
}
}

def createOptimizations(struct: Optional[StructType]): Unit = {
def createOptimizations(struct: StructType): Unit = {
Validations.validate(ValidateSchemaOptions(options, struct))
options.query.queryType match {
case QueryType.LABELS => createOptimizationsForNode(struct)
case QueryType.RELATIONSHIP => createOptimizationsForRelationship(struct)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ class Neo4jScan(neo4jOptions: Neo4jOptions,
optsMap.put(Neo4jOptions.STREAMING_METADATA_STORAGE, StorageType.SPARK.toString)
val newOpts = new Neo4jOptions(optsMap)
Validations.validate(ValidateReadStreaming(newOpts, jobId))
new Neo4jMicroBatchReader(Optional.of(schema), newOpts, jobId, aggregateColumns)
new Neo4jMicroBatchReader(schema, newOpts, jobId, aggregateColumns)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import org.neo4j.spark.util._
import java.lang
import java.util.Optional

class Neo4jMicroBatchReader(private val optionalSchema: Optional[StructType],
class Neo4jMicroBatchReader(private val schema: StructType,
private val neo4jOptions: Neo4jOptions,
private val jobId: String,
private val aggregateColumns: Array[AggregateFunc])
Expand All @@ -27,7 +27,7 @@ class Neo4jMicroBatchReader(private val optionalSchema: Optional[StructType],

private lazy val scriptResult = {
val schemaService = new SchemaService(neo4jOptions, driverCache)
schemaService.createOptimizations(optionalSchema)
schemaService.createOptimizations(schema)
val scriptResult = schemaService.execute(neo4jOptions.script)
schemaService.close()
scriptResult
Expand Down Expand Up @@ -108,7 +108,7 @@ class Neo4jMicroBatchReader(private val optionalSchema: Optional[StructType],

override def createReaderFactory(): PartitionReaderFactory = {
new Neo4jStreamingPartitionReaderFactory(
neo4jOptions, optionalSchema.orElse(new StructType()), jobId, scriptResult, offsetAccumulator, aggregateColumns
neo4jOptions, schema, jobId, scriptResult, offsetAccumulator, aggregateColumns
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Neo4jStreamingWriter(val queryId: String,

private lazy val scriptResult = {
val schemaService = new SchemaService(neo4jOptions, driverCache)
schemaService.createOptimizations(Optional.of(schema))
schemaService.createOptimizations(schema)
val scriptResult = schemaService.execute(neo4jOptions.script)
schemaService.close()
scriptResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Neo4jBatchWriter(jobId: String,
neo4jOptions: Neo4jOptions) extends BatchWrite{
override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory = {
val schemaService = new SchemaService(neo4jOptions, driverCache)
schemaService.createOptimizations(Optional.of(structType))
schemaService.createOptimizations(structType)
val scriptResult = schemaService.execute(neo4jOptions.script)
schemaService.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ package org.neo4j.spark

import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerStageCompleted}
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.junit.Assert.{assertEquals, assertTrue, fail}
import org.junit.{Assume, BeforeClass, Test}
import org.junit.Test
import org.neo4j.driver.summary.ResultSummary
import org.neo4j.driver.{Result, SessionConfig, Transaction, TransactionWork}
import org.neo4j.driver.{Result, Session, SessionConfig, Transaction, TransactionWork}
import org.neo4j.spark.writer.DataWriterMetrics

import java.util.concurrent.{CountDownLatch, TimeUnit}

class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE {

val sparkSession = SparkSession.builder().getOrCreate()
private val sparkSession = SparkSession.builder().getOrCreate()

import sparkSession.implicits._

Expand Down Expand Up @@ -496,15 +496,15 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE {
.option("relationship.target.node.keys", "instrument:name")
.save()
} catch {
case sparkException: SparkException => {
val clientException = ExceptionUtils.getRootCause(sparkException)
case exception: IllegalArgumentException => {
val clientException = ExceptionUtils.getRootCause(exception)
assertTrue(clientException.getMessage.equals(
"""Write failed due to the following errors:
| - Schema is missing musician_name from option `relationship.source.node.keys`
|
|The option key and value might be inverted.""".stripMargin))
}
case generic: Throwable => fail(s"should be thrown a ${classOf[SparkException].getName}, got ${generic.getClass} instead")
case generic: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${generic.getClass} instead")
}
}

Expand Down Expand Up @@ -533,16 +533,16 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE {
.option("relationship.target.node.keys", "instrument_name:name")
.save()
} catch {
case sparkException: SparkException => {
val clientException = ExceptionUtils.getRootCause(sparkException)
case exception: IllegalArgumentException => {
val clientException = ExceptionUtils.getRootCause(exception)
assertTrue(clientException.getMessage.equals(
"""Write failed due to the following errors:
| - Schema is missing instrument_name from option `relationship.target.node.keys`
| - Schema is missing musician_name, another_name from option `relationship.source.node.keys`
|
|The option key and value might be inverted.""".stripMargin))
}
case generic: Throwable => fail(s"should be thrown a ${classOf[SparkException].getName}, got ${generic.getClass} instead")
case generic: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${generic.getClass} instead")
}
}

Expand All @@ -565,15 +565,15 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE {
.option("node.properties", "musician_name:name,another_name:name")
.save()
} catch {
case sparkException: SparkException => {
val clientException = ExceptionUtils.getRootCause(sparkException)
case exception: IllegalArgumentException => {
val clientException = ExceptionUtils.getRootCause(exception)
assertTrue(clientException.getMessage.equals(
"""Write failed due to the following errors:
| - Schema is missing instrument_name from option `node.properties`
|
|The option key and value might be inverted.""".stripMargin))
}
case generic: Throwable => fail(s"should be thrown a ${classOf[SparkException].getName}, got ${generic.getClass} instead: ${generic.getMessage}")
case generic: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${generic.getClass} instead: ${generic.getMessage}")
}
}

Expand Down Expand Up @@ -604,6 +604,43 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE {
latch.await(30, TimeUnit.SECONDS)
}

@Test
def `does not create constraint if schema validation fails`(): Unit = {
val cities = Seq(
(1, "Cherbourg en Cotentin"),
(2, "London"),
(3, "Malmö"),
).toDF("id", "city")

try {
cities.write
.format(classOf[DataSource].getName)
.mode(SaveMode.Overwrite)
.option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl)
.option("labels", ":News")
.option("node.keys", "newsId")
.option("schema.optimization.node.keys", "UNIQUE")
.save()
} catch {
case _:Exception => {
}
}

var session: Session = null
try {
session = SparkConnectorScalaSuiteIT.driver.session()
val result = session.run("SHOW CONSTRAINTS YIELD labelsOrTypes WHERE labelsOrTypes[0] = 'News' RETURN count(*) AS count")
.single()
.get("count")
.asLong()
assertEquals(0, result)
} finally {
if (session != null) {
session.close()
}
}
}

class MetricsListener(expectedMetrics: Map[String, Any], done: CountDownLatch) extends SparkListener {
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
val actualMetrics = stageCompleted
Expand Down