Skip to content

Commit

Permalink
Merge pull request #914 from Kotlin/stdlib-interpret
Browse files Browse the repository at this point in the history
[Compiler plugin] Add a mechanism to handle function calls to stdlib that can appear as df api arguments
  • Loading branch information
koperagen authored Oct 10, 2024
2 parents 0e3fdcd + 650222e commit 61865d8
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,38 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Present

class PairConstructor : AbstractInterpreter<Pair<*, *>>() {
class PairToConstructor : AbstractInterpreter<Pair<*, *>>() {
val Arguments.receiver: Any? by arg(lens = Interpreter.Id)
val Arguments.that: Any? by arg(lens = Interpreter.Id)
override fun Arguments.interpret(): Pair<*, *> {
return receiver to that
}
}

class PairConstructor : AbstractInterpreter<Pair<*, *>>() {
val Arguments.first: Any? by arg(lens = Interpreter.Id)
val Arguments.second: Any? by arg(lens = Interpreter.Id)
override fun Arguments.interpret(): Pair<*, *> {
return first to second
}
}

class TrimMargin : AbstractInterpreter<String>() {
val Arguments.receiver: String by arg(lens = Interpreter.Value)
val Arguments.marginPrefix: String by arg(lens = Interpreter.Value, defaultValue = Present("|"))

override fun Arguments.interpret(): String {
return receiver.trimMargin(marginPrefix)
}
}

class TrimIndent : AbstractInterpreter<String>() {
val Arguments.receiver: String by arg(lens = Interpreter.Value)

override fun Arguments.interpret(): String {
return receiver.trimIndent()
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,13 @@ import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression
import org.jetbrains.kotlin.fir.expressions.FirResolvedQualifier
import org.jetbrains.kotlin.fir.expressions.UnresolvedExpressionTypeAccess
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.references.resolved
import org.jetbrains.kotlin.fir.references.symbol
import org.jetbrains.kotlin.fir.references.toResolvedNamedFunctionSymbol
import org.jetbrains.kotlin.fir.references.toResolvedFunctionSymbol
import org.jetbrains.kotlin.fir.resolve.fqName
import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
import org.jetbrains.kotlin.fir.types.classId
import org.jetbrains.kotlin.fir.types.coneType
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddDslStringInvoke
Expand All @@ -91,25 +88,22 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Move0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.PairConstructor
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.PairToConstructor
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ReadExcel
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameDefault
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameDsl
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameFrom
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToTop
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TrimMargin
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Update0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.UpdateWith0
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names

@OptIn(UnresolvedExpressionTypeAccess::class)
internal fun FirFunctionCall.loadInterpreter(session: FirSession): Interpreter<*>? {
if (
calleeReference.toResolvedNamedFunctionSymbol()?.callableId == Names.TO &&
coneTypeOrNull?.classId == Names.PAIR
) {
return PairConstructor()
}
val interpreter = Stdlib.interpreter(this)
if (interpreter != null) return interpreter
val symbol =
(calleeReference as? FirResolvedNamedReference)?.resolvedSymbol as? FirCallableSymbol ?: return null
val argName = Name.identifier("interpreter")
Expand All @@ -121,6 +115,32 @@ internal fun FirFunctionCall.loadInterpreter(session: FirSession): Interpreter<*
}
}

private object Stdlib {
private val map: MutableMap<Key, Interpreter<*>> = mutableMapOf()
init {
register(Names.TO, Names.PAIR, PairToConstructor())
register(Names.PAIR_CONSTRUCTOR, Names.PAIR, PairConstructor())
register(Names.TRIM_MARGIN, StandardClassIds.String, TrimMargin())
register(Names.TRIM_INDENT, StandardClassIds.String, TrimMargin())
}

@OptIn(UnresolvedExpressionTypeAccess::class)
fun interpreter(call: FirFunctionCall): Interpreter<*>? {
val id = call.calleeReference.toResolvedFunctionSymbol()?.callableId ?: return null
val returnType = call.coneTypeOrNull?.classId ?: return null
return map[Key(id, returnType)]
}

fun register(id: CallableId, returnType: ClassId, interpreter: Interpreter<*>) {
map[Key(id, returnType)] = interpreter
}
}

private data class Key(
val id: CallableId,
val returnType: ClassId,
)

internal fun FirFunctionCall.interpreterName(session: FirSession): String? {
val symbol =
(calleeReference as? FirResolvedNamedReference)?.resolvedSymbol as? FirCallableSymbol ?: return null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.jetbrains.kotlinx.dataframe.plugin.utils

import org.jetbrains.kotlin.builtins.StandardNames
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
Expand Down Expand Up @@ -53,7 +54,10 @@ object Names {
val INSTANT_CLASS_ID = kotlinx.datetime.Instant::class.classId()

val PAIR = ClassId(FqName("kotlin"), Name.identifier("Pair"))
val PAIR_CONSTRUCTOR = CallableId(FqName("kotlin"), FqName("Pair"), Name.identifier("Pair"))
val TO = CallableId(FqName("kotlin"), Name.identifier("to"))
val TRIM_MARGIN = CallableId(StandardNames.TEXT_PACKAGE_FQ_NAME, Name.identifier("trimMargin"))
val TRIM_INDENT = CallableId(StandardNames.TEXT_PACKAGE_FQ_NAME, Name.identifier("trimIndent"))
}

private fun KClass<*>.classId(): ClassId {
Expand Down
2 changes: 1 addition & 1 deletion plugins/kotlin-dataframe/testData/box/dataFrameOf_to.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val df = dataFrameOf(
"a" to listOf(1, 2),
Pair("a", listOf(1, 2))
"b" to listOf("str1", "str2"),
)
val i: Int = df.a[0]
Expand Down
12 changes: 12 additions & 0 deletions plugins/kotlin-dataframe/testData/box/trimIndent.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val df = DataFrame.readJsonStr("""
{"a": 123}
""".trimIndent())
df.a
return "OK"
}
12 changes: 12 additions & 0 deletions plugins/kotlin-dataframe/testData/box/trimMargin.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val df = DataFrame.readJsonStr("""
|{"a": 123}
|""".trimMargin())
df.a
return "OK"
}
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,18 @@ public void testTransformReplaceFunctionCall() {
runTest("testData/box/transformReplaceFunctionCall.kt");
}

@Test
@TestMetadata("trimIndent.kt")
public void testTrimIndent() {
runTest("testData/box/trimIndent.kt");
}

@Test
@TestMetadata("trimMargin.kt")
public void testTrimMargin() {
runTest("testData/box/trimMargin.kt");
}

@Test
@TestMetadata("ungroup.kt")
public void testUngroup() {
Expand Down

0 comments on commit 61865d8

Please sign in to comment.