Skip to content
This repository was archived by the owner on Dec 20, 2018. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 2 additions & 76 deletions src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ private[avro] class AvroOutputWriter(
recordName: String,
recordNamespace: String) extends OutputWriter {

private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace)
private lazy val converter =
SchemaConverters.createConverterToAvro(schema, recordName, recordNamespace)

/**
* Overrides the couple of methods responsible for generating the output streams / files so
Expand All @@ -72,79 +73,4 @@ private[avro] class AvroOutputWriter(
}

override def close(): Unit = recordWriter.close(context)

/**
* This function constructs converter function for a given sparkSQL datatype. This is used in
* writing Avro records out to disk
*/
private def createConverterToAvro(
dataType: DataType,
structName: String,
recordNamespace: String): (Any) => Any = {
dataType match {
case BinaryType => (item: Any) => item match {
case null => null
case bytes: Array[Byte] => ByteBuffer.wrap(bytes)
}
case ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | StringType | BooleanType => identity
case _: DecimalType => (item: Any) => if (item == null) null else item.toString
case TimestampType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Timestamp].getTime
case DateType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Date].getTime
case ArrayType(elementType, _) =>
val elementConverter = createConverterToAvro(elementType, structName, recordNamespace)
(item: Any) => {
if (item == null) {
null
} else {
val sourceArray = item.asInstanceOf[Seq[Any]]
val sourceArraySize = sourceArray.size
val targetArray = new Array[Any](sourceArraySize)
var idx = 0
while (idx < sourceArraySize) {
targetArray(idx) = elementConverter(sourceArray(idx))
idx += 1
}
targetArray
}
}
case MapType(StringType, valueType, _) =>
val valueConverter = createConverterToAvro(valueType, structName, recordNamespace)
(item: Any) => {
if (item == null) {
null
} else {
val javaMap = new HashMap[String, Any]()
item.asInstanceOf[Map[String, Any]].foreach { case (key, value) =>
javaMap.put(key, valueConverter(value))
}
javaMap
}
}
case structType: StructType =>
val builder = SchemaBuilder.record(structName).namespace(recordNamespace)
val schema: Schema = SchemaConverters.convertStructToAvro(
structType, builder, recordNamespace)
val fieldConverters = structType.fields.map(field =>
createConverterToAvro(field.dataType, field.name, recordNamespace))
(item: Any) => {
if (item == null) {
null
} else {
val record = new Record(schema)
val convertersIterator = fieldConverters.iterator
val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator
val rowIterator = item.asInstanceOf[Row].toSeq.iterator

while (convertersIterator.hasNext) {
val converter = convertersIterator.next()
record.put(fieldNamesIterator.next(), converter(rowIterator.next()))
}
record
}
}
}
}
}
138 changes: 118 additions & 20 deletions src/main/scala/com/databricks/spark/avro/SchemaConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
package com.databricks.spark.avro

import java.nio.ByteBuffer
import java.sql.{Date, Timestamp}
import java.util.HashMap

import scala.collection.JavaConverters._

import org.apache.avro.generic.GenericData.Fixed
import org.apache.avro.Schema.Type._
import org.apache.avro.SchemaBuilder._
import org.apache.avro.generic.GenericData.{Fixed, Record}
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.SchemaBuilder._
import org.apache.avro.Schema.Type._

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

import scala.collection.JavaConverters._
import scala.collection.immutable.Map

/**
* This object contains method that are used to convert sparkSQL schemas to avro schemas and vice
* versa.
Expand Down Expand Up @@ -113,16 +116,20 @@ object SchemaConverters {
def convertStructToAvro[T](
structType: StructType,
schemaBuilder: RecordBuilder[T],
recordNamespace: String): T = {
recordNamespace: String,
structName: String = "",
schemaMap: collection.mutable.Map[String, Object] =
collection.mutable.Map[String, Object]()): T = {
val fieldsAssembler: FieldAssembler[T] = schemaBuilder.fields()
structType.fields.foreach { field =>
val newField = fieldsAssembler.name(field.name).`type`()

if (field.nullable) {
convertFieldTypeToAvro(field.dataType, newField.nullable(), field.name, recordNamespace)
convertFieldTypeToAvro(field.dataType, newField.nullable(), field.name, recordNamespace,
schemaMap)
.noDefault
} else {
convertFieldTypeToAvro(field.dataType, newField, field.name, recordNamespace)
convertFieldTypeToAvro(field.dataType, newField, field.name, recordNamespace, schemaMap)
.noDefault
}
}
Expand Down Expand Up @@ -315,7 +322,7 @@ object SchemaConverters {
dataType: DataType,
schemaBuilder: BaseTypeBuilder[T],
structName: String,
recordNamespace: String): T = {
recordNamespace: String, schemaMap: collection.mutable.Map[String, Object]): T = {
dataType match {
case ByteType => schemaBuilder.intType()
case ShortType => schemaBuilder.intType()
Expand All @@ -332,19 +339,21 @@ object SchemaConverters {

case ArrayType(elementType, _) =>
val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull)
val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace)
val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace,
schemaMap)
schemaBuilder.array().items(elementSchema)

case MapType(StringType, valueType, _) =>
val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull)
val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace)
val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace,
schemaMap)
schemaBuilder.map().values(valueSchema)

case structType: StructType =>
convertStructToAvro(
structType,
schemaBuilder.record(structName).namespace(recordNamespace),
recordNamespace)
recordNamespace, structName, schemaMap)

case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.")
}
Expand All @@ -359,7 +368,8 @@ object SchemaConverters {
dataType: DataType,
newFieldBuilder: BaseFieldTypeBuilder[T],
structName: String,
recordNamespace: String): FieldDefault[T, _] = {
recordNamespace: String,
schemaMap: collection.mutable.Map[String, Object]): FieldDefault[T, _] = {
dataType match {
case ByteType => newFieldBuilder.intType()
case ShortType => newFieldBuilder.intType()
Expand All @@ -376,19 +386,28 @@ object SchemaConverters {

case ArrayType(elementType, _) =>
val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull)
val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace)
val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace,
schemaMap)
newFieldBuilder.array().items(elementSchema)

case MapType(StringType, valueType, _) =>
val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull)
val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace)
val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace,
schemaMap)
newFieldBuilder.map().values(valueSchema)

case structType: StructType =>
convertStructToAvro(
structType,
newFieldBuilder.record(structName).namespace(recordNamespace),
recordNamespace)
val schemaKey = s"$recordNamespace.$structName"
if (schemaMap.contains(schemaKey)) {
val schema = schemaMap.get(schemaKey).get
schema.asInstanceOf[RecordDefault[T]]
} else {
val schema : RecordDefault[T] = SchemaConverters.convertStructToAvro(
structType, newFieldBuilder.record(structName).namespace(recordNamespace),
recordNamespace, structName, schemaMap)
schemaMap.put(schemaKey, schema)
schema
}

case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.")
}
Expand All @@ -401,4 +420,83 @@ object SchemaConverters {
SchemaBuilder.builder()
}
}

/**
* This function constructs converter function for a given sparkSQL datatype. This is used in
* writing Avro records out to disk
*/
def createConverterToAvro(
dataType: DataType,
structName: String,
recordNamespace: String,
schemaMap: collection.mutable.Map[String, Object] =
collection.mutable.Map[String, Object]()): (Any) => Any = {
dataType match {
case BinaryType => (item: Any) => item match {
case null => null
case bytes: Array[Byte] => ByteBuffer.wrap(bytes)
}
case ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | StringType | BooleanType => identity
case _: DecimalType => (item: Any) => if (item == null) null else item.toString
case TimestampType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Timestamp].getTime
case DateType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Date].getTime
case ArrayType(elementType, _) =>
val elementConverter = createConverterToAvro(elementType, structName, recordNamespace,
schemaMap)
(item: Any) => {
if (item == null) {
null
} else {
val sourceArray = item.asInstanceOf[Seq[Any]]
val sourceArraySize = sourceArray.size
val targetArray = new Array[Any](sourceArraySize)
var idx = 0
while (idx < sourceArraySize) {
targetArray(idx) = elementConverter(sourceArray(idx))
idx += 1
}
targetArray
}
}
case MapType(StringType, valueType, _) =>
val valueConverter = createConverterToAvro(valueType, structName, recordNamespace,
schemaMap)
(item: Any) => {
if (item == null) {
null
} else {
val javaMap = new HashMap[String, Any]()
item.asInstanceOf[Map[String, Any]].foreach { case (key, value) =>
javaMap.put(key, valueConverter(value))
}
javaMap
}
}
case structType: StructType =>
val builder = SchemaBuilder.record(structName).namespace(recordNamespace)
val schema: Schema = SchemaConverters.convertStructToAvro(
structType, builder, recordNamespace, structName, schemaMap)
val fieldConverters = structType.fields.map(field =>
createConverterToAvro(field.dataType, field.name, recordNamespace, schemaMap))
(item: Any) => {
if (item == null) {
null
} else {
val record = new Record(schema)
val convertersIterator = fieldConverters.iterator
val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator
val rowIterator = item.asInstanceOf[Row].toSeq.iterator

while (convertersIterator.hasNext) {
val converter = convertersIterator.next()
record.put(fieldNamesIterator.next(), converter(rowIterator.next()))
}
record
}
}
}
}
}