Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ import com.angrypodo.wisp.model.ClassRouteInfo
import com.angrypodo.wisp.model.ObjectRouteInfo
import com.angrypodo.wisp.model.ParameterInfo
import com.angrypodo.wisp.model.RouteInfo
import com.angrypodo.wisp.util.WispClassName.INVALID_PARAMETER_ERROR
import com.angrypodo.wisp.util.WispClassName.MISSING_PARAMETER_ERROR
import com.angrypodo.wisp.util.WispClassName.ROUTE_FACTORY
import com.angrypodo.wisp.util.WispClassName
import com.google.devtools.ksp.processing.KSPLogger
import com.squareup.kotlinpoet.ANY
import com.squareup.kotlinpoet.BOOLEAN
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.DOUBLE
import com.squareup.kotlinpoet.FLOAT
Expand All @@ -19,13 +18,23 @@ import com.squareup.kotlinpoet.INT
import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.LONG
import com.squareup.kotlinpoet.MAP
import com.squareup.kotlinpoet.MemberName
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.STRING
import com.squareup.kotlinpoet.TypeSpec

internal class RouteFactoryGenerator(
private val logger: KSPLogger
) {
private val jsonClass = ClassName("kotlinx.serialization.json", "Json")
private val jsonObjectClass = ClassName("kotlinx.serialization.json", "JsonObject")
private val jsonPrimitiveClass = ClassName("kotlinx.serialization.json", "JsonPrimitive")
private val jsonElementClass = ClassName("kotlinx.serialization.json", "JsonElement")
private val decodeFromJsonElement = MemberName(
"kotlinx.serialization.json",
"decodeFromJsonElement"
)

fun generate(routeInfo: RouteInfo): FileSpec {
val createFun = FunSpec.builder("create")
.addModifiers(KModifier.OVERRIDE)
Expand All @@ -36,7 +45,7 @@ internal class RouteFactoryGenerator(

val factoryObject = TypeSpec.objectBuilder(routeInfo.factoryClassName)
.addModifiers(KModifier.INTERNAL)
.addSuperinterface(ROUTE_FACTORY)
.addSuperinterface(WispClassName.ROUTE_FACTORY)
.addFunction(createFun)
.build()

Expand All @@ -49,66 +58,54 @@ internal class RouteFactoryGenerator(
}

private fun buildCreateFunctionBody(routeInfo: RouteInfo): CodeBlock {
return when (routeInfo) {
is ObjectRouteInfo -> CodeBlock.of("return %T", routeInfo.routeClassName)
is ClassRouteInfo -> {
val block = CodeBlock.builder()
routeInfo.parameters.forEach { parameter ->
val conversion = buildConversionCode(parameter, routeInfo.wispPath)
block.addStatement("val %L = %L", parameter.name, conversion)
}
val constructorArgs = routeInfo.parameters.joinToString(", ") {
"${it.name} = ${it.name}"
}
block.addStatement("return %T(%L)", routeInfo.routeClassName, constructorArgs)
block.build()
}
if (routeInfo is ObjectRouteInfo) {
return CodeBlock.of("return %T", routeInfo.routeClassName)
}
}

private fun buildConversionCode(param: ParameterInfo, wispPath: String): CodeBlock {
val rawAccess = CodeBlock.of("params[%S]", param.name)
val conversionLogic = getConversionLogic(param, rawAccess)
val info = routeInfo as ClassRouteInfo
val block = CodeBlock.builder()

if (param.isNullable) {
return conversionLogic
}
// 1. Prepare JSON fields map
block.addStatement(
"val jsonFields = mutableMapOf<%T, %T>()",
STRING,
jsonElementClass
)

val nonNullableType = param.typeName.copy(nullable = false)
val errorType = when (nonNullableType) {
STRING -> MISSING_PARAMETER_ERROR
else -> INVALID_PARAMETER_ERROR
// 2. Iterate parameters and populate map
info.parameters.forEach { param ->
val conversion = getJsonConversion(param)
block.beginControlFlow("params[%S]?.let", param.name)
block.addStatement("jsonFields[%S] = %L", param.name, conversion)
block.endControlFlow()
}

return CodeBlock.of(
"(%L ?: throw %T(%S, %S))",
conversionLogic,
errorType,
wispPath,
param.name
// 3. Decode
block.addStatement("val jsonObject = %T(jsonFields)", jsonObjectClass)

// Use default Json instance and the extension function
block.addStatement(
"return %T.Default.%M<%T>(jsonObject)",
jsonClass,
decodeFromJsonElement,
routeInfo.routeClassName
)

return block.build()
}

private fun getConversionLogic(param: ParameterInfo, rawAccess: CodeBlock): CodeBlock {
val nonNullableType = param.typeName.copy(nullable = false)
return when {
param.isEnum -> CodeBlock.of(
"runCatching { %T.valueOf(%L!!.uppercase()) }.getOrNull()",
nonNullableType,
rawAccess
)
nonNullableType == STRING -> rawAccess
nonNullableType == INT -> CodeBlock.of("%L?.toIntOrNull()", rawAccess)
nonNullableType == LONG -> CodeBlock.of("%L?.toLongOrNull()", rawAccess)
nonNullableType == BOOLEAN -> CodeBlock.of("%L?.toBooleanStrictOrNull()", rawAccess)
nonNullableType == FLOAT -> CodeBlock.of("%L?.toFloatOrNull()", rawAccess)
nonNullableType == DOUBLE -> CodeBlock.of("%L?.toDoubleOrNull()", rawAccess)
private fun getJsonConversion(param: ParameterInfo): CodeBlock {
val type = param.typeName.copy(nullable = false)
return when (type) {
STRING -> CodeBlock.of("%T(it)", jsonPrimitiveClass)
INT -> CodeBlock.of("%T(it.toInt())", jsonPrimitiveClass)
LONG -> CodeBlock.of("%T(it.toLong())", jsonPrimitiveClass)
BOOLEAN -> CodeBlock.of("%T(it.toBoolean())", jsonPrimitiveClass)
FLOAT -> CodeBlock.of("%T(it.toFloat())", jsonPrimitiveClass)
DOUBLE -> CodeBlock.of("%T(it.toDouble())", jsonPrimitiveClass)
else -> {
logger.error(
"Wisp Error: Unsupported type '${param.typeName}'" +
" for parameter '${param.name}'."
)
CodeBlock.of("null")
// Fallback for Enum or others
CodeBlock.of("%T(it)", jsonPrimitiveClass)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ private fun KSClassDeclaration.extractParameters(): List<ParameterInfo> {
name = parameterName,
typeName = resolvedType.toTypeName(),
isNullable = resolvedType.isMarkedNullable,
isEnum = isEnum
isEnum = isEnum,
hasDefault = parameter.hasDefault
)
} ?: emptyList()
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ internal data class ParameterInfo(
val name: String,
val typeName: TypeName,
val isNullable: Boolean,
val isEnum: Boolean
val isEnum: Boolean,
val hasDefault: Boolean
)
Loading
Loading