diff --git a/common/src/main/scala/org/neo4j/spark/service/MappingService.scala b/common/src/main/scala/org/neo4j/spark/service/MappingService.scala index 8ff183269..ff2e7609b 100644 --- a/common/src/main/scala/org/neo4j/spark/service/MappingService.scala +++ b/common/src/main/scala/org/neo4j/spark/service/MappingService.scala @@ -25,8 +25,6 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions) private val dataConverter = SparkToNeo4jDataConverter() override def node(row: InternalRow, schema: StructType): java.util.Map[String, AnyRef] = { - Validations.validate(ValidateSchemaOptions(options, schema)) - val rowMap: java.util.Map[String, Object] = new java.util.HashMap[String, Object] val keys: java.util.Map[String, Object] = new java.util.HashMap[String, Object] val properties: java.util.Map[String, Object] = new java.util.HashMap[String, Object] @@ -97,8 +95,6 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions) override def relationship(row: InternalRow, schema: StructType): java.util.Map[String, AnyRef] = { val rowMap: java.util.Map[String, AnyRef] = new java.util.HashMap[String, AnyRef] - Validations.validate(ValidateSchemaOptions(options, schema)) - val consumer = options.relationshipMetadata.saveStrategy match { case RelationshipSaveStrategy.NATIVE => nativeStrategyConsumer() case RelationshipSaveStrategy.KEYS => keysStrategyConsumer() diff --git a/common/src/main/scala/org/neo4j/spark/service/SchemaService.scala b/common/src/main/scala/org/neo4j/spark/service/SchemaService.scala index 0da1893e1..96762f27f 100644 --- a/common/src/main/scala/org/neo4j/spark/service/SchemaService.scala +++ b/common/src/main/scala/org/neo4j/spark/service/SchemaService.scala @@ -684,7 +684,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache: }) } - private def createOptimizationsForNode(struct: Optional[StructType]): Unit = { + private def createOptimizationsForNode(struct: StructType): Unit = { val schemaMetadata = options.schemaMetadata.optimization if (schemaMetadata.nodeConstraint != ConstraintsOptimizationType.NONE || schemaMetadata.schemaConstraints != Set(SchemaConstraintsOptimizationType.NONE)) { @@ -694,13 +694,12 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache: options.nodeMetadata.nodeKeys) } if (schemaMetadata.schemaConstraints.nonEmpty) { - val structType: StructType = struct.orElse(emptyStruct) - val propFromStruct: Map[String, String] = structType + val propFromStruct: Map[String, String] = struct .map(f => (f.name, f.name)) .toMap val propsFromMeta: Map[String, String] = options.nodeMetadata.nodeKeys ++ options.nodeMetadata.properties createEntityTypeConstraint("NODE", options.nodeMetadata.labels.head, - propsFromMeta ++ propFromStruct, structType, schemaMetadata.schemaConstraints) + propsFromMeta ++ propFromStruct, struct, schemaMetadata.schemaConstraints) } } else { // TODO old behaviour, remove it in the future options.schemaMetadata.optimizationType match { @@ -714,7 +713,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache: } } - private def createOptimizationsForRelationship(struct: Optional[StructType]): Unit = { + private def createOptimizationsForRelationship(struct: StructType): Unit = { val schemaMetadata = options.schemaMetadata.optimization if (schemaMetadata.nodeConstraint != ConstraintsOptimizationType.NONE || schemaMetadata.relConstraint != ConstraintsOptimizationType.NONE @@ -735,19 +734,18 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache: if (schemaMetadata.schemaConstraints.nonEmpty) { val sourceNodeProps: Map[String, String] = options.relationshipMetadata.source.nodeKeys ++ options.relationshipMetadata.source.properties val targetNodeProps: Map[String, String] = options.relationshipMetadata.target.nodeKeys ++ options.relationshipMetadata.target.properties - val baseStruct = struct.orElse(emptyStruct) val allNodeProps: Map[String, String] = sourceNodeProps ++ targetNodeProps - val relStruct: StructType = StructType(baseStruct.filterNot(f => allNodeProps.contains(f.name))) + val relStruct: StructType = StructType(struct.filterNot(f => allNodeProps.contains(f.name))) val relFromStruct: Map[String, String] = relStruct .map(f => (f.name, f.name)) .toMap val propsFromMeta: Map[String, String] = options.relationshipMetadata.relationshipKeys ++ options.relationshipMetadata.properties createEntityTypeConstraint("RELATIONSHIP",options.relationshipMetadata.relationshipType, - propsFromMeta ++ relFromStruct, baseStruct, schemaMetadata.schemaConstraints) + propsFromMeta ++ relFromStruct, struct, schemaMetadata.schemaConstraints) createEntityTypeConstraint("NODE", options.relationshipMetadata.source.labels.head, - sourceNodeProps, baseStruct, schemaMetadata.schemaConstraints) + sourceNodeProps, struct, schemaMetadata.schemaConstraints) createEntityTypeConstraint("NODE", options.relationshipMetadata.target.labels.head, - targetNodeProps, baseStruct, schemaMetadata.schemaConstraints) + targetNodeProps, struct, schemaMetadata.schemaConstraints) } } else { // TODO old behaviour, remove it in the future options.schemaMetadata.optimizationType match { @@ -764,7 +762,8 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache: } } - def createOptimizations(struct: Optional[StructType]): Unit = { + def createOptimizations(struct: StructType): Unit = { + Validations.validate(ValidateSchemaOptions(options, struct)) options.query.queryType match { case QueryType.LABELS => createOptimizationsForNode(struct) case QueryType.RELATIONSHIP => createOptimizationsForRelationship(struct) diff --git a/spark-3/src/main/scala/org/neo4j/spark/reader/Neo4jScan.scala b/spark-3/src/main/scala/org/neo4j/spark/reader/Neo4jScan.scala index 2410f74ab..f1dab53cf 100644 --- a/spark-3/src/main/scala/org/neo4j/spark/reader/Neo4jScan.scala +++ b/spark-3/src/main/scala/org/neo4j/spark/reader/Neo4jScan.scala @@ -58,6 +58,6 @@ class Neo4jScan(neo4jOptions: Neo4jOptions, optsMap.put(Neo4jOptions.STREAMING_METADATA_STORAGE, StorageType.SPARK.toString) val newOpts = new Neo4jOptions(optsMap) Validations.validate(ValidateReadStreaming(newOpts, jobId)) - new Neo4jMicroBatchReader(Optional.of(schema), newOpts, jobId, aggregateColumns) + new Neo4jMicroBatchReader(schema, newOpts, jobId, aggregateColumns) } } diff --git a/spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jMicroBatchReader.scala b/spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jMicroBatchReader.scala index 4981f9a18..ae99d3dcf 100644 --- a/spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jMicroBatchReader.scala +++ b/spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jMicroBatchReader.scala @@ -14,7 +14,7 @@ import org.neo4j.spark.util._ import java.lang import java.util.Optional -class Neo4jMicroBatchReader(private val optionalSchema: Optional[StructType], +class Neo4jMicroBatchReader(private val schema: StructType, private val neo4jOptions: Neo4jOptions, private val jobId: String, private val aggregateColumns: Array[AggregateFunc]) @@ -27,7 +27,7 @@ class Neo4jMicroBatchReader(private val optionalSchema: Optional[StructType], private lazy val scriptResult = { val schemaService = new SchemaService(neo4jOptions, driverCache) - schemaService.createOptimizations(optionalSchema) + schemaService.createOptimizations(schema) val scriptResult = schemaService.execute(neo4jOptions.script) schemaService.close() scriptResult @@ -108,7 +108,7 @@ class Neo4jMicroBatchReader(private val optionalSchema: Optional[StructType], override def createReaderFactory(): PartitionReaderFactory = { new Neo4jStreamingPartitionReaderFactory( - neo4jOptions, optionalSchema.orElse(new StructType()), jobId, scriptResult, offsetAccumulator, aggregateColumns + neo4jOptions, schema, jobId, scriptResult, offsetAccumulator, aggregateColumns ) } } diff --git a/spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jStreamingWriter.scala b/spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jStreamingWriter.scala index 8d72338c0..37f40abd3 100644 --- a/spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jStreamingWriter.scala +++ b/spark-3/src/main/scala/org/neo4j/spark/streaming/Neo4jStreamingWriter.scala @@ -36,7 +36,7 @@ class Neo4jStreamingWriter(val queryId: String, private lazy val scriptResult = { val schemaService = new SchemaService(neo4jOptions, driverCache) - schemaService.createOptimizations(Optional.of(schema)) + schemaService.createOptimizations(schema) val scriptResult = schemaService.execute(neo4jOptions.script) schemaService.close() scriptResult diff --git a/spark-3/src/main/scala/org/neo4j/spark/writer/Neo4jBatchWriter.scala b/spark-3/src/main/scala/org/neo4j/spark/writer/Neo4jBatchWriter.scala index 3d6edc9c4..fce0fa332 100644 --- a/spark-3/src/main/scala/org/neo4j/spark/writer/Neo4jBatchWriter.scala +++ b/spark-3/src/main/scala/org/neo4j/spark/writer/Neo4jBatchWriter.scala @@ -14,7 +14,7 @@ class Neo4jBatchWriter(jobId: String, neo4jOptions: Neo4jOptions) extends BatchWrite{ override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory = { val schemaService = new SchemaService(neo4jOptions, driverCache) - schemaService.createOptimizations(Optional.of(structType)) + schemaService.createOptimizations(structType) val scriptResult = schemaService.execute(neo4jOptions.script) schemaService.close() diff --git a/spark-3/src/test/scala/org/neo4j/spark/DataSourceWriterNeo4jTSE.scala b/spark-3/src/test/scala/org/neo4j/spark/DataSourceWriterNeo4jTSE.scala index 3d61101ca..cd3cd47fd 100644 --- a/spark-3/src/test/scala/org/neo4j/spark/DataSourceWriterNeo4jTSE.scala +++ b/spark-3/src/test/scala/org/neo4j/spark/DataSourceWriterNeo4jTSE.scala @@ -2,19 +2,19 @@ package org.neo4j.spark import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.SparkException -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerStageCompleted} +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} import org.junit.Assert.{assertEquals, assertTrue, fail} -import org.junit.{Assume, BeforeClass, Test} +import org.junit.Test import org.neo4j.driver.summary.ResultSummary -import org.neo4j.driver.{Result, SessionConfig, Transaction, TransactionWork} +import org.neo4j.driver.{Result, Session, SessionConfig, Transaction, TransactionWork} import org.neo4j.spark.writer.DataWriterMetrics import java.util.concurrent.{CountDownLatch, TimeUnit} class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE { - val sparkSession = SparkSession.builder().getOrCreate() + private val sparkSession = SparkSession.builder().getOrCreate() import sparkSession.implicits._ @@ -496,15 +496,15 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE { .option("relationship.target.node.keys", "instrument:name") .save() } catch { - case sparkException: SparkException => { - val clientException = ExceptionUtils.getRootCause(sparkException) + case exception: IllegalArgumentException => { + val clientException = ExceptionUtils.getRootCause(exception) assertTrue(clientException.getMessage.equals( """Write failed due to the following errors: | - Schema is missing musician_name from option `relationship.source.node.keys` | |The option key and value might be inverted.""".stripMargin)) } - case generic: Throwable => fail(s"should be thrown a ${classOf[SparkException].getName}, got ${generic.getClass} instead") + case generic: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${generic.getClass} instead") } } @@ -533,8 +533,8 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE { .option("relationship.target.node.keys", "instrument_name:name") .save() } catch { - case sparkException: SparkException => { - val clientException = ExceptionUtils.getRootCause(sparkException) + case exception: IllegalArgumentException => { + val clientException = ExceptionUtils.getRootCause(exception) assertTrue(clientException.getMessage.equals( """Write failed due to the following errors: | - Schema is missing instrument_name from option `relationship.target.node.keys` @@ -542,7 +542,7 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE { | |The option key and value might be inverted.""".stripMargin)) } - case generic: Throwable => fail(s"should be thrown a ${classOf[SparkException].getName}, got ${generic.getClass} instead") + case generic: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${generic.getClass} instead") } } @@ -565,15 +565,15 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE { .option("node.properties", "musician_name:name,another_name:name") .save() } catch { - case sparkException: SparkException => { - val clientException = ExceptionUtils.getRootCause(sparkException) + case exception: IllegalArgumentException => { + val clientException = ExceptionUtils.getRootCause(exception) assertTrue(clientException.getMessage.equals( """Write failed due to the following errors: | - Schema is missing instrument_name from option `node.properties` | |The option key and value might be inverted.""".stripMargin)) } - case generic: Throwable => fail(s"should be thrown a ${classOf[SparkException].getName}, got ${generic.getClass} instead: ${generic.getMessage}") + case generic: Throwable => fail(s"should be thrown a ${classOf[IllegalArgumentException].getName}, got ${generic.getClass} instead: ${generic.getMessage}") } } @@ -604,6 +604,43 @@ class DataSourceWriterNeo4jTSE extends SparkConnectorScalaBaseTSE { latch.await(30, TimeUnit.SECONDS) } + @Test + def `does not create constraint if schema validation fails`(): Unit = { + val cities = Seq( + (1, "Cherbourg en Cotentin"), + (2, "London"), + (3, "Malmö"), + ).toDF("id", "city") + + try { + cities.write + .format(classOf[DataSource].getName) + .mode(SaveMode.Overwrite) + .option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl) + .option("labels", ":News") + .option("node.keys", "newsId") + .option("schema.optimization.node.keys", "UNIQUE") + .save() + } catch { + case _:Exception => { + } + } + + var session: Session = null + try { + session = SparkConnectorScalaSuiteIT.driver.session() + val result = session.run("SHOW CONSTRAINTS YIELD labelsOrTypes WHERE labelsOrTypes[0] = 'News' RETURN count(*) AS count") + .single() + .get("count") + .asLong() + assertEquals(0, result) + } finally { + if (session != null) { + session.close() + } + } + } + class MetricsListener(expectedMetrics: Map[String, Any], done: CountDownLatch) extends SparkListener { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { val actualMetrics = stageCompleted