Skip to content

Commit 1417d55

Browse files
committed
#697 Add conflict resolution logic to SparkUtils.copyMetadata.
1 parent 6d89cec commit 1417d55

File tree

2 files changed

+67
-6
lines changed

2 files changed

+67
-6
lines changed

spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -243,17 +243,26 @@ object SparkUtils extends Logging {
243243
/**
244244
* Copies metadata from one schema to another as long as names and data types are the same.
245245
*
246-
* @param schemaFrom Schema to copy metadata from.
247-
* @param schemaTo Schema to copy metadata to.
248-
* @param overwrite If true, the metadata of schemaTo is not retained
246+
* @param schemaFrom Schema to copy metadata from.
247+
* @param schemaTo Schema to copy metadata to.
248+
* @param overwrite If true, the metadata of schemaTo is not retained
249+
* @param sourcePreferred If true, schemaFrom metadata is used on conflicts, schemaTo otherwise.
249250
* @return Same schema as schemaTo with metadata from schemaFrom.
250251
*/
251-
def copyMetadata(schemaFrom: StructType, schemaTo: StructType, overwrite: Boolean = false): StructType = {
252+
def copyMetadata(schemaFrom: StructType,
253+
schemaTo: StructType,
254+
overwrite: Boolean = false,
255+
sourcePreferred: Boolean = false): StructType = {
252256
def joinMetadata(from: Metadata, to: Metadata): Metadata = {
253257
val newMetadataMerged = new MetadataBuilder
254258

255-
newMetadataMerged.withMetadata(from)
256-
newMetadataMerged.withMetadata(to)
259+
if (sourcePreferred) {
260+
newMetadataMerged.withMetadata(to)
261+
newMetadataMerged.withMetadata(from)
262+
} else {
263+
newMetadataMerged.withMetadata(from)
264+
newMetadataMerged.withMetadata(to)
265+
}
257266

258267
newMetadataMerged.build()
259268
}

spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,58 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
626626
assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120)
627627
}
628628

629+
test("copyMetadata should retain metadata on conflicts by default") {
630+
val df1 = List(1, 2, 3).toDF("col1")
631+
val df2 = List(1, 2, 3).toDF("col1")
632+
633+
val metadata1 = new MetadataBuilder()
634+
metadata1.putString("comment", "Test")
635+
metadata1.putLong("maxLength", 100)
636+
637+
val metadata2 = new MetadataBuilder()
638+
metadata2.putLong("maxLength", 120)
639+
metadata2.putLong("newMetadata", 180)
640+
641+
val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build())))
642+
val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build())))
643+
644+
val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata)
645+
646+
val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata)
647+
648+
val newDf = spark.createDataFrame(df2.rdd, schemaWithMetadata)
649+
650+
assert(newDf.schema.fields.head.metadata.getString("comment") == "Test")
651+
assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120)
652+
assert(newDf.schema.fields.head.metadata.getLong("newMetadata") == 180)
653+
}
654+
655+
test("copyMetadata should overwrite metadata on conflicts when sourcePreferred=true") {
656+
val df1 = List(1, 2, 3).toDF("col1")
657+
val df2 = List(1, 2, 3).toDF("col1")
658+
659+
val metadata1 = new MetadataBuilder()
660+
metadata1.putString("comment", "Test")
661+
metadata1.putLong("maxLength", 100)
662+
663+
val metadata2 = new MetadataBuilder()
664+
metadata2.putLong("maxLength", 120)
665+
metadata2.putLong("newMetadata", 180)
666+
667+
val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build())))
668+
val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build())))
669+
670+
val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata)
671+
672+
val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata, sourcePreferred = true)
673+
674+
val newDf = spark.createDataFrame(df2.rdd, schemaWithMetadata)
675+
676+
assert(newDf.schema.fields.head.metadata.getString("comment") == "Test")
677+
assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 100)
678+
assert(newDf.schema.fields.head.metadata.getLong("newMetadata") == 180)
679+
}
680+
629681
test("copyMetadata should not retain original metadata when overwrite = true") {
630682
val df1 = List(1, 2, 3).toDF("col1")
631683
val df2 = List(1, 2, 3).toDF("col1")

0 commit comments

Comments
 (0)