diff --git a/src/main/scala/com/databricks/spark/avro/DefaultSource.scala b/src/main/scala/com/databricks/spark/avro/DefaultSource.scala index 4ebcec85..8b5aa4c9 100644 --- a/src/main/scala/com/databricks/spark/avro/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/avro/DefaultSource.scala @@ -23,6 +23,7 @@ import java.util.zip.Deflater import scala.util.control.NonFatal import com.databricks.spark.avro.DefaultSource.{AvroSchema, IgnoreFilesWithoutExtensionProperty, SerializableConfiguration} +import com.databricks.spark.avro.generic.SparkGenericDatumReader import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.avro.{Schema, SchemaBuilder} @@ -178,10 +179,8 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister { val reader = { val in = new FsInput(new Path(new URI(file.filePath)), conf) try { - val datumReader = userProvidedSchema match { - case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema) - case _ => new GenericDatumReader[GenericRecord]() - } + val datumReader = new SparkGenericDatumReader() + userProvidedSchema.foreach(datumReader.setSchema) DataFileReader.openReader(in, datumReader) } catch { case NonFatal(e) => @@ -210,9 +209,11 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister { private val encoderForDataColumns = RowEncoder(requiredSchema) private[this] var completed = false + private var record: GenericRecord = _ override def hasNext: Boolean = { if (completed) { + record = null false } else { val r = reader.hasNext && !reader.pastSync(stop) @@ -228,7 +229,9 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister { if (reader.pastSync(stop)) { throw new NoSuchElementException("next on empty iterator") } - val record = reader.next() + + // record is reused by avro, we copy it's content with rowconverter + record = reader.next(record) val safeDataRow = rowConverter(record).asInstanceOf[GenericRow] // The safeDataRow is reused, we must do a copy diff --git a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala index 4fbdccdc..005a2e23 100644 --- a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala +++ b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala @@ -143,14 +143,12 @@ object SchemaConverters { targetSqlType: DataType): AnyRef => AnyRef = { def createConverter(avroSchema: Schema, - sqlType: DataType, path: List[String]): AnyRef => AnyRef = { + sparkSqlType: DataType, path: List[String]): AnyRef => AnyRef = { val avroType = avroSchema.getType - (sqlType, avroType) match { - // Avro strings are in Utf8, so we have to call toString on them - case (StringType, STRING) | (StringType, ENUM) => - (item: AnyRef) => item.toString - // Byte arrays are reused by avro, so we have to make a copy of them. - case (IntegerType, INT) | (BooleanType, BOOLEAN) | (DoubleType, DOUBLE) | + (sparkSqlType, avroType) match { + case (StringType, ENUM) => (item: AnyRef) => item.toString + case (StringType, STRING) | (IntegerType, INT) | + (BooleanType, BOOLEAN) | (DoubleType, DOUBLE) | (FloatType, FLOAT) | (LongType, LONG) => identity case (TimestampType, LONG) => @@ -160,7 +158,8 @@ object SchemaConverters { case (BinaryType, FIXED) => (item: AnyRef) => item.asInstanceOf[GenericFixed].bytes().clone() case (BinaryType, BYTES) => - (item: AnyRef) => + // Byte arrays are reused by avro, so we have to make a copy of them. + (item: AnyRef) => val byteBuffer = item.asInstanceOf[ByteBuffer] val bytes = new Array[Byte](byteBuffer.remaining) byteBuffer.get(bytes) diff --git a/src/main/scala/com/databricks/spark/avro/generic/SparkGenericDatumReader.scala b/src/main/scala/com/databricks/spark/avro/generic/SparkGenericDatumReader.scala new file mode 100644 index 00000000..9e0fb9c2 --- /dev/null +++ b/src/main/scala/com/databricks/spark/avro/generic/SparkGenericDatumReader.scala @@ -0,0 +1,15 @@ +package com.databricks.spark.avro.generic + +import org.apache.avro.Schema +import org.apache.avro.generic.{GenericDatumReader, GenericRecord} + + /** + * A generic datumreader that reads strings as string instead of utf-8 + */ +class SparkGenericDatumReader extends GenericDatumReader[GenericRecord]{ + + override def findStringClass( + + schema: Schema): Class[_] = classOf[String] + +}