Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AN-146] Emit VM cost for GCP Batch #7582

Merged
merged 14 commits into from
Dec 9, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
case event if event.name == CallMetadataKeys.VmEndTime => event.offsetDateTime
}

override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] =
pollStatus.instantiatedVmInfo

Check warning on line 51 in supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala

View check run for this annotation

Codecov / codecov/patch

supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala#L51

Added line #L51 was not covered by tests

override def handleVmCostLookup(vmInfo: InstantiatedVmInfo) = {
val request = GcpCostLookupRequest(vmInfo, self)
params.serviceRegistry ! request
Expand All @@ -69,6 +72,7 @@
}

override def receive: Receive = {
case costResponse: GcpCostLookupResponse => handleCostResponse(costResponse)

Check warning on line 75 in supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala

View check run for this annotation

Codecov / codecov/patch

supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala#L75

Added line #L75 was not covered by tests
case message: PollResultMessage =>
message match {
case ProcessThisPollResult(pollResult: RunStatus) => processPollResult(pollResult)
Expand All @@ -93,5 +97,4 @@

override def params: PollMonitorParameters = pollMonitorParameters

override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] = Option.empty // TODO
}
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
Future.fromTry {
Try {
runStatus match {
case RunStatus.Aborted(_) => AbortedExecutionHandle
case RunStatus.Aborted(_, _) => AbortedExecutionHandle
case failedStatus: RunStatus.UnsuccessfulRunStatus => handleFailedRunStatus(failedStatus)
case unknown =>
throw new RuntimeException(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
package cromwell.backend.google.batch.models

import cromwell.core.ExecutionEvent
import cromwell.services.cost.InstantiatedVmInfo

sealed trait RunStatus {
def eventList: Seq[ExecutionEvent]
def toString: String

val instantiatedVmInfo: Option[InstantiatedVmInfo]
}

object RunStatus {

case class Initializing(eventList: Seq[ExecutionEvent]) extends RunStatus { override def toString = "Initializing" }
case class AwaitingCloudQuota(eventList: Seq[ExecutionEvent]) extends RunStatus {
case class Initializing(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty)
extends RunStatus { override def toString = "Initializing" }
case class AwaitingCloudQuota(eventList: Seq[ExecutionEvent],
instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty
) extends RunStatus {
override def toString = "AwaitingCloudQuota"
}

case class Running(eventList: Seq[ExecutionEvent]) extends RunStatus { override def toString = "Running" }
case class Running(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty)
extends RunStatus { override def toString = "Running" }

sealed trait TerminalRunStatus extends RunStatus

case class Success(eventList: Seq[ExecutionEvent]) extends TerminalRunStatus {
case class Success(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty)
extends TerminalRunStatus {
override def toString = "Success"
}

Expand All @@ -29,7 +37,8 @@ object RunStatus {

final case class Failed(
exitCode: Option[GcpBatchExitCode],
eventList: Seq[ExecutionEvent]
eventList: Seq[ExecutionEvent],
instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty
) extends UnsuccessfulRunStatus {
override def toString = "Failed"

Expand Down Expand Up @@ -58,7 +67,9 @@ object RunStatus {
}
}

final case class Aborted(eventList: Seq[ExecutionEvent]) extends UnsuccessfulRunStatus {
final case class Aborted(eventList: Seq[ExecutionEvent],
instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty
) extends UnsuccessfulRunStatus {
override def toString = "Aborted"

override val exitCode: Option[GcpBatchExitCode] = None
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package cromwell.backend.google.batch.actors

import akka.actor.{ActorRef, ActorSystem, Props}
import akka.testkit.{TestKit, TestProbe}
import cats.data.Validated.Valid
import common.mock.MockSugar
import cromwell.backend.google.batch.models.GcpBatchRuntimeAttributes
import cromwell.backend.{BackendJobDescriptor, BackendJobDescriptorKey, RuntimeAttributeDefinition}
import cromwell.core.callcaching.NoDocker
import cromwell.core.{ExecutionEvent, WorkflowOptions}
import cromwell.core.logging.JobLogger
import cromwell.services.cost.{GcpCostLookupRequest, GcpCostLookupResponse, InstantiatedVmInfo}
import cromwell.services.keyvalue.InMemoryKvServiceActor
import org.scalatest.flatspec.AnyFlatSpecLike
import org.scalatest.matchers.should.Matchers
import cromwell.backend.google.batch.models.GcpBatchTestConfig._
import wom.graph.CommandCallNode
import cromwell.backend._
import cromwell.backend.google.batch.models._
import cromwell.backend.io.TestWorkflows
import cromwell.backend.standard.pollmonitoring.ProcessThisPollResult
import cromwell.services.metadata.CallMetadataKeys
import cromwell.services.metadata.MetadataService.PutMetadataAction
import org.slf4j.helpers.NOPLogger
import wom.values.WomString

import java.time.{Instant, OffsetDateTime}
import java.time.temporal.ChronoUnit
import scala.concurrent.duration.DurationInt

class BatchPollResultMonitorActorSpec
extends TestKit(ActorSystem("BatchPollResultMonitorActorSpec"))
with AnyFlatSpecLike
with BackendSpec
with Matchers
with MockSugar {

var kvService: ActorRef = system.actorOf(Props(new InMemoryKvServiceActor), "kvService")
val runtimeAttributesBuilder = GcpBatchRuntimeAttributes.runtimeAttributesBuilder(gcpBatchConfiguration)
val jobLogger = mock[JobLogger]
val serviceRegistry = TestProbe()

val workflowDescriptor = buildWdlWorkflowDescriptor(TestWorkflows.HelloWorld)
val call: CommandCallNode = workflowDescriptor.callable.taskCallNodes.head
val jobKey = BackendJobDescriptorKey(call, None, 1)

val jobDescriptor = BackendJobDescriptor(workflowDescriptor,
jobKey,
runtimeAttributes = Map.empty,
evaluatedTaskInputs = Map.empty,
NoDocker,
None,
Map.empty
)

val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"))

val staticRuntimeAttributeDefinitions: Set[RuntimeAttributeDefinition] =
GcpBatchRuntimeAttributes.runtimeAttributesBuilder(GcpBatchTestConfig.gcpBatchConfiguration).definitions.toSet

val defaultedAttributes =
RuntimeAttributeDefinition.addDefaultsToAttributes(staticRuntimeAttributeDefinitions,
WorkflowOptions.fromMap(Map.empty).get
)(
runtimeAttributes
)
val validatedRuntimeAttributes = runtimeAttributesBuilder.build(defaultedAttributes, NOPLogger.NOP_LOGGER)

val actor = system.actorOf(
BatchPollResultMonitorActor.props(serviceRegistry.ref,
workflowDescriptor,
jobDescriptor,
validatedRuntimeAttributes,
Some(Gcp),
jobLogger
)
)
val vmInfo = InstantiatedVmInfo("europe-west9", "custom-16-32768", false)

behavior of "BatchPollResultMonitorActor"

it should "send a cost lookup request with the correct vm info after receiving a success pollResult" in {

val terminalPollResult =
RunStatus.Success(Seq(ExecutionEvent("fakeEvent", OffsetDateTime.now().truncatedTo(ChronoUnit.MILLIS))),
Some(vmInfo)
)
val message = ProcessThisPollResult(terminalPollResult)

actor ! message

serviceRegistry.expectMsgPF(1.seconds) { case m: GcpCostLookupRequest =>
m.vmInfo shouldBe vmInfo
}
}

it should "emit the correct cost metadata after receiving a costLookupResponse" in {

val costLookupResponse = GcpCostLookupResponse(vmInfo, Valid(BigDecimal(0.1)))

actor ! costLookupResponse

serviceRegistry.expectMsgPF(1.seconds) { case m: PutMetadataAction =>
val event = m.events.head
m.events.size shouldBe 1
event.key.key shouldBe CallMetadataKeys.VmCostPerHour
event.value.get.value shouldBe "0.1"
}
}

it should "emit the correct start time after receiving a running pollResult" in {

val vmStartTime = OffsetDateTime.now().minus(2, ChronoUnit.HOURS)
val pollResult = RunStatus.Running(
Seq(ExecutionEvent(CallMetadataKeys.VmStartTime, vmStartTime)),
Some(vmInfo)
)
val message = ProcessThisPollResult(pollResult)

actor ! message

serviceRegistry.expectMsgPF(1.seconds) { case m: PutMetadataAction =>
val event = m.events.head
m.events.size shouldBe 1
event.key.key shouldBe CallMetadataKeys.VmStartTime
assert(
Instant
.parse(event.value.get.value)
.equals(vmStartTime.toInstant.truncatedTo(ChronoUnit.MILLIS))
)
}
}

it should "emit the correct end time after receiving a running pollResult" in {

val vmEndTime = OffsetDateTime.now().minus(2, ChronoUnit.HOURS)
val pollResult = RunStatus.Running(
Seq(ExecutionEvent(CallMetadataKeys.VmEndTime, vmEndTime)),
Some(vmInfo)
)
val message = ProcessThisPollResult(pollResult)

actor ! message

serviceRegistry.expectMsgPF(1.seconds) { case m: PutMetadataAction =>
val event = m.events.head
m.events.size shouldBe 1
event.key.key shouldBe CallMetadataKeys.VmEndTime
assert(
Instant
.parse(event.value.get.value)
.equals(vmEndTime.toInstant.truncatedTo(ChronoUnit.MILLIS))
)
}
}
}
Loading