diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala index e3152e40d0..1d5c1ab621 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala @@ -19,6 +19,7 @@ package org.apache.sedona.sql import org.apache.sedona.sql.UDF.RasterUdafCatalog +import org.apache.sedona.sql.utils.GeoToolsCoverageAvailability.{gridClassName, isGeoToolsAvailable} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.sedona_sql.UDT.RasterUdtRegistratorWrapper import org.apache.spark.sql.{SparkSession, functions} @@ -26,19 +27,6 @@ import org.slf4j.{Logger, LoggerFactory} object RasterRegistrator { val logger: Logger = LoggerFactory.getLogger(getClass) - private val gridClassName = "org.geotools.coverage.grid.GridCoverage2D" - - // Helper method to check if GridCoverage2D is available - private def isGeoToolsAvailable: Boolean = { - try { - Class.forName(gridClassName, true, Thread.currentThread().getContextClassLoader) - true - } catch { - case _: ClassNotFoundException => - logger.warn("Geotools was not found on the classpath. Raster operations will not be available.") - false - } - } def registerAll(sparkSession: SparkSession): Unit = { if (isGeoToolsAvailable) { diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/utils/GeoToolsCoverageAvailability.scala b/spark/common/src/main/scala/org/apache/sedona/sql/utils/GeoToolsCoverageAvailability.scala new file mode 100644 index 0000000000..1d197c2c32 --- /dev/null +++ b/spark/common/src/main/scala/org/apache/sedona/sql/utils/GeoToolsCoverageAvailability.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.utils + +import org.apache.sedona.sql.RasterRegistrator.logger + +/** + * A helper object to check if GeoTools GridCoverage2D is available on the classpath. + */ +object GeoToolsCoverageAvailability { + val gridClassName = "org.geotools.coverage.grid.GridCoverage2D" + + lazy val isGeoToolsAvailable: Boolean = { + try { + Class.forName(gridClassName, true, Thread.currentThread().getContextClassLoader) + true + } catch { + case _: ClassNotFoundException => + logger.warn("Geotools was not found on the classpath. Raster operations will not be available.") + false + } + } +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala new file mode 100644 index 0000000000..2d3349d4ad --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferrableRasterTypes.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.sedona_sql.UDT.RasterUDT +import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.{RasterEnhancer, RasterInputExpressionEnhancer} +import org.apache.spark.sql.types.{ArrayType, DataTypes, UserDefinedType} + +import scala.reflect.runtime.universe.{Type, typeOf} +import org.geotools.coverage.grid.GridCoverage2D + +object InferrableRasterTypes { + implicit val gridCoverage2DInstance: InferrableType[GridCoverage2D] = + new InferrableType[GridCoverage2D] {} + implicit val gridCoverage2DArrayInstance: InferrableType[Array[GridCoverage2D]] = + new InferrableType[Array[GridCoverage2D]] {} + + def isRasterType(t: Type): Boolean = t =:= typeOf[GridCoverage2D] + def isRasterArrayType(t: Type): Boolean = t =:= typeOf[Array[GridCoverage2D]] + + val rasterUDT: UserDefinedType[_] = RasterUDT + val rasterUDTArray: ArrayType = DataTypes.createArrayType(RasterUDT) + + def rasterExtractor(expr: Expression)(input: InternalRow): Any = expr.toRaster(input) + + def rasterSerializer(output: Any): Any = + if (output != null) { + output.asInstanceOf[GridCoverage2D].serialize + } else { + null + } + + def rasterArraySerializer(output: Any): Any = + if (output != null) { + val rasters = output.asInstanceOf[Array[GridCoverage2D]] + val serialized = rasters.map { raster => + val serialized = raster.serialize + raster.dispose(true) + serialized + } + ArrayData.toArrayData(serialized) + } else { + null + } +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala index 28096c4bcc..6b9f89c451 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala @@ -22,13 +22,11 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT} +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, DataTypes, DoubleType, IntegerType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.locationtech.jts.geom.Geometry import org.apache.spark.sql.sedona_sql.expressions.implicits._ -import org.apache.spark.sql.sedona_sql.expressions.raster.implicits._ -import org.geotools.coverage.grid.GridCoverage2D import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` import scala.reflect.runtime.universe.TypeTag @@ -75,14 +73,10 @@ abstract class InferredExpression(fSeq: InferrableFunction *) // This is a compile time type shield for the types we are able to infer. Anything // other than these types will cause a compilation error. This is the Scala // 2 way of making a union type. -sealed class InferrableType[T: TypeTag] +class InferrableType[T: TypeTag] object InferrableType { implicit val geometryInstance: InferrableType[Geometry] = new InferrableType[Geometry] {} - implicit val gridCoverage2DInstance: InferrableType[GridCoverage2D] = - new InferrableType[GridCoverage2D] {} - implicit val gridCoverage2DArrayInstance: InferrableType[Array[GridCoverage2D]] = - new InferrableType[Array[GridCoverage2D]] {} implicit val geometryArrayInstance: InferrableType[Array[Geometry]] = new InferrableType[Array[Geometry]] {} implicit val javaDoubleInstance: InferrableType[java.lang.Double] = @@ -127,8 +121,8 @@ object InferredTypes { expr => input => expr.toGeometry(input) } else if (t =:= typeOf[Array[Geometry]]) { expr => input => expr.toGeometryArray(input) - } else if (t =:= typeOf[GridCoverage2D]) { - expr => input => expr.toRaster(input) + } else if (InferredRasterExpression.isRasterType(t)) { + InferredRasterExpression.rasterExtractor } else if (t =:= typeOf[Array[Double]]) { expr => input => expr.eval(input).asInstanceOf[ArrayData].toDoubleArray() } else if (t =:= typeOf[String]) { @@ -156,14 +150,8 @@ object InferredTypes { } else { null } - } else if (t =:= typeOf[GridCoverage2D]) { - output => { - if (output != null) { - output.asInstanceOf[GridCoverage2D].serialize - } else { - null - } - } + } else if (InferredRasterExpression.isRasterType(t)) { + InferredRasterExpression.rasterSerializer } else if (t =:= typeOf[String]) { output => if (output != null) { @@ -194,19 +182,8 @@ object InferredTypes { } else { null } - } else if (t =:= typeOf[Array[GridCoverage2D]]) { - output => - if (output != null) { - val rasters = output.asInstanceOf[Array[GridCoverage2D]] - val serialized = rasters.map { raster => - val serialized = raster.serialize - raster.dispose(true) - serialized - } - ArrayData.toArrayData(serialized) - } else { - null - } + } else if (InferredRasterExpression.isRasterArrayType(t)) { + InferredRasterExpression.rasterArraySerializer } else if (t =:= typeOf[Option[Boolean]]) { output => if (output != null) { @@ -224,10 +201,10 @@ object InferredTypes { GeometryUDT } else if (t =:= typeOf[Array[Geometry]] || t =:= typeOf[java.util.List[Geometry]]) { DataTypes.createArrayType(GeometryUDT) - } else if (t =:= typeOf[GridCoverage2D]) { - RasterUDT - } else if (t =:= typeOf[Array[GridCoverage2D]]) { - DataTypes.createArrayType(RasterUDT) + } else if (InferredRasterExpression.isRasterType(t)) { + InferredRasterExpression.rasterUDT + } else if (InferredRasterExpression.isRasterArrayType(t)) { + InferredRasterExpression.rasterUDTArray } else if (t =:= typeOf[java.lang.Double]) { DoubleType } else if (t =:= typeOf[java.lang.Integer]) { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredRasterExpression.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredRasterExpression.scala new file mode 100644 index 0000000000..9c6875a65e --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredRasterExpression.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.expressions + +import org.apache.sedona.sql.utils.GeoToolsCoverageAvailability.isGeoToolsAvailable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.{ArrayType, UserDefinedType} + +import scala.reflect.runtime.universe.{Type, typeOf} + +object InferredRasterExpression { + def isRasterType(t: Type): Boolean = + isGeoToolsAvailable && InferrableRasterTypes.isRasterType(t) + + def isRasterArrayType(t: Type): Boolean = + isGeoToolsAvailable && InferrableRasterTypes.isRasterArrayType(t) + + def rasterUDT: UserDefinedType[_] = if (isGeoToolsAvailable) { + InferrableRasterTypes.rasterUDT + } else { + null + } + + def rasterUDTArray: ArrayType = if (isGeoToolsAvailable) { + InferrableRasterTypes.rasterUDTArray + } else { + null + } + + val rasterExtractor: Expression => InternalRow => Any = if (isGeoToolsAvailable) { + InferrableRasterTypes.rasterExtractor + } else { + _ => _ => null + } + + val rasterSerializer: Any => Any = if (isGeoToolsAvailable) { + InferrableRasterTypes.rasterSerializer + } else { + (_: Any) => null + } + + val rasterArraySerializer: Any => Any = if (isGeoToolsAvailable) { + InferrableRasterTypes.rasterArraySerializer + } else { + (_: Any) => null + } +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala index 85719ce5a3..a8baca9033 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala @@ -22,13 +22,10 @@ package org.apache.spark.sql.sedona_sql.expressions import org.apache.sedona.sql.utils.GeometrySerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} -import org.apache.spark.sql.types.{ByteType, DataTypes} +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.unsafe.types.UTF8String import org.locationtech.jts.geom.{Geometry, GeometryFactory, Point} -import java.util - object implicits { implicit class InputExpressionEnhancer(inputExpression: Expression) { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/GeometryFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/GeometryFunctions.scala index fa8390a31e..e13e81dfce 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/GeometryFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/GeometryFunctions.scala @@ -22,6 +22,7 @@ import org.apache.sedona.common.raster.GeometryFunctions import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.sedona_sql.expressions.InferredExpression import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ +import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._ case class RS_ConvexHull(inputExpressions: Seq[Expression]) extends InferredExpression(GeometryFunctions.convexHull _) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala index bd30844021..42fb3fd226 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ +import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._ import org.apache.spark.sql.sedona_sql.expressions.InferredExpression /// Calculate Normalized Difference between two bands diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctionEditors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctionEditors.scala index 10ea368ecf..fc87706967 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctionEditors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctionEditors.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster import org.apache.sedona.common.raster.PixelFunctionEditors import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ +import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._ import org.apache.spark.sql.sedona_sql.expressions.InferredExpression case class RS_SetValues(inputExpressions: Seq[Expression]) extends InferredExpression( diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala index f5499f277d..22315ed944 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT} import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ +import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._ import org.apache.spark.sql.sedona_sql.expressions.InferredExpression import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, DoubleType, IntegerType, StructType} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterAccessors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterAccessors.scala index f0039c6af2..b7ffbbebb7 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterAccessors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterAccessors.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster import org.apache.sedona.common.raster.RasterAccessors import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ +import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._ import org.apache.spark.sql.sedona_sql.expressions.InferredExpression case class RS_NumBands(inputExpressions: Seq[Expression]) extends InferredExpression(RasterAccessors.numBands _) { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandAccessors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandAccessors.scala index 11b4152401..b64a9b5bb3 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandAccessors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandAccessors.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.sedona_sql.UDT.RasterUDT import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ +import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._ import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer import org.apache.spark.sql.sedona_sql.expressions.InferredExpression import org.geotools.coverage.grid.GridCoverage2D diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandEditors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandEditors.scala index de782a8883..de7f57dfce 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandEditors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandEditors.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster import org.apache.sedona.common.raster.RasterBandEditors import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ +import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._ import org.apache.spark.sql.sedona_sql.expressions.InferredExpression case class RS_SetBandNoDataValue(inputExpressions: Seq[Expression]) extends InferredExpression( diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala index ae6c9e103d..1e4a6a8ea5 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, Gener import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.sedona_sql.UDT.RasterUDT import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ +import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._ import org.apache.spark.sql.sedona_sql.expressions.InferredExpression import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.{RasterEnhancer, RasterInputExpressionEnhancer} import org.apache.spark.sql.types.{ArrayType, BooleanType, Decimal, IntegerType, NullType, StructType} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala index 3b13e03101..db77310ba9 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster import org.apache.sedona.common.raster.RasterEditors import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ +import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._ import org.apache.spark.sql.sedona_sql.expressions.InferredExpression case class RS_SetSRID(inputExpressions: Seq[Expression]) extends InferredExpression(RasterEditors.setSrid _) { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterOutputs.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterOutputs.scala index 4d9e375b52..07a06730c2 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterOutputs.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterOutputs.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster import org.apache.sedona.common.raster.RasterOutputs import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._ +import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._ import org.apache.spark.sql.sedona_sql.expressions.InferredExpression case class RS_AsGeoTiff(inputExpressions: Seq[Expression]) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala index e34b1b87c0..70868128e5 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala @@ -34,7 +34,7 @@ trait TraitJoinQueryBase { leftShapeExpr: Expression, rightRdd: RDD[UnsafeRow], rightShapeExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) = { - if (leftShapeExpr.dataType.acceptsType(RasterUDT) || rightShapeExpr.dataType.acceptsType(RasterUDT)) { + if (leftShapeExpr.dataType.isInstanceOf[RasterUDT] || rightShapeExpr.dataType.isInstanceOf[RasterUDT]) { (toWGS84EnvelopeRDD(leftRdd, leftShapeExpr), toWGS84EnvelopeRDD(rightRdd, rightShapeExpr)) } else { @@ -60,7 +60,7 @@ trait TraitJoinQueryBase { // transformation for both sides. We use expanded WGS84 envelope as the joined geometries and perform a // coarse-grained spatial join. val spatialRdd = new SpatialRDD[Geometry] - val wgs84EnvelopeRdd = if (shapeExpression.dataType.acceptsType(RasterUDT)) { + val wgs84EnvelopeRdd = if (shapeExpression.dataType.isInstanceOf[RasterUDT]) { rdd.map { row => val raster = RasterSerializer.deserialize(shapeExpression.eval(row).asInstanceOf[Array[Byte]]) val shape = JoinedGeometryRaster.rasterToWGS84Envelope(raster)