diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 6392a74..c2a71c1 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -28,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.sources.{PrunedScan, BaseRelation, InsertableRelation, TableScan} import org.apache.spark.sql.types._ -import com.databricks.spark.csv.readers.{BulkCsvReader, LineCsvReader} +import com.databricks.spark.csv.readers.{BulkCsvReader, LineCsvReader, BulkReader, LineReader} import com.databricks.spark.csv.util._ case class CsvRelation protected[spark] ( @@ -208,19 +208,24 @@ case class CsvRelation protected[spark] ( } } + protected def getLineReader(): LineReader = { + val escapeVal = if (escape == null) '\\' else escape.charValue() + val commentChar: Char = if (comment == null) '\0' else comment + val quoteChar: Char = if (quote == null) '\0' else quote + + new LineCsvReader( + fieldSep = delimiter, + quote = quoteChar, + escape = escapeVal, + commentMarker = commentChar) + } + private def inferSchema(): StructType = { if (this.userSchema != null) { userSchema } else { val firstRow = if (ParserLibs.isUnivocityLib(parserLib)) { - val escapeVal = if (escape == null) '\\' else escape.charValue() - val commentChar: Char = if (comment == null) '\0' else comment - val quoteChar: Char = if (quote == null) '\0' else quote - new LineCsvReader( - fieldSep = delimiter, - quote = quoteChar, - escape = escapeVal, - commentMarker = commentChar).parseLine(firstLine) + getLineReader().parseLine(firstLine) } else { val csvFormat = defaultCsvFormat .withDelimiter(delimiter) @@ -261,6 +266,18 @@ case class CsvRelation protected[spark] ( } } + protected def getBulkReader( + header: Seq[String], + iter: Iterator[String], split: Int): BulkReader = { + val escapeVal = if (escape == null) '\\' else escape.charValue() + val commentChar: Char = if (comment == null) '\0' else comment + val quoteChar: Char = if (quote == null) '\0' else quote + + new BulkCsvReader(iter, split, + headers = header, fieldSep = delimiter, + quote = quoteChar, escape = escapeVal, commentMarker = commentChar) + } + private def univocityParseCSV( file: RDD[String], header: Seq[String]): RDD[Array[String]] = { @@ -269,13 +286,7 @@ case class CsvRelation protected[spark] ( val dataLines = if (useHeader) file.filter(_ != filterLine) else file val rows = dataLines.mapPartitionsWithIndex({ case (split, iter) => { - val escapeVal = if (escape == null) '\\' else escape.charValue() - val commentChar: Char = if (comment == null) '\0' else comment - val quoteChar: Char = if (quote == null) '\0' else quote - - new BulkCsvReader(iter, split, - headers = header, fieldSep = delimiter, - quote = quoteChar, escape = escapeVal, commentMarker = commentChar) + getBulkReader(header, iter, split) } }, true) diff --git a/src/main/scala/com/databricks/spark/csv/readers/readers.scala b/src/main/scala/com/databricks/spark/csv/readers/readers.scala index 9fb1212..1c81815 100644 --- a/src/main/scala/com/databricks/spark/csv/readers/readers.scala +++ b/src/main/scala/com/databricks/spark/csv/readers/readers.scala @@ -22,6 +22,15 @@ import java.io.StringReader import com.univocity.parsers.csv._ +trait BulkReader extends Iterator[Array[String]] { + protected def reader(iter: Iterator[String]) = new StringIteratorReader(iter) +} + +trait LineReader { + protected def reader(line: String) = new StringReader(line) + def parseLine(line: String): Array[String] +} + /** * Read and parse CSV-like input * @param fieldSep the delimiter used to separate fields in a line @@ -97,14 +106,15 @@ private[csv] class LineCsvReader( ignoreTrailingSpace, null, inputBufSize, - maxCols) { + maxCols) + with LineReader{ /** * parse a line * @param line a String with no newline at the end * @return array of strings where each string is a field in the CSV record */ def parseLine(line: String): Array[String] = { - parser.beginParsing(new StringReader(line)) + parser.beginParsing(reader(line)) val parsed = parser.parseNext() parser.stopParsing() parsed @@ -148,10 +158,9 @@ private[csv] class BulkCsvReader( headers, inputBufSize, maxCols) - with Iterator[Array[String]] { + with BulkReader { - private val reader = new StringIteratorReader(iter) - parser.beginParsing(reader) + parser.beginParsing(reader(iter)) private var nextRecord = parser.parseNext() /** @@ -178,7 +187,7 @@ private[csv] class BulkCsvReader( * parsed and needs the newlines to be present * @param iter iterator over RDD[String] */ -private class StringIteratorReader(val iter: Iterator[String]) extends java.io.Reader { +private[readers] class StringIteratorReader(val iter: Iterator[String]) extends java.io.Reader { private var next: Long = 0 private var length: Long = 0 // length of input so far diff --git a/src/main/scala/com/databricks/spark/csv/util/TextFile.scala b/src/main/scala/com/databricks/spark/csv/util/TextFile.scala index 3b8d6c6..2dde4d3 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TextFile.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TextFile.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.mapred.TextInputFormat import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD -private[csv] object TextFile { +object TextFile { val DEFAULT_CHARSET = Charset.forName("UTF-8") def withCharset(context: SparkContext, location: String, charset: String): RDD[String] = {