Skip to content

Commit

Permalink
review comment fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Mahesh Kumar Behera committed Aug 24, 2020
1 parent 3a6d96b commit d16e9e1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class HiveAcidDataSourceV2Reader
txn: HiveAcidTxn => {
import scala.collection.JavaConversions._
val reader = new TableReader(sparkSession, txn, hiveAcidMetadata)
val hiveReader = reader.getReader(schema.fieldNames,
val hiveReader = reader.getPartitionsV2(schema.fieldNames,
pushedFilterArray, new SparkAcidConf(sparkSession, options.toMap))
factories.addAll(hiveReader)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,27 @@ import org.apache.hadoop.mapred.{InputFormat, OutputFormat}
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

/**
* Represents metadata for hive acid table and exposes API to perform operations on top of it
* @param database - name of the database
* @param identifier - table identifier
* @param hiveConf hiveConf - hive conf
* @param sparkSession - spark session object
* @param fullyQualifiedTableName - the fully qualified hive acid table name
*/
class HiveAcidMetadata(database : Option[String],
identifier : String,
hiveConf: HiveConf,
caseSensitiveAnalysis : Boolean = false) extends Logging {
class HiveAcidMetadata(sparkSession: SparkSession,
fullyQualifiedTableName: String) extends Logging {

// hive conf
private val hiveConf: HiveConf = HiveConverter.getHiveConf(sparkSession.sparkContext)

// a hive representation of the table
val hTable: metadata.Table = {
val hive: Hive = Hive.get(hiveConf)
val table = sparkSession.sessionState.sqlParser.parseTableIdentifier(fullyQualifiedTableName)
val hTable = hive.getTable(
database match {
table.database match {
case Some(database) => database
case None => HiveAcidMetadata.DEFAULT_DATABASE
}, identifier)
}, table.identifier)
Hive.closeCurrent()
hTable
}
Expand Down Expand Up @@ -133,7 +134,7 @@ class HiveAcidMetadata(database : Option[String],
}

private def getColName(field: StructField): String = {
HiveAcidMetadata.getColName(caseSensitiveAnalysis, field)
HiveAcidMetadata.getColName(sparkSession, field)
}
}

Expand All @@ -155,26 +156,20 @@ object HiveAcidMetadata {

def fromSparkSession(sparkSession: SparkSession,
fullyQualifiedTableName: String): HiveAcidMetadata = {
val logicalPlan = sparkSession.sessionState.sqlParser.parseTableIdentifier(fullyQualifiedTableName)
new HiveAcidMetadata(logicalPlan.database,
logicalPlan.table,
HiveConverter.getHiveConf(sparkSession.sparkContext),
sparkSession.sessionState.conf.caseSensitiveAnalysis)
}

def fromTableName(database : Option[String], table : String, hiveConf : HiveConf): HiveAcidMetadata = {
new HiveAcidMetadata(database, table, hiveConf)
new HiveAcidMetadata(
sparkSession,
fullyQualifiedTableName)
}

def getColName(caseSensitiveAnalysis : Boolean, field: StructField): String = {
if (caseSensitiveAnalysis) {
def getColName(sparkSession: SparkSession, field: StructField): String = {
if (sparkSession.sessionState.conf.caseSensitiveAnalysis) {
field.name
} else {
field.name.toLowerCase(Locale.ROOT)
}
}

def getColNames(sparkSession: SparkSession, schema: StructType): Seq[String] = {
schema.map(getColName(sparkSession.sessionState.conf.caseSensitiveAnalysis, _))
schema.map(getColName(sparkSession, _))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ private[hiveacid] class TableReader(sparkSession: SparkSession,
partitions)
}

def getReader(requiredColumns: Array[String],
filters: Array[Filter],
readConf: SparkAcidConf): java.util.List[InputPartition[ColumnarBatch]] = {
def getPartitionsV2(requiredColumns: Array[String],
filters: Array[Filter],
readConf: SparkAcidConf): java.util.List[InputPartition[ColumnarBatch]] = {
val reader = getTableReader(requiredColumns, filters, readConf)
if (hiveAcidMetadata.isPartitioned) {
logDebug("getReader for Partitioned table")
Expand Down

0 comments on commit d16e9e1

Please sign in to comment.