From ebd269b651d730492c66c4ebccbb218298ec27ba Mon Sep 17 00:00:00 2001 From: Peter Rudenko Date: Mon, 8 Jun 2015 04:28:41 +0300 Subject: [PATCH 1/4] Configurable null values --- .../com/databricks/spark/csv/CsvParser.scala | 9 ++++++++- .../databricks/spark/csv/CsvRelation.scala | 5 +++-- src/test/resources/missing-values.csv | 5 +++++ .../databricks/spark/csv/CsvFastSuite.scala | 20 +++++++++++++++++++ 4 files changed, 36 insertions(+), 3 deletions(-) create mode 100644 src/test/resources/missing-values.csv diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index 7d71195..d49576b 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -34,6 +34,7 @@ class CsvParser { private var ignoreLeadingWhiteSpace: Boolean = false private var ignoreTrailingWhiteSpace: Boolean = false private var parserLib: String = ParserLibs.DEFAULT + private var nullValues: Seq[String] = Seq("") def withUseHeader(flag: Boolean): CsvParser = { @@ -81,6 +82,11 @@ class CsvParser { this } + def withNullValues(nullValues: Seq[String]): CsvParser = { + this.nullValues = nullValues + this + } + /** Returns a Schema RDD for the given CSV path. */ @throws[RuntimeException] def csvFile(sqlContext: SQLContext, path: String): DataFrame = { @@ -94,7 +100,8 @@ class CsvParser { parserLib, ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace, - schema)(sqlContext) + schema, + nullValues)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 2c9f30a..6d482cc 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -41,7 +41,8 @@ case class CsvRelation protected[spark] ( parserLib: String, ignoreLeadingWhiteSpace: Boolean, ignoreTrailingWhiteSpace: Boolean, - userSchema: StructType = null)(@transient val sqlContext: SQLContext) + userSchema: StructType = null, + nullValues: Seq[String] = Seq(""))(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with InsertableRelation { private val logger = LoggerFactory.getLogger(CsvRelation.getClass) @@ -63,7 +64,7 @@ case class CsvRelation protected[spark] ( // By making this a lazy val we keep the RDD around, amortizing the cost of locating splits. def buildScan = { - val baseRDD = sqlContext.sparkContext.textFile(location) + val baseRDD = sqlContext.sparkContext.textFile(location).map(line => line.replaceAll(nullValues.mkString("|"), "")) val fieldNames = schema.fieldNames diff --git a/src/test/resources/missing-values.csv b/src/test/resources/missing-values.csv new file mode 100644 index 0000000..71c953b --- /dev/null +++ b/src/test/resources/missing-values.csv @@ -0,0 +1,5 @@ +year,make,model,comment,blank +"2012","Tesla","S","No comment", +1997,Ford,E350,"Go get one now they are going fast", +2015,Chevy,Volt +NA,NULL,"T","Comment" diff --git a/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala index e2bab82..3b6f440 100644 --- a/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala @@ -32,6 +32,7 @@ class CsvFastSuite extends FunSuite { val carsAltFile = "src/test/resources/cars-alternative.csv" val emptyFile = "src/test/resources/empty.csv" val escapeFile = "src/test/resources/escape.csv" + val carsWithNAs = "src/test/resources/missing-values.csv" val tempEmptyDir = "target/test/empty2/" val numCars = 3 @@ -93,6 +94,25 @@ class CsvFastSuite extends FunSuite { assert(results.size === numCars - 1) } + test("DSL test for handling NULL values") { + val results = new CsvParser() + .withUseHeader(true) + .withParserLib("univocity") + .withNullValues(Seq("NULL", "NA")) + .csvFile(TestSQLContext, carsWithNAs) + .collect() + + assert(results.size === numCars + 1) + + val results2 = new CsvParser() + .withUseHeader(true) + .withNullValues(Seq("NULL", "NA", "NaN")) + .csvFile(TestSQLContext, carsWithNAs) + .collect() + + assert(results2.size === numCars + 1) + } + test("DSL test for FAILFAST parsing mode") { val parser = new CsvParser() .withParseMode("FAILFAST") From d89dcb3b91ba5c9de92a2d1fcd4b3676db9afbf5 Mon Sep 17 00:00:00 2001 From: Peter Rudenko Date: Mon, 8 Jun 2015 18:09:04 +0300 Subject: [PATCH 2/4] Fix style --- src/main/scala/com/databricks/spark/csv/CsvRelation.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 6d482cc..ed76e0e 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -64,7 +64,8 @@ case class CsvRelation protected[spark] ( // By making this a lazy val we keep the RDD around, amortizing the cost of locating splits. def buildScan = { - val baseRDD = sqlContext.sparkContext.textFile(location).map(line => line.replaceAll(nullValues.mkString("|"), "")) + val baseRDD = sqlContext.sparkContext.textFile(location). + map(line => line.replaceAll(nullValues.mkString("|"), "")) val fieldNames = schema.fieldNames From 9612b1603d1183193c00e2e562d9c65d38fe34ef Mon Sep 17 00:00:00 2001 From: Peter Rudenko Date: Fri, 12 Jun 2015 21:42:32 +0300 Subject: [PATCH 3/4] Null value for types other than String --- .../databricks/spark/csv/util/TypeCast.scala | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala index 62c7b17..b881e5d 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala @@ -33,21 +33,25 @@ object TypeCast { * @param castType SparkSQL type */ private[csv] def castTo(datum: String, castType: DataType): Any = { - castType match { - case _: ByteType => datum.toByte - case _: ShortType => datum.toShort - case _: IntegerType => datum.toInt - case _: LongType => datum.toLong - case _: FloatType => datum.toFloat - case _: DoubleType => datum.toDouble - case _: BooleanType => datum.toBoolean - case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) - // TODO(hossein): would be good to support other common timestamp formats - case _: TimestampType => Timestamp.valueOf(datum) - // TODO(hossein): would be good to support other common date formats - case _: DateType => Date.valueOf(datum) - case _: StringType => datum - case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + if (datum.isEmpty && castType != StringType) { + null + } else { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => datum.toFloat + case _: DoubleType => datum.toDouble + case _: BooleanType => datum.toBoolean + case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) + // TODO(hossein): would be good to support other common timestamp formats + case _: TimestampType => Timestamp.valueOf(datum) + // TODO(hossein): would be good to support other common date formats + case _: DateType => Date.valueOf(datum) + case _: StringType => datum + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + } } } From 18753a9b4473bdda109a11e064f9145707d881a6 Mon Sep 17 00:00:00 2001 From: Peter Rudenko Date: Thu, 18 Jun 2015 23:57:30 +0300 Subject: [PATCH 4/4] Replace null markers for the whole token. To avoid replacement MONTANA -> MONTA --- .../scala/com/databricks/spark/csv/CsvParser.scala | 2 +- .../scala/com/databricks/spark/csv/CsvRelation.scala | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index d49576b..da2c599 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -101,7 +101,7 @@ class CsvParser { ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace, schema, - nullValues)(sqlContext) + nullValues.toSet)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index ed76e0e..84242f6 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -42,7 +42,7 @@ case class CsvRelation protected[spark] ( ignoreLeadingWhiteSpace: Boolean, ignoreTrailingWhiteSpace: Boolean, userSchema: StructType = null, - nullValues: Seq[String] = Seq(""))(@transient val sqlContext: SQLContext) + nullValues: Set[String] = Set(""))(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with InsertableRelation { private val logger = LoggerFactory.getLogger(CsvRelation.getClass) @@ -64,8 +64,7 @@ case class CsvRelation protected[spark] ( // By making this a lazy val we keep the RDD around, amortizing the cost of locating splits. def buildScan = { - val baseRDD = sqlContext.sparkContext.textFile(location). - map(line => line.replaceAll(nullValues.mkString("|"), "")) + val baseRDD = sqlContext.sparkContext.textFile(location) val fieldNames = schema.fieldNames @@ -155,7 +154,8 @@ case class CsvRelation protected[spark] ( try { index = 0 while (index < schemaFields.length) { - rowArray(index) = TypeCast.castTo(tokens(index), schemaFields(index).dataType) + val token = if (nullValues.contains(tokens(index))) "" else tokens(index) + rowArray(index) = TypeCast.castTo(token, schemaFields(index).dataType) index = index + 1 } Some(Row.fromSeq(rowArray)) @@ -197,7 +197,8 @@ case class CsvRelation protected[spark] ( throw new RuntimeException(s"Malformed line in FAILFAST mode: $line") } else { while (index < schemaFields.length) { - rowArray(index) = TypeCast.castTo(tokens.get(index), schemaFields(index).dataType) + val token = if (nullValues.contains(tokens.get(index))) "" else tokens.get(index) + rowArray(index) = TypeCast.castTo(token, schemaFields(index).dataType) index = index + 1 } Some(Row.fromSeq(rowArray))