Skip to content

Commit

Permalink
codegen: better enum handling in query params (#4385)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsimpson authored Mar 3, 2025
1 parent 5a0e6b7 commit e5c36c7
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ class EndpointGenerator {
val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None
val queryOrPathParamRefs = m.resolvedParameters
.collect { case queryParam: OpenapiParameter if queryParam.in == "query" || queryParam.in == "path" => queryParam.schema }
.collect { case ref: OpenapiSchemaRef if ref.isSchema => ref.stripped }
.collect {
case ref: OpenapiSchemaRef if ref.isSchema => ref.stripped
case OpenapiSchemaArray(ref: OpenapiSchemaRef, _) if ref.isSchema => ref.stripped
}
.toSet
val jsonParamRefs = (m.requestBody.toSeq.flatMap(_.content.map(c => (c.contentType, c.schema))) ++
m.responses.flatMap(_.content.map(c => (c.contentType, c.schema))))
Expand Down Expand Up @@ -264,6 +267,12 @@ class EndpointGenerator {
streamingImplementation: StreamingImplementation,
doc: OpenapiDocument
)(implicit location: Location): (String, Option[String], Seq[String]) = {
def toOutType(baseType: String, isArray: Boolean, noOptionWrapper: Boolean) = (isArray, noOptionWrapper) match {
case (true, true) => s"List[$baseType]"
case (true, false) => s"Option[List[$baseType]]"
case (false, true) => baseType
case (false, false) => s"Option[$baseType]"
}
def getEnumParamDefn(param: OpenapiParameter, e: OpenapiSchemaEnum, isArray: Boolean) = {
val enumName = endpointName.capitalize + strippedToCamelCase(param.name).capitalize
val enumParamRefs = if (param.in == "query" || param.in == "path") Set(enumName) else Set.empty[String]
Expand All @@ -283,12 +292,7 @@ class EndpointGenerator {
// 'exploded' params have no distinction between an empty list and an absent value, so don't wrap in 'Option' for them
val noOptionWrapper = required || (isArray && param.isExploded)
val req = if (noOptionWrapper) tpe else s"Option[$tpe]"
val outType = (isArray, noOptionWrapper) match {
case (true, true) => s"List[$enumName]"
case (true, false) => s"Option[List[$enumName]]"
case (false, true) => enumName
case (false, false) => s"Option[$enumName]"
}
val outType = toOutType(enumName, isArray, noOptionWrapper)

def mapToList =
if (!isArray) "" else if (noOptionWrapper) s".map(_.values)($arrayType(_))" else s".map(_.map(_.values))(_.map($arrayType(_)))"
Expand Down Expand Up @@ -320,7 +324,8 @@ class EndpointGenerator {
def mapToList = if (noOptionWrapper) s".map(_.values)($arrayType(_))" else s".map(_.map(_.values))(_.map($arrayType(_)))"

val desc = param.description.map(d => JavaEscape.escapeString(d)).fold("")(d => s""".description("$d")""")
(s""".in(${param.in}[$req]("${param.name}")$mapToList$desc)""", None, req)
val outType = toOutType(t, true, noOptionWrapper)
(s""".in(${param.in}[$req]("${param.name}")$mapToList$desc)""", None, outType)
case e @ OpenapiSchemaEnum(_, _, _) => getEnumParamDefn(param, e, isArray = false)
case OpenapiSchemaArray(e: OpenapiSchemaEnum, _) => getEnumParamDefn(param, e, isArray = true)
case x => bail(s"Can't create non-simple params to input - found $x")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,24 @@ object TapirGeneratedEndpoints {
support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values)
}


case class EnumExtraParamSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]) extends ExtraParamSupport[T] {
// Case-insensitive mapping
def decode(s: String): sttp.tapir.DecodeResult[T] =
scala.util.Try(T.upperCaseNameValuesToMap(s.toUpperCase))
.fold(
_ =>
sttp.tapir.DecodeResult.Error(
s,
new NoSuchElementException(
s"Could not find value $s for enum ${enumName}, available values: ${T.values.mkString(", ")}"
)
),
sttp.tapir.DecodeResult.Value(_)
)
def encode(t: T): String = t.entryName
}
def extraCodecSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): ExtraParamSupport[T] =
EnumExtraParamSupport(enumName, T)
sealed trait ADTWithoutDiscriminator
sealed trait ADTWithDiscriminator
sealed trait ADTWithDiscriminatorNoMapping
Expand Down Expand Up @@ -81,36 +98,42 @@ object TapirGeneratedEndpoints {
case object Foo extends AnEnum
case object Bar extends AnEnum
case object Baz extends AnEnum
implicit val enumCodecSupportAnEnum: ExtraParamSupport[AnEnum] =
extraCodecSupport[AnEnum]("AnEnum", AnEnum)
}



lazy val putAdtTest =
type PutAdtTestEndpoint = Endpoint[Unit, ADTWithoutDiscriminator, Unit, ADTWithoutDiscriminator, Any]
lazy val putAdtTest: PutAdtTestEndpoint =
endpoint
.put
.in(("adt" / "test"))
.in(jsonBody[ADTWithoutDiscriminator])
.out(jsonBody[ADTWithoutDiscriminator].description("successful operation"))

lazy val postAdtTest =
type PostAdtTestEndpoint = Endpoint[Unit, ADTWithDiscriminatorNoMapping, Unit, ADTWithDiscriminator, Any]
lazy val postAdtTest: PostAdtTestEndpoint =
endpoint
.post
.in(("adt" / "test"))
.in(jsonBody[ADTWithDiscriminatorNoMapping])
.out(jsonBody[ADTWithDiscriminator].description("successful operation"))

lazy val getOneofOptionTest =
type GetOneofOptionTestEndpoint = Endpoint[Unit, Unit, Unit, Option[AnEnum], Any]
lazy val getOneofOptionTest: GetOneofOptionTestEndpoint =
endpoint
.get
.in(("oneof" / "option" / "test"))
.out(oneOf[Option[AnEnum]](
oneOfVariantSingletonMatcher(sttp.model.StatusCode(204), emptyOutput.description("No response"))(None),
oneOfVariantValueMatcher(sttp.model.StatusCode(200), jsonBody[Option[AnEnum]].description("An enum")){ case Some(_: AnEnum) => true }))

lazy val postGenericJson =
type PostGenericJsonEndpoint = Endpoint[Unit, (Option[List[AnEnum]], Option[io.circe.Json]), Unit, io.circe.Json, Any]
lazy val postGenericJson: PostGenericJsonEndpoint =
endpoint
.post
.in(("generic" / "json"))
.in(query[Option[CommaSeparatedValues[AnEnum]]]("aTrickyParam").map(_.map(_.values))(_.map(CommaSeparatedValues(_))).description("A very thorough description"))
.in(jsonBody[Option[io.circe.Json]])
.out(jsonBody[io.circe.Json].description("anything back"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ lazy val root = (project in file("."))
.settings(
scalaVersion := "2.13.16",
version := "0.1",
openapiJsonSerdeLib := "jsoniter"
openapiJsonSerdeLib := "jsoniter",
openapiGenerateEndpointTypes := true
)

libraryDependencies ++= Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,28 @@ paths:
$ref: '#/components/schemas/AnEnum'
'/generic/json':
post:
parameters:
- in: query
name: aTrickyParam
style: form
explode: false
required: false
description: A very thorough description
schema:
type: array
items:
$ref: '#/components/schemas/AnEnum'
requestBody:
description: anything
content:
application/json:
schema: {}
schema: { }
responses:
"200":
description: anything back
content:
application/json:
schema: {}
schema: { }

components:
schemas:
Expand Down

0 comments on commit e5c36c7

Please sign in to comment.