Skip to content

Commit

Permalink
bulk-cdk: improve AirbyteConnectorRunner and CliRunner
Browse files Browse the repository at this point in the history
  • Loading branch information
postamar committed Sep 10, 2024
1 parent e00eaf6 commit dd954be
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import io.airbyte.cdk.command.ConnectorCommandLinePropertySource
import io.airbyte.cdk.command.MetadataYamlPropertySource
import io.micronaut.configuration.picocli.MicronautFactory
import io.micronaut.context.ApplicationContext
import io.micronaut.context.RuntimeBeanDefinition
import io.micronaut.context.env.CommandLinePropertySource
import io.micronaut.context.env.Environment
import io.micronaut.core.cli.CommandLine as MicronautCommandLine
Expand All @@ -17,8 +18,11 @@ import picocli.CommandLine.Model.UsageMessageSpec

/** Source connector entry point. */
class AirbyteSourceRunner(
/** CLI args. */
args: Array<out String>,
) : AirbyteConnectorRunner("source", args) {
/** Micronaut bean definition overrides, used only for tests. */
vararg testBeanDefinitions: RuntimeBeanDefinition<*>,
) : AirbyteConnectorRunner("source", args, testBeanDefinitions) {
companion object {
@JvmStatic
fun run(vararg args: String) {
Expand All @@ -29,8 +33,11 @@ class AirbyteSourceRunner(

/** Destination connector entry point. */
class AirbyteDestinationRunner(
/** CLI args. */
args: Array<out String>,
) : AirbyteConnectorRunner("destination", args) {
/** Micronaut bean definition overrides, used only for tests. */
vararg testBeanDefinitions: RuntimeBeanDefinition<*>,
) : AirbyteConnectorRunner("destination", args, testBeanDefinitions) {
companion object {
@JvmStatic
fun run(vararg args: String) {
Expand All @@ -46,6 +53,7 @@ class AirbyteDestinationRunner(
sealed class AirbyteConnectorRunner(
val connectorType: String,
val args: Array<out String>,
val testBeanDefinitions: Array<out RuntimeBeanDefinition<*>>,
) {
val envs: Array<String> = arrayOf(Environment.CLI, connectorType)

Expand All @@ -65,11 +73,12 @@ sealed class AirbyteConnectorRunner(
commandLinePropertySource,
MetadataYamlPropertySource(),
)
.beanDefinitions(*testBeanDefinitions)
.start()
val isTest: Boolean = ctx.environment.activeNames.contains(Environment.TEST)
val picocliFactory: CommandLine.IFactory = MicronautFactory(ctx)
val picocliCommandLine: CommandLine =
picocliCommandLineFactory.build<AirbyteConnectorRunnable>(picocliFactory, isTest)
picocliCommandLineFactory.build<AirbyteConnectorRunnable>(picocliFactory)
val exitCode: Int = picocliCommandLine.execute(*args)
if (!isTest) {
// Required by the platform, otherwise syncs may hang.
Expand All @@ -82,10 +91,7 @@ sealed class AirbyteConnectorRunner(
class PicocliCommandLineFactory(
val runner: AirbyteConnectorRunner,
) {
inline fun <reified R : Runnable> build(
factory: CommandLine.IFactory,
isTest: Boolean,
): CommandLine {
inline fun <reified R : Runnable> build(factory: CommandLine.IFactory): CommandLine {
val commandSpec: CommandLine.Model.CommandSpec =
CommandLine.Model.CommandSpec.wrapWithoutInspection(R::class.java, factory)
.name("airbyte-${runner.connectorType}-connector")
Expand All @@ -95,10 +101,6 @@ class PicocliCommandLineFactory(
.addOption(config)
.addOption(catalog)
.addOption(state)

if (isTest) {
commandSpec.addOption(output)
}
return CommandLine(commandSpec, factory)
}

Expand Down Expand Up @@ -168,10 +170,4 @@ class PicocliCommandLineFactory(
"path to the json-encoded state file",
"Required by the following commands: read",
)
val output: OptionSpec =
fileOption(
"output",
"path to the output file",
"When present, the connector writes to this file instead of stdout",
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class ConnectorCommandLinePropertySource(
const val CONNECTOR_CONFIG_PREFIX: String = "airbyte.connector.config"
const val CONNECTOR_CATALOG_PREFIX: String = "airbyte.connector.catalog"
const val CONNECTOR_STATE_PREFIX: String = "airbyte.connector.state"
const val CONNECTOR_OUTPUT_FILE = "airbyte.connector.output.file"

private fun resolveValues(
commandLine: CommandLine,
Expand All @@ -39,7 +38,6 @@ private fun resolveValues(
}
val values: MutableMap<String, Any> = mutableMapOf()
values[Operation.PROPERTY] = ops.first()
commandLine.optionValue("output")?.let { values[CONNECTOR_OUTPUT_FILE] = it }
for ((cliOptionKey, prefix) in
mapOf(
"config" to CONNECTOR_CONFIG_PREFIX,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ import io.micronaut.context.annotation.Value
import io.micronaut.context.env.Environment
import jakarta.inject.Singleton
import java.io.ByteArrayOutputStream
import java.io.FileOutputStream
import java.io.PrintStream
import java.nio.file.Path
import java.time.Clock
import java.time.Instant
import java.util.concurrent.ConcurrentHashMap
Expand Down Expand Up @@ -104,9 +102,6 @@ interface OutputConsumer : Consumer<AirbyteMessage>, AutoCloseable {
/** Configuration properties prefix for [StdoutOutputConsumer]. */
const val CONNECTOR_OUTPUT_PREFIX = "airbyte.connector.output"

// Used for integration tests.
const val CONNECTOR_OUTPUT_FILE = "$CONNECTOR_OUTPUT_PREFIX.file"

/** Default implementation of [OutputConsumer]. */
@Singleton
@Secondary
Expand Down Expand Up @@ -293,10 +288,4 @@ private class RecordTemplate(
private class PrintStreamFactory {

@Singleton @Requires(notEnv = [Environment.TEST]) fun stdout(): PrintStream = System.out

@Singleton
@Requires(env = [Environment.TEST])
@Requires(property = CONNECTOR_OUTPUT_FILE)
fun file(@Value("\${$CONNECTOR_OUTPUT_FILE}") filePath: Path): PrintStream =
PrintStream(FileOutputStream(filePath.toFile()), false, Charsets.UTF_8)
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,15 @@ import io.airbyte.cdk.util.Jsons
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.AirbyteStateMessage
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog
import io.micronaut.context.RuntimeBeanDefinition
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.InputStream
import java.io.OutputStream
import java.io.PrintStream
import java.nio.file.Files
import java.nio.file.Path
import java.time.Clock
import kotlin.io.path.deleteIfExists

data object CliRunner {
Expand All @@ -33,54 +40,73 @@ data object CliRunner {
config: ConfigurationJsonObjectBase? = null,
catalog: ConfiguredAirbyteCatalog? = null,
state: List<AirbyteStateMessage>? = null,
): BufferingOutputConsumer =
): BufferingOutputConsumer {
val out = TestOutputStream()
runConnector(op, config, catalog, state) { args: Array<String> ->
AirbyteSourceRunner(args)
AirbyteSourceRunner(args, out.beanDefinition)
}
return out.results
}

/** Same as [runSource] but for destinations. */
fun runDestination(
op: String,
config: ConfigurationJsonObjectBase? = null,
catalog: ConfiguredAirbyteCatalog? = null,
state: List<AirbyteStateMessage>? = null,
): BufferingOutputConsumer =
inputStream: InputStream,
): BufferingOutputConsumer {
val inputBeanDefinition: RuntimeBeanDefinition<InputStream> =
RuntimeBeanDefinition.builder(InputStream::class.java) { inputStream }
.singleton(true)
.build()
val out = TestOutputStream()
runConnector(op, config, catalog, state) { args: Array<String> ->
AirbyteDestinationRunner(args)
AirbyteDestinationRunner(args, inputBeanDefinition, out.beanDefinition)
}
return out.results
}

/** Same as the other [runDestination] but with simpler input of [AirbyteMessage] instances. */
fun runDestination(
op: String,
config: ConfigurationJsonObjectBase? = null,
catalog: ConfiguredAirbyteCatalog? = null,
state: List<AirbyteStateMessage>? = null,
vararg input: AirbyteMessage,
): BufferingOutputConsumer {
val baos = ByteArrayOutputStream()
for (msg in input) {
Jsons.writeValue(baos, msg)
baos.write('\n'.code)
}
val inputStream = ByteArrayInputStream(baos.toByteArray())
return runDestination(op, config, catalog, state, inputStream)
}

private fun runConnector(
op: String,
config: ConfigurationJsonObjectBase?,
catalog: ConfiguredAirbyteCatalog?,
state: List<AirbyteStateMessage>?,
connectorRunnerConstructor: (Array<String>) -> AirbyteConnectorRunner,
): BufferingOutputConsumer {
val result = BufferingOutputConsumer(ClockFactory().fixed())
) {
val configFile: Path? = inputFile(config)
val catalogFile: Path? = inputFile(catalog)
val stateFile: Path? = inputFile(state)
val outputFile: Path = Files.createTempFile(null, null)
val args: List<String> =
listOfNotNull(
"--$op",
configFile?.let { "--config=$it" },
catalogFile?.let { "--catalog=$it" },
stateFile?.let { "--state=$it" },
"--output=$outputFile",
)
try {
connectorRunnerConstructor(args.toTypedArray()).run<AirbyteConnectorRunnable>()
Files.readAllLines(outputFile)
.filter { it.isNotBlank() }
.map { Jsons.readValue(it, AirbyteMessage::class.java) }
.forEach { result.accept(it) }
return result
} finally {
configFile?.deleteIfExists()
catalogFile?.deleteIfExists()
stateFile?.deleteIfExists()
outputFile.deleteIfExists()
}
}

Expand All @@ -90,4 +116,41 @@ data object CliRunner {
Files.writeString(file, Jsons.writeValueAsString(contents))
}
}

private val clock: Clock = ClockFactory().fixed()

class TestOutputStream : OutputStream() {

val results = BufferingOutputConsumer(clock)
private val lineStream = ByteArrayOutputStream()
private val printStream = PrintStream(this, true, Charsets.UTF_8)

val beanDefinition: RuntimeBeanDefinition<PrintStream> =
RuntimeBeanDefinition.builder(PrintStream::class.java) { printStream }
.singleton(true)
.build()

override fun write(b: Int) {
if (b == '\n'.code) {
readLine()
} else {
lineStream.write(b)
}
}

override fun close() {
readLine()
lineStream.close()
results.close()
super.close()
}

private fun readLine() {
val line: String = lineStream.toString(Charsets.UTF_8).trim()
lineStream.reset()
if (line.isNotBlank()) {
results.accept(Jsons.readValue(line, AirbyteMessage::class.java))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import java.time.Instant
/** [OutputConsumer] implementation for unit tests. Collects everything into thread-safe buffers. */
@Singleton
@Requires(notEnv = [Environment.CLI])
@Requires(missingProperty = CONNECTOR_OUTPUT_FILE)
@Replaces(OutputConsumer::class)
class BufferingOutputConsumer(
clock: Clock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import io.airbyte.cdk.message.MessageQueueWriter
import io.github.oshai.kotlinlogging.KLogger
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Requires
import io.micronaut.context.env.Environment
import jakarta.inject.Singleton
import java.io.InputStream
import java.nio.charset.StandardCharsets
Expand Down Expand Up @@ -68,6 +70,7 @@ class DefaultInputConsumer(
@Factory
class InputStreamFactory {
@Singleton
@Requires(notEnv = [Environment.TEST])
fun make(): InputStream {
return System.`in`
}
Expand Down

0 comments on commit dd954be

Please sign in to comment.