From 14f2991439b4104c66955401988593b4a62e9eb1 Mon Sep 17 00:00:00 2001 From: Julien Date: Tue, 21 Jul 2020 16:02:21 +0200 Subject: [PATCH] Add image index to each row when reading images --- .../astrolabsoftware/sparkfits/FitsHduImage.scala | 8 +++++--- .../sparkfits/FitsRecordReader.scala | 15 +++++++++++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/main/scala/com/astrolabsoftware/sparkfits/FitsHduImage.scala b/src/main/scala/com/astrolabsoftware/sparkfits/FitsHduImage.scala index 7bef416..da1824c 100644 --- a/src/main/scala/com/astrolabsoftware/sparkfits/FitsHduImage.scala +++ b/src/main/scala/com/astrolabsoftware/sparkfits/FitsHduImage.scala @@ -121,15 +121,17 @@ object FitsHduImage { * Build a list of one StructField from header information. * The list of StructField is then used to build the DataFrame schema. * - * @return (List[StructField]) List of StructField with column name = Image, + * @return (List[StructField]) List of StructField with column names [Image, ImgIndex], * data type, and whether the data is nullable. * */ override def listOfStruct : List[StructField] = { // Get the list of StructField. val lStruct = List.newBuilder[StructField] - val tmp = ReadMyType("Image", elementType(0), true) - lStruct += tmp.copy(tmp.name, ArrayType(tmp.dataType)) + val img = ReadMyType("Image", elementType(0), true) + val index = ReadMyType("ImgIndex", "K", true) + lStruct += img.copy(img.name, ArrayType(img.dataType)) + lStruct += index.copy(index.name, index.dataType) lStruct.result } diff --git a/src/main/scala/com/astrolabsoftware/sparkfits/FitsRecordReader.scala b/src/main/scala/com/astrolabsoftware/sparkfits/FitsRecordReader.scala index 77f7594..675bb7c 100644 --- a/src/main/scala/com/astrolabsoftware/sparkfits/FitsRecordReader.scala +++ b/src/main/scala/com/astrolabsoftware/sparkfits/FitsRecordReader.scala @@ -72,6 +72,7 @@ class FitsRecordReader extends RecordReader[LongWritable, Seq[Row]] { private var nrowsLong : Long = 0L private var rowSizeInt : Int = 0 private var rowSizeLong : Long = 0L + private var nrowsPerImage : Long = 0L private var startstop: FitsLib.FitsBlockBoundaries = FitsLib.FitsBlockBoundaries() private var notValid : Boolean = false @@ -181,6 +182,9 @@ class FitsRecordReader extends RecordReader[LongWritable, Seq[Row]] { log.warn(s"Use option('mode', 'PERMISSIVE') if you want to discard all empty HDUs.") } + // Total number of rows per image + nrowsPerImage = keyValues("NAXIS2").toInt + // Get the number of rows and the size (B) of one row. // this is dependent on the HDU type nrowsLong = fits.hdu.getNRows(keyValues) @@ -370,13 +374,20 @@ class FitsRecordReader extends RecordReader[LongWritable, Seq[Row]] { // Read a record of length `0 to recordLength - 1` fits.data.readFully(recordValueBytes, 0, recordLength) + val imgPosition = (((currentPosition + recordLength)/rowSizeLong - 1)/nrowsPerImage).toLong // Convert each row // 1 task: 32 MB @ 2s val tmp = Seq.newBuilder[Row] for (i <- 0 to recordLength / rowSizeLong.toInt - 1) { - tmp += Row.fromSeq(fits.getRow( + val myrow = fits.getRow( recordValueBytes.slice( - rowSizeInt*i, rowSizeInt*(i+1)))) + rowSizeInt*i, rowSizeInt*(i+1) + ) + ) + val data = if (fits.hduType == "IMAGE") { + myrow :+ imgPosition + } else myrow + tmp += Row.fromSeq(data) } recordValue = tmp.result