Skip to content

Feat: (WIP) Stdlib functions #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -192,6 +192,9 @@ to create `TypedColumn`s and with those a new Dataset from pieces of another usi
```kotlin
val dataset: Dataset<YourClass> = ...
val newDataset: Dataset<Pair<TypeA, TypeB>> = dataset.selectTyped(col(YourClass::colA), col(YourClass::colB))

// Alternatively, for instance when working with a Dataset<Row>
val typedDataset: Dataset<Pair<String, Int>> = otherDataset.selectTyped(col("a").`as`<String>(), col("b").`as`<Int>())
```

### Overload resolution ambiguity
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ object KSparkExtensions {

def collectAsList[T](ds: Dataset[T]): util.List[T] = JavaConverters.seqAsJavaList(ds.collect())

def tailAsList[T](ds: Dataset[T], n: Int): util.List[T] = util.Arrays.asList(ds.tail(n) : _*)

def debugCodegen(df: Dataset[_]): Unit = {
import org.apache.spark.sql.execution.debug._
Original file line number Diff line number Diff line change
@@ -647,12 +647,19 @@ operator fun Column.get(key: Any): Column = getItem(key)
fun lit(a: Any) = functions.lit(a)

/**
* Provides a type hint about the expected return value of this column. This information can
* Provides a type hint about the expected return value of this column. This information can
* be used by operations such as `select` on a [Dataset] to automatically convert the
* results into the correct JVM types.
*
* ```
* val df: Dataset<Row> = ...
* val typedColumn: Dataset<Int> = df.selectTyped( col("a").`as`<Int>() )
* ```
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T> Column.`as`(): TypedColumn<Any, T> = `as`(encoder<T>())


/**
* Alias for [Dataset.joinWith] which passes "left" argument
* and respects the fact that in result of left join right relation is nullable
@@ -809,45 +816,74 @@ fun <T> Dataset<T>.showDS(numRows: Int = 20, truncate: Boolean = true) = apply {
/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1> Dataset<T>.selectTyped(
c1: TypedColumn<out Any, U1>,
): Dataset<U1> = select(c1 as TypedColumn<T, U1>)

/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1, reified U2> Dataset<T>.selectTyped(
c1: TypedColumn<T, U1>,
c2: TypedColumn<T, U2>,
c1: TypedColumn<out Any, U1>,
c2: TypedColumn<out Any, U2>,
): Dataset<Pair<U1, U2>> =
select(c1, c2).map { Pair(it._1(), it._2()) }
select(
c1 as TypedColumn<T, U1>,
c2 as TypedColumn<T, U2>,
).map { Pair(it._1(), it._2()) }

/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1, reified U2, reified U3> Dataset<T>.selectTyped(
c1: TypedColumn<T, U1>,
c2: TypedColumn<T, U2>,
c3: TypedColumn<T, U3>,
c1: TypedColumn<out Any, U1>,
c2: TypedColumn<out Any, U2>,
c3: TypedColumn<out Any, U3>,
): Dataset<Triple<U1, U2, U3>> =
select(c1, c2, c3).map { Triple(it._1(), it._2(), it._3()) }
select(
c1 as TypedColumn<T, U1>,
c2 as TypedColumn<T, U2>,
c3 as TypedColumn<T, U3>,
).map { Triple(it._1(), it._2(), it._3()) }

/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1, reified U2, reified U3, reified U4> Dataset<T>.selectTyped(
c1: TypedColumn<T, U1>,
c2: TypedColumn<T, U2>,
c3: TypedColumn<T, U3>,
c4: TypedColumn<T, U4>,
c1: TypedColumn<out Any, U1>,
c2: TypedColumn<out Any, U2>,
c3: TypedColumn<out Any, U3>,
c4: TypedColumn<out Any, U4>,
): Dataset<Arity4<U1, U2, U3, U4>> =
select(c1, c2, c3, c4).map { Arity4(it._1(), it._2(), it._3(), it._4()) }
select(
c1 as TypedColumn<T, U1>,
c2 as TypedColumn<T, U2>,
c3 as TypedColumn<T, U3>,
c4 as TypedColumn<T, U4>,
).map { Arity4(it._1(), it._2(), it._3(), it._4()) }

/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1, reified U2, reified U3, reified U4, reified U5> Dataset<T>.selectTyped(
c1: TypedColumn<T, U1>,
c2: TypedColumn<T, U2>,
c3: TypedColumn<T, U3>,
c4: TypedColumn<T, U4>,
c5: TypedColumn<T, U5>,
c1: TypedColumn<out Any, U1>,
c2: TypedColumn<out Any, U2>,
c3: TypedColumn<out Any, U3>,
c4: TypedColumn<out Any, U4>,
c5: TypedColumn<out Any, U5>,
): Dataset<Arity5<U1, U2, U3, U4, U5>> =
select(c1, c2, c3, c4, c5).map { Arity5(it._1(), it._2(), it._3(), it._4(), it._5()) }

select(
c1 as TypedColumn<T, U1>,
c2 as TypedColumn<T, U2>,
c3 as TypedColumn<T, U3>,
c4 as TypedColumn<T, U4>,
c5 as TypedColumn<T, U5>,
).map { Arity5(it._1(), it._2(), it._3(), it._4(), it._5()) }

@OptIn(ExperimentalStdlibApi::class)
inline fun <reified T> schema(map: Map<String, KType> = mapOf()) = schema(typeOf<T>(), map)
Original file line number Diff line number Diff line change
@@ -339,31 +339,34 @@ class ApiTest : ShouldSpec({
SomeClass(intArrayOf(1, 2, 4), 5),
)

val typedColumnA: TypedColumn<Any, IntArray> = dataset.col("a").`as`(encoder())
val newDS1WithAs: Dataset<Int> = dataset.selectTyped(
col("b").`as`<Int>(),
)
newDS1WithAs.show()

val newDS2 = dataset.selectTyped(
val newDS2: Dataset<Pair<Int, Int>> = dataset.selectTyped(
// col(SomeClass::a), NOTE that this doesn't work on 2.4, returnting a data class with an array in it
col(SomeClass::b),
col(SomeClass::b),
)
newDS2.show()

val newDS3 = dataset.selectTyped(
val newDS3: Dataset<Triple<Int, Int, Int>> = dataset.selectTyped(
col(SomeClass::b),
col(SomeClass::b),
col(SomeClass::b),
)
newDS3.show()

val newDS4 = dataset.selectTyped(
val newDS4: Dataset<Arity4<Int, Int, Int, Int>> = dataset.selectTyped(
col(SomeClass::b),
col(SomeClass::b),
col(SomeClass::b),
col(SomeClass::b),
)
newDS4.show()

val newDS5 = dataset.selectTyped(
val newDS5: Dataset<Arity5<Int, Int, Int, Int, Int>> = dataset.selectTyped(
col(SomeClass::b),
col(SomeClass::b),
col(SomeClass::b),
Loading