diff --git a/src/main/scala/com/databricks/spark/avro/AvroRelation.scala b/src/main/scala/com/databricks/spark/avro/AvroRelation.scala index 3e1fda20..863677c6 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroRelation.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroRelation.scala @@ -18,11 +18,9 @@ package com.databricks.spark.avro import java.io.FileNotFoundException import java.util.zip.Deflater - import scala.collection.Iterator import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer - import com.google.common.base.Objects import org.apache.avro.SchemaBuilder import org.apache.avro.file.{DataFileConstants, DataFileReader, FileReader} @@ -31,12 +29,12 @@ import org.apache.avro.mapred.{AvroOutputFormat, FsInput} import org.apache.avro.mapreduce.AvroJob import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.Job - import org.apache.spark.Logging import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} +import org.apache.avro.Schema.Type private[avro] class AvroRelation( override val paths: Array[String], @@ -130,8 +128,11 @@ private[avro] class AvroRelation( val firstRecord = records.next() val superSchema = firstRecord.getSchema // the schema of the actual record // the fields that are actually required along with their converters - val avroFieldMap = superSchema.getFields.map(f => (f.name, f)).toMap - + val avroFieldMap = superSchema.getFields.map{f => + f.getJsonProps.foreach(x => f.schema().addProp(x._1, x._2)) + (f.name, f) + }.toMap + new Iterator[Row] { private[this] val baseIterator = records private[this] var currentRecord = firstRecord diff --git a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala index c2a42ba9..07b6ed4c 100644 --- a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala +++ b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala @@ -27,6 +27,9 @@ import org.apache.avro.SchemaBuilder._ import org.apache.avro.Schema.Type._ import org.apache.spark.sql.Row import org.apache.spark.sql.types._ +import org.apache.avro.Schema.Type +import java.sql.Timestamp +import java.sql.Date /** * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice @@ -34,6 +37,13 @@ import org.apache.spark.sql.types._ */ object SchemaConverters { + val LOGICAL_TYPE = "logicalType" + val DECIMAL = "decimal" + val TIMESTAMP = "timestamp"; + val DATE = "date"; + val PRECISION = "precision" + val SCALE = "scale" + case class SchemaType(dataType: DataType, nullable: Boolean) /** @@ -42,17 +52,36 @@ object SchemaConverters { def toSqlType(avroSchema: Schema): SchemaType = { avroSchema.getType match { case INT => SchemaType(IntegerType, nullable = false) - case STRING => SchemaType(StringType, nullable = false) + case STRING => { + val logicalType = avroSchema.getJsonProp(LOGICAL_TYPE) + if(logicalType != null && logicalType.asText().equalsIgnoreCase(DECIMAL)){ + val precision = avroSchema.getJsonProp(PRECISION).asInt + val scale = avroSchema.getJsonProp(SCALE).asInt + SchemaType(DecimalType(precision,scale), nullable = false) + }else { + SchemaType(StringType, nullable = false) + } + } case BOOLEAN => SchemaType(BooleanType, nullable = false) case BYTES => SchemaType(BinaryType, nullable = false) case DOUBLE => SchemaType(DoubleType, nullable = false) case FLOAT => SchemaType(FloatType, nullable = false) - case LONG => SchemaType(LongType, nullable = false) + case LONG => { + val logicalType = avroSchema.getJsonProp(LOGICAL_TYPE) + if(logicalType != null && logicalType.asText().equalsIgnoreCase(TIMESTAMP)) { + SchemaType(TimestampType, nullable = false) + }else if(logicalType != null && logicalType.asText().equalsIgnoreCase(DATE)) { + SchemaType(TimestampType, nullable = false) + }else{ + SchemaType(LongType, nullable = false) + } + } case FIXED => SchemaType(BinaryType, nullable = false) case ENUM => SchemaType(StringType, nullable = false) case RECORD => val fields = avroSchema.getFields.map { f => + f.getJsonProps.foreach(x => f.schema().addProp(x._1, x._2)) val schemaType = toSqlType(f.schema()) StructField(f.name, schemaType.dataType, schemaType.nullable) } @@ -76,7 +105,9 @@ object SchemaConverters { // In case of a union with null, eliminate it and make a recursive call val remainingUnionTypes = avroSchema.getTypes.filterNot(_.getType == NULL) if (remainingUnionTypes.size == 1) { - toSqlType(remainingUnionTypes.get(0)).copy(nullable = true) + val remainingSchema = remainingUnionTypes.get(0) + avroSchema.getJsonProps.foreach(x => remainingSchema.addProp(x._1, x._2)) + toSqlType(remainingSchema).copy(nullable = true) } else { toSqlType(Schema.createUnion(remainingUnionTypes)).copy(nullable = true) } @@ -125,8 +156,31 @@ object SchemaConverters { private[avro] def createConverterToSQL(schema: Schema): Any => Any = { schema.getType match { // Avro strings are in Utf8, so we have to call toString on them - case STRING | ENUM => (item: Any) => if (item == null) null else item.toString - case INT | BOOLEAN | DOUBLE | FLOAT | LONG => identity + case STRING | ENUM => (item: Any) => if (item == null) { + null + }else { + val logicalType = schema.getJsonProp(LOGICAL_TYPE) + if(logicalType != null && logicalType.asText().equalsIgnoreCase(DECIMAL)){ + val precision = schema.getJsonProp(PRECISION).asInt + val scale = schema.getJsonProp(SCALE).asInt + Decimal.apply(BigDecimal.apply(item.toString()), precision, scale) + }else{ + item.toString + } + } + case LONG => (item: Any) => if (item == null) { + null + }else { + val logicalType = schema.getJsonProp(LOGICAL_TYPE) + if(logicalType != null && logicalType.asText().equalsIgnoreCase(TIMESTAMP)){ + new Timestamp(item.asInstanceOf[Long].longValue()) + }else if(logicalType != null && logicalType.asText().equalsIgnoreCase(DATE)){ + new Timestamp(item.asInstanceOf[Long].longValue()) + }else{ + item + } + } + case INT | BOOLEAN | DOUBLE | FLOAT => identity // Byte arrays are reused by avro, so we have to make a copy of them. case FIXED => (item: Any) => if (item == null) { null @@ -142,7 +196,10 @@ object SchemaConverters { javaBytes } case RECORD => - val fieldConverters = schema.getFields.map(f => createConverterToSQL(f.schema)) + val fieldConverters = schema.getFields.map{f => + f.getJsonProps.foreach(x => f.schema().addProp(x._1, x._2)) + createConverterToSQL(f.schema) + } (item: Any) => if (item == null) { null } else { @@ -173,7 +230,9 @@ object SchemaConverters { if (schema.getTypes.exists(_.getType == NULL)) { val remainingUnionTypes = schema.getTypes.filterNot(_.getType == NULL) if (remainingUnionTypes.size == 1) { - createConverterToSQL(remainingUnionTypes.get(0)) + val remainingSchema = remainingUnionTypes.get(0) + schema.getJsonProps.foreach(x => remainingSchema.addProp(x._1, x._2)) + createConverterToSQL(remainingSchema) } else { createConverterToSQL(Schema.createUnion(remainingUnionTypes)) } diff --git a/src/test/resources/users.avro b/src/test/resources/users.avro new file mode 100644 index 00000000..95050de4 Binary files /dev/null and b/src/test/resources/users.avro differ diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index df9ebd00..795aed3a 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -16,10 +16,12 @@ import org.apache.spark.SparkContext import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfterAll, FunSuite} +import java.sql.Date class AvroSuite extends FunSuite with BeforeAndAfterAll { val episodesFile = "src/test/resources/episodes.avro" val testFile = "src/test/resources/test.avro" + val userFile = "src/test/resources/users.avro" private var sqlContext: SQLContext = _ @@ -442,4 +444,29 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { assert(newDf.count == 8) } } -} \ No newline at end of file + + test("Logical Types") { + val df = sqlContext.read.avro(userFile) + + val decimals = df.select("decimal").collect().map(x => Decimal.apply(x.getDecimal(0))) + val dec1 = Decimal.apply(BigDecimal.apply("55555.555550000"), 25, 9) + val dec2 = Decimal.apply(BigDecimal.apply("8747336654.536756000"), 25, 9) + + assert(decimals.apply(0).equals(dec1)) + assert(decimals.apply(1).equals(dec2)) + + assert(df.schema.apply("decimal").dataType == DecimalType(25,9)) + + + val timestamps = df.select("timestamp").collect().map(x => x.getAs[Timestamp](0)) + val t1 = new Timestamp(1460354720000l) + val t2 = new Timestamp(1462842320000l) + + assert(timestamps.apply(0).equals(t1)) + assert(timestamps.apply(1).equals(t2)) + + assert(df.schema.apply("timestamp").dataType == TimestampType) + + + } +}