From c044228de00625f05bd770b3857867af3dca4641 Mon Sep 17 00:00:00 2001 From: Tom Wiseman Date: Wed, 18 Sep 2024 13:26:29 -0400 Subject: [PATCH] the gist --- .../PollResultMonitorActor.scala | 22 +++++ core/src/main/resources/reference.conf | 2 +- .../services/cost/GcpCostCatalogService.scala | 83 +++++++++++++++++-- .../services/cost/GcpCostCatalogTypes.scala | 50 +++++++++++ .../cost/GcpCostCatalogServiceSpec.scala | 8 +- .../actors/BatchPollResultMonitorActor.scala | 11 +-- .../common/PapiPollResultMonitorActor.scala | 18 ++-- .../pipelines/common/api/RunStatus.scala | 32 ++++--- .../pipelines/common/CostLookupSpec.scala | 62 ++++++++++++++ .../v2beta/api/request/ErrorReporter.scala | 2 +- .../api/request/GetRequestHandler.scala | 82 +++++++++--------- 11 files changed, 294 insertions(+), 78 deletions(-) create mode 100644 supportedBackends/google/pipelines/common/src/test/scala/cromwell/backend/google/pipelines/common/CostLookupSpec.scala diff --git a/backend/src/main/scala/cromwell/backend/standard/pollmonitoring/PollResultMonitorActor.scala b/backend/src/main/scala/cromwell/backend/standard/pollmonitoring/PollResultMonitorActor.scala index d5b8110d5dc..43d7132f999 100644 --- a/backend/src/main/scala/cromwell/backend/standard/pollmonitoring/PollResultMonitorActor.scala +++ b/backend/src/main/scala/cromwell/backend/standard/pollmonitoring/PollResultMonitorActor.scala @@ -9,6 +9,7 @@ import cromwell.backend.validation.{ ValidatedRuntimeAttributes } import cromwell.core.logging.JobLogger +import cromwell.services.cost.{GcpCostLookupRequest, GcpCostLookupResponse, InstantiatedVmInfo} import cromwell.services.metadata.CallMetadataKeys import cromwell.services.metrics.bard.BardEventing.BardEventRequest import cromwell.services.metrics.bard.model.TaskSummaryEvent @@ -42,6 +43,9 @@ trait PollResultMonitorActor[PollResultType] extends Actor { // Time that the user VM started spending money. def extractStartTimeFromRunState(pollStatus: PollResultType): Option[OffsetDateTime] + // Used to kick off a cost calculation + def extractVmInfoFromRunState(pollStatus: PollResultType): Option[InstantiatedVmInfo] + // Time that the user VM stopped spending money. def extractEndTimeFromRunState(pollStatus: PollResultType): Option[OffsetDateTime] @@ -99,6 +103,7 @@ trait PollResultMonitorActor[PollResultType] extends Actor { Option.empty private var vmStartTime: Option[OffsetDateTime] = Option.empty private var vmEndTime: Option[OffsetDateTime] = Option.empty + private var vmCostPerHour: Option[BigDecimal] = Option.empty def processPollResult(pollStatus: PollResultType): Unit = { // Make sure jobStartTime remains the earliest event time ever seen @@ -122,6 +127,16 @@ trait PollResultMonitorActor[PollResultType] extends Actor { tellMetadata(Map(CallMetadataKeys.VmEndTime -> end)) } } + // If we don't yet have a cost per hour and we can extract VM info, send a cost request to the catalog service. + // We expect it to reply with an answer, which is handled in receive. + // NB: Due to the nature of async code, we may send a few cost requests before we get a response back. + if (vmCostPerHour.isEmpty) { + val instantiatedVmInfo = extractVmInfoFromRunState(pollStatus) + instantiatedVmInfo.foreach { vmInfo => + val request = GcpCostLookupRequest(vmInfo, self) + params.serviceRegistry ! request + } + } } // When a job finishes, the bard actor needs to know about the timing in order to record metrics. @@ -135,4 +150,11 @@ trait PollResultMonitorActor[PollResultType] extends Actor { vmEndTime = vmEndTime.getOrElse(OffsetDateTime.now()) ) ) + + def handleCostResponse(costLookupResponse: GcpCostLookupResponse): Unit = { + if (vmCostPerHour.isDefined) { return } // Optimization to avoid processing responses after we've stopped caring. + val cost = costLookupResponse.calculatedCost.getOrElse(BigDecimal(-1)) // TODO: better logging here. + vmCostPerHour = Option(cost) + tellMetadata(Map(CallMetadataKeys.VmCostPerHour -> vmCostPerHour)) + } } diff --git a/core/src/main/resources/reference.conf b/core/src/main/resources/reference.conf index 465f43ad7a1..7b776c87e4e 100644 --- a/core/src/main/resources/reference.conf +++ b/core/src/main/resources/reference.conf @@ -607,7 +607,7 @@ services { } } - CostCatalogService { + GcpCostCatalogService { class = "cromwell.services.cost.GcpCostCatalogService" config { catalogExpirySeconds = 86400 diff --git a/services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala b/services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala index 35012b7e424..7878c146b97 100644 --- a/services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala +++ b/services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala @@ -1,22 +1,30 @@ package cromwell.services.cost import akka.actor.{Actor, ActorRef} +import com.google.`type`.Money import com.google.cloud.billing.v1._ import com.typesafe.config.Config import com.typesafe.scalalogging.LazyLogging import common.util.StringUtil.EnhancedToStringable +import cromwell.services.ServiceRegistryActor.ServiceRegistryMessage import cromwell.services.cost.GcpCostCatalogService.{COMPUTE_ENGINE_SERVICE_NAME, DEFAULT_CURRENCY_CODE} import cromwell.util.GracefulShutdownHelper.ShutdownCommand import java.time.{Duration, Instant} import scala.jdk.CollectionConverters.IterableHasAsScala import java.time.temporal.ChronoUnit.SECONDS +import scala.util.{Failure, Success, Try} case class CostCatalogKey(machineType: Option[MachineType], usageType: Option[UsageType], machineCustomization: Option[MachineCustomization], - resourceGroup: Option[ResourceGroup] + resourceGroup: Option[ResourceGroup], + region: String ) +case class GcpCostLookupRequest(vmInfo: InstantiatedVmInfo, replyTo: ActorRef) extends ServiceRegistryMessage { + override def serviceName: String = GcpCostCatalogService.getClass.getSimpleName +} +case class GcpCostLookupResponse(calculatedCost: Option[BigDecimal]) case class CostCatalogValue(catalogObject: Sku) case class ExpiringGcpCostCatalog(catalog: Map[CostCatalogKey, CostCatalogValue], fetchTime: Instant) @@ -88,21 +96,78 @@ class GcpCostCatalogService(serviceConfig: Config, globalConfig: Config, service * Ideally, we don't want to have an entire, unprocessed, cost catalog in memory at once since it's ~20MB. */ private def processCostCatalog(skus: Iterable[Sku]): Map[CostCatalogKey, CostCatalogValue] = - // TODO: Account for key collisions (same key can be in multiple regions) // TODO: reduce memory footprint of returned map (don't store entire SKU object) skus.foldLeft(Map.empty[CostCatalogKey, CostCatalogValue]) { case (acc, sku) => - acc + convertSkuToKeyValuePair(sku) + acc ++ convertSkuToKeyValuePairs(sku) + } + + private def convertSkuToKeyValuePairs(sku: Sku): List[(CostCatalogKey, CostCatalogValue)] = { + val allAvailableRegions = sku.getServiceRegionsList.asScala.toList + allAvailableRegions.map(region => + CostCatalogKey( + machineType = MachineType.fromSku(sku), + usageType = UsageType.fromSku(sku), + machineCustomization = MachineCustomization.fromSku(sku), + resourceGroup = ResourceGroup.fromSku(sku), + region = region + ) -> CostCatalogValue(sku) + ) + } + + // See: https://cloud.google.com/billing/v1/how-tos/catalog-api + private def calculateCpuPricePerHour(cpuSku: Sku, coreCount: Int): Try[BigDecimal] = { + val pricingInfo = getMostRecentPricingInfo(cpuSku) + val usageUnit = pricingInfo.getPricingExpression.getUsageUnit + if (usageUnit != "h") { + return Failure(new UnsupportedOperationException(s"Expected usage units of CPUs to be 'h'. Got ${usageUnit}")) } + // Price per hour of a single core + // NB: Ignoring "TieredRates" here (the idea that stuff gets cheaper the more you use). + // Technically, we should write code that determines which tier(s) to use. + // In practice, from what I've seen, CPU cores and RAM don't have more than a single tier. + val costPerUnit: Money = pricingInfo.getPricingExpression.getTieredRates(0).getUnitPrice + val costPerCorePerHour: BigDecimal = + costPerUnit.getUnits + (costPerUnit.getNanos * 10e-9) // Same as above, but as a big decimal + Success(costPerCorePerHour * coreCount) + } + + private def calculateRamPricePerHour(ramSku: Sku, ramMbCount: Int): Try[BigDecimal] = + // TODO + Success(ramMbCount.toLong * 0.25) - private def convertSkuToKeyValuePair(sku: Sku): (CostCatalogKey, CostCatalogValue) = CostCatalogKey( - machineType = MachineType.fromSku(sku), - usageType = UsageType.fromSku(sku), - machineCustomization = MachineCustomization.fromSku(sku), - resourceGroup = ResourceGroup.fromSku(sku) - ) -> CostCatalogValue(sku) + private def getMostRecentPricingInfo(sku: Sku): PricingInfo = { + val mostRecentPricingInfoIndex = sku.getPricingInfoCount - 1 + sku.getPricingInfo(mostRecentPricingInfoIndex) + } + + private def calculateVmCostPerHour(instantiatedVmInfo: InstantiatedVmInfo): Try[BigDecimal] = { + val machineType = MachineType.fromGoogleMachineTypeString(instantiatedVmInfo.machineType) + val usageType = UsageType.fromBoolean(instantiatedVmInfo.preemptible) + val machineCustomization = MachineCustomization.fromMachineTypeString(instantiatedVmInfo.machineType) + val region = instantiatedVmInfo.region + val coreCount = MachineType.extractCoreCountFromMachineTypeString(instantiatedVmInfo.machineType) + val ramMbCount = MachineType.extractRamMbFromMachineTypeString(instantiatedVmInfo.machineType) + + val cpuResourceGroup = Cpu // TODO: Investigate the situation in which the resource group is n1 + val cpuKey = + CostCatalogKey(machineType, Option(usageType), Option(machineCustomization), Option(cpuResourceGroup), region) + val cpuSku = getSku(cpuKey) + val cpuCost = cpuSku.map(sku => calculateCpuPricePerHour(sku.catalogObject, coreCount.get)) // TODO .get + + val ramResourceGroup = Ram + val ramKey = + CostCatalogKey(machineType, Option(usageType), Option(machineCustomization), Option(ramResourceGroup), region) + val ramSku = getSku(ramKey) + val ramCost = ramSku.map(sku => calculateRamPricePerHour(sku.catalogObject, ramMbCount.get)) // TODO .get + Success(cpuCost.get.get + ramCost.get.get) + } def serviceRegistryActor: ActorRef = serviceRegistry override def receive: Receive = { + case GcpCostLookupRequest(vmInfo, replyTo) => + val calculatedCost = calculateVmCostPerHour(vmInfo).toOption + val response = GcpCostLookupResponse(calculatedCost) + replyTo ! response case ShutdownCommand => googleClient.foreach(client => client.shutdownNow()) context stop self diff --git a/services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala b/services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala index 7507560c810..57fed6d888c 100644 --- a/services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala +++ b/services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala @@ -2,6 +2,13 @@ package cromwell.services.cost import com.google.cloud.billing.v1.Sku +import java.util.regex.{Matcher, Pattern} +import scala.util.{Failure, Success, Try} + +/* + * Case class that contains information retrieved from Google about a VM that cromwell has started + */ +case class InstantiatedVmInfo(region: String, machineType: String, preemptible: Boolean) /* * These types reflect hardcoded strings found in a google cost catalog. */ @@ -13,6 +20,32 @@ object MachineType { else if (tokenizedDescription.contains(N2d.machineTypeName)) Some(N2d) else Option.empty } + + // expects a string that looks something like "n1-standard-1" or "custom-1-4096" + def fromGoogleMachineTypeString(machineTypeString: String): Option[MachineType] = + if (machineTypeString.startsWith("n1")) Some(N1) + else if (machineTypeString.startsWith("n2d")) Some(N2d) + else if (machineTypeString.startsWith("n2")) Some(N2) + else if (machineTypeString.startsWith("custom")) + None // TODO: should this be n1? Make a 'custom' type? Combine with MachineCustomization? + else { + println(s"Error: Unrecognized machine type: $machineTypeString") + None + } + + def extractCoreCountFromMachineTypeString(machineTypeString: String): Try[Int] = { + val pattern: Pattern = Pattern.compile("-(\\d+)") + val matcher: Matcher = pattern.matcher(machineTypeString) + if (matcher.find()) { + Success(matcher.group(1).toInt) + } else { + Failure(new IllegalArgumentException(s"Could not extract core count from ${machineTypeString}")) + } + } + + def extractRamMbFromMachineTypeString(machineTypeString: String): Try[Int] = + // TODO + Success(4096) } sealed trait MachineType { def machineTypeName: String } case object N1 extends MachineType { override val machineTypeName = "N1" } @@ -26,12 +59,23 @@ object UsageType { case Preemptible.typeName => Some(Preemptible) case _ => Option.empty } + def fromBoolean(isPreemptible: Boolean): UsageType = isPreemptible match { + case true => Preemptible + case false => OnDemand + } + } sealed trait UsageType { def typeName: String } case object OnDemand extends UsageType { override val typeName = "OnDemand" } case object Preemptible extends UsageType { override val typeName = "Preemptible" } object MachineCustomization { + // TODO: I think this is right but I am not 100% sure. Needs testing. + // Do custom machine types always have the word "custom"? + // Does Cromwell ever assign Predefined Machines? + // Is it possible to have a custom machine that *doesn't* contain the word custom? + def fromMachineTypeString(machineTypeString: String): MachineCustomization = + if (machineTypeString.toLowerCase.contains("custom")) Custom else Predefined def fromSku(sku: Sku): Option[MachineCustomization] = { val tokenizedDescription = sku.getDescription.split(" ") if (tokenizedDescription.contains(Predefined.customizationName)) Some(Predefined) @@ -55,4 +99,10 @@ object ResourceGroup { sealed trait ResourceGroup { def groupName: String } case object Cpu extends ResourceGroup { override val groupName = "CPU" } case object Ram extends ResourceGroup { override val groupName = "RAM" } + +// TODO: What is the deal with this? It seems out of place. +// Need to figure out how to reconcile with the Cpu resource group. +// Current theory is that the n1 machines are legacy machines, +// and are therefore categorized differently. +// Unfortunately, N1 is Cromwell's default machine. case object N1Standard extends ResourceGroup { override val groupName = "N1Standard" } diff --git a/services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala b/services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala index 19d9e505ac6..2c5298d0cb3 100644 --- a/services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala +++ b/services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala @@ -75,7 +75,8 @@ class GcpCostCatalogServiceSpec machineType = Some(N2), usageType = Some(Preemptible), machineCustomization = Some(Predefined), - resourceGroup = Some(Cpu) + resourceGroup = Some(Cpu), + region = "europe-west9" ) val freshActor = constructTestActor @@ -103,9 +104,10 @@ class GcpCostCatalogServiceSpec machineType = Some(N2d), usageType = Some(Preemptible), machineCustomization = None, - resourceGroup = Some(Ram) + resourceGroup = Some(Ram), + region = "europe-west9" ) val foundValue = testActorRef.getSku(expectedKey) - foundValue.get.catalogObject.getDescription shouldBe "Spot Preemptible N2D AMD Instance Ram running in Johannesburg" + foundValue.get.catalogObject.getDescription shouldBe "Spot Preemptible N2D AMD Instance Ram running in Paris" } } diff --git a/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala b/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala index 8b05bf4057b..0f935b17c79 100644 --- a/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala +++ b/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala @@ -3,15 +3,10 @@ package cromwell.backend.google.batch.actors import akka.actor.{ActorRef, Props} import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform} import cromwell.backend.google.batch.models.RunStatus -import cromwell.backend.standard.pollmonitoring.{ - AsyncJobHasFinished, - PollMonitorParameters, - PollResultMessage, - PollResultMonitorActor, - ProcessThisPollResult -} +import cromwell.backend.standard.pollmonitoring.{AsyncJobHasFinished, PollMonitorParameters, PollResultMessage, PollResultMonitorActor, ProcessThisPollResult} import cromwell.backend.validation.ValidatedRuntimeAttributes import cromwell.core.logging.JobLogger +import cromwell.services.cost.InstantiatedVmInfo import cromwell.services.metadata.CallMetadataKeys import java.time.OffsetDateTime @@ -78,4 +73,6 @@ class BatchPollResultMonitorActor(pollMonitorParameters: PollMonitorParameters) } override def params: PollMonitorParameters = pollMonitorParameters + + override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] = Option.empty //TODO } diff --git a/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PapiPollResultMonitorActor.scala b/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PapiPollResultMonitorActor.scala index 597c0ed8d35..c1bbef4edc2 100644 --- a/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PapiPollResultMonitorActor.scala +++ b/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PapiPollResultMonitorActor.scala @@ -1,17 +1,12 @@ package cromwell.backend.google.pipelines.common import akka.actor.{ActorRef, Props} -import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform} import cromwell.backend.google.pipelines.common.api.RunStatus -import cromwell.backend.standard.pollmonitoring.{ - AsyncJobHasFinished, - PollMonitorParameters, - PollResultMessage, - PollResultMonitorActor, - ProcessThisPollResult -} +import cromwell.backend.standard.pollmonitoring._ import cromwell.backend.validation.ValidatedRuntimeAttributes +import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform} import cromwell.core.logging.JobLogger +import cromwell.services.cost.{GcpCostLookupResponse, InstantiatedVmInfo} import cromwell.services.metadata.CallMetadataKeys import java.time.OffsetDateTime @@ -50,7 +45,13 @@ class PapiPollResultMonitorActor(parameters: PollMonitorParameters) extends Poll case event if event.name == CallMetadataKeys.VmEndTime => event.offsetDateTime } + override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] = + pollStatus.instantiatedVmInfo + + override def params: PollMonitorParameters = parameters + override def receive: Receive = { + case costResponse: GcpCostLookupResponse => handleCostResponse(costResponse) case message: PollResultMessage => message match { case ProcessThisPollResult(pollResult: RunStatus) => processPollResult(pollResult) @@ -75,5 +76,4 @@ class PapiPollResultMonitorActor(parameters: PollMonitorParameters) extends Poll ) ) } - override def params: PollMonitorParameters = parameters } diff --git a/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/api/RunStatus.scala b/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/api/RunStatus.scala index 03e49e5c1c1..d5be6707161 100644 --- a/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/api/RunStatus.scala +++ b/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/api/RunStatus.scala @@ -3,20 +3,26 @@ package cromwell.backend.google.pipelines.common.api import _root_.io.grpc.Status import cromwell.backend.google.pipelines.common.PipelinesApiAsyncBackendJobExecutionActor import cromwell.core.ExecutionEvent +import cromwell.services.cost.InstantiatedVmInfo import scala.util.Try - sealed trait RunStatus { def eventList: Seq[ExecutionEvent] def toString: String + + val instantiatedVmInfo: Option[InstantiatedVmInfo] } object RunStatus { - case class Initializing(eventList: Seq[ExecutionEvent]) extends RunStatus { override def toString = "Initializing" } - case class AwaitingCloudQuota(eventList: Seq[ExecutionEvent]) extends RunStatus { + case class Initializing(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty) + extends RunStatus { override def toString = "Initializing" } + case class AwaitingCloudQuota(eventList: Seq[ExecutionEvent], + instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty + ) extends RunStatus { override def toString = "AwaitingCloudQuota" } - case class Running(eventList: Seq[ExecutionEvent]) extends RunStatus { override def toString = "Running" } + case class Running(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty) + extends RunStatus { override def toString = "Running" } sealed trait TerminalRunStatus extends RunStatus { def machineType: Option[String] @@ -38,7 +44,8 @@ object RunStatus { case class Success(eventList: Seq[ExecutionEvent], machineType: Option[String], zone: Option[String], - instanceName: Option[String] + instanceName: Option[String], + instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty ) extends TerminalRunStatus { override def toString = "Success" } @@ -88,7 +95,8 @@ object RunStatus { eventList, machineType, zone, - instanceName + instanceName, + Option.empty ) } } @@ -99,7 +107,8 @@ object RunStatus { eventList: Seq[ExecutionEvent], machineType: Option[String], zone: Option[String], - instanceName: Option[String] + instanceName: Option[String], + instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty ) extends UnsuccessfulRunStatus { override def toString = "Failed" } @@ -113,7 +122,8 @@ object RunStatus { eventList: Seq[ExecutionEvent], machineType: Option[String], zone: Option[String], - instanceName: Option[String] + instanceName: Option[String], + instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty ) extends UnsuccessfulRunStatus { override def toString = "Cancelled" } @@ -124,7 +134,8 @@ object RunStatus { eventList: Seq[ExecutionEvent], machineType: Option[String], zone: Option[String], - instanceName: Option[String] + instanceName: Option[String], + instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty ) extends UnsuccessfulRunStatus { override def toString = "Preempted" } @@ -139,7 +150,8 @@ object RunStatus { eventList: Seq[ExecutionEvent], machineType: Option[String], zone: Option[String], - instanceName: Option[String] + instanceName: Option[String], + instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty ) extends UnsuccessfulRunStatus { override def toString = "QuotaFailed" } diff --git a/supportedBackends/google/pipelines/common/src/test/scala/cromwell/backend/google/pipelines/common/CostLookupSpec.scala b/supportedBackends/google/pipelines/common/src/test/scala/cromwell/backend/google/pipelines/common/CostLookupSpec.scala new file mode 100644 index 00000000000..de9490d86e7 --- /dev/null +++ b/supportedBackends/google/pipelines/common/src/test/scala/cromwell/backend/google/pipelines/common/CostLookupSpec.scala @@ -0,0 +1,62 @@ +package cromwell.backend.google.pipelines.common + +import akka.testkit.{ImplicitSender, TestActorRef, TestProbe} +import cromwell.core.TestKitSuite +import org.scalatest.flatspec.AnyFlatSpecLike +import org.scalatest.matchers.should.Matchers +import cromwell.services.cost._ +import org.scalatest.concurrent.Eventually + +class CostLookupSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with Eventually with ImplicitSender { + behavior of "CostLookup" + + def constructTestActor: GcpCostCatalogServiceTestActor = + TestActorRef( + new GcpCostCatalogServiceTestActor(GcpCostCatalogServiceSpec.config, + GcpCostCatalogServiceSpec.config, + TestProbe().ref + ) + ).underlyingActor + + val testCatalogService = constructTestActor + + it should "find a CPU sku" in { + val machineType = Some(N2) + val usageType = Some(OnDemand) + val customization = Some(Custom) + val resourceGroup = Some(Cpu) + val region = "europe-west9" + val key = CostCatalogKey(machineType, usageType, customization, resourceGroup, region) + val result = testCatalogService.getSku(key).get.catalogObject.getDescription + result shouldBe "N2 Custom Instance Core running in Paris" + } + + it should "find a RAM sku" in { + val machineType = Some(N2) + val usageType = Some(OnDemand) + val customization = Some(Custom) + val resourceGroup = Some(Ram) + val region = "europe-west9" + val key = CostCatalogKey(machineType, usageType, customization, resourceGroup, region) + val result = testCatalogService.getSku(key).get.catalogObject.getDescription + result shouldBe "N2 Custom Instance Ram running in Paris" + } + + it should "find CPU skus for all supported machine types" in { + val legalMachineTypes: List[MachineType] = List(N1, N2, N2d) + val legalUsageTypes: List[UsageType] = List(Preemptible, OnDemand) + val legalCustomizations: List[MachineCustomization] = List(Custom, Predefined) + val resourceGroup: Option[ResourceGroup] = Some(Cpu) + val region = "us-west1" + for (machineType <- legalMachineTypes) + for (usageType <- legalUsageTypes) + for (customization <- legalCustomizations) { + val key = CostCatalogKey(Some(machineType), Some(usageType), Some(customization), resourceGroup, region) + val result = testCatalogService.getSku(key) + if (!result.isEmpty) { + println("Success") + } + result.isEmpty shouldBe false + } + } +} diff --git a/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/ErrorReporter.scala b/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/ErrorReporter.scala index 77e53176df3..1c9be388c7d 100644 --- a/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/ErrorReporter.scala +++ b/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/ErrorReporter.scala @@ -86,7 +86,7 @@ class ErrorReporter(machineType: Option[String], // Reverse the list because the first failure (likely the most relevant, will appear last otherwise) val unexpectedExitEvents: List[String] = unexpectedExitStatusErrorStrings(events, actions).reverse - builder(status, None, failed.toList ++ unexpectedExitEvents, executionEvents, machineType, zone, instanceName) + builder(status, None, failed.toList ++ unexpectedExitEvents, executionEvents, machineType, zone, instanceName, Option.empty) } // There's maybe one FailedEvent per operation with a summary error message diff --git a/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/GetRequestHandler.scala b/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/GetRequestHandler.scala index 9cfbf62ca4f..d9f622b8af5 100644 --- a/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/GetRequestHandler.scala +++ b/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/GetRequestHandler.scala @@ -9,19 +9,14 @@ import common.validation.Validation._ import cromwell.backend.google.pipelines.common.action.ActionLabels._ import cromwell.backend.google.pipelines.common.api.PipelinesApiRequestManager._ import cromwell.backend.google.pipelines.common.api.RunStatus -import cromwell.backend.google.pipelines.common.api.RunStatus.{ - AwaitingCloudQuota, - Initializing, - Running, - Success, - UnsuccessfulRunStatus -} +import cromwell.backend.google.pipelines.common.api.RunStatus.{AwaitingCloudQuota, Initializing, Running, Success, UnsuccessfulRunStatus} import cromwell.backend.google.pipelines.common.errors.isQuotaMessage import cromwell.backend.google.pipelines.v2beta.PipelinesConversions._ import cromwell.backend.google.pipelines.v2beta.api.Deserialization._ import cromwell.backend.google.pipelines.v2beta.api.request.ErrorReporter._ import cromwell.cloudsupport.gcp.auth.GoogleAuthMode import cromwell.core.ExecutionEvent +import cromwell.services.cost.InstantiatedVmInfo import cromwell.services.metadata.CallMetadataKeys import io.grpc.Status import org.apache.commons.lang3.exception.ExceptionUtils @@ -29,7 +24,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils import scala.jdk.CollectionConverters._ import scala.concurrent.{ExecutionContext, Future} import scala.language.postfixOps -import scala.util.{Failure, Success => TrySuccess, Try} +import scala.util.{Failure, Try, Success => TrySuccess} trait GetRequestHandler { this: RequestHandler => // the Genomics batch endpoint doesn't seem to be able to handle get requests on V2 operations at the moment @@ -81,33 +76,44 @@ trait GetRequestHandler { this: RequestHandler => .toList .flatten val executionEvents = getEventList(metadata, events, actions) + val workerAssignedEvent: Option[WorkerAssignedEvent] = + events.collectFirst { + case event if event.getWorkerAssigned != null => event.getWorkerAssigned + } + val virtualMachineOption = for { + pipelineValue <- pipeline + resources <- Option(pipelineValue.getResources) + virtualMachine <- Option(resources.getVirtualMachine) + } yield virtualMachine + + // Correlate `executionEvents` to `actions` to potentially assign a grouping into the appropriate events. + val machineType = virtualMachineOption.flatMap(virtualMachine => Option(virtualMachine.getMachineType)) + /* + preemptible is only used if the job fails, as a heuristic to guess if the VM was preempted. + If we can't get the value of preempted we still need to return something, returning false will not make the + failure count as a preemption which seems better than saying that it was preemptible when we really don't know + */ + val preemptibleOption = for { + pipelineValue <- pipeline + resources <- Option(pipelineValue.getResources) + virtualMachine <- Option(resources.getVirtualMachine) + preemptible <- Option(virtualMachine.getPreemptible) + } yield preemptible + val preemptible = preemptibleOption.exists(_.booleanValue) + val instanceName = + workerAssignedEvent.flatMap(workerAssignedEvent => Option(workerAssignedEvent.getInstance())) + val zone = workerAssignedEvent.flatMap(workerAssignedEvent => Option(workerAssignedEvent.getZone)) + val region = zone.map { zoneString => + val lastDashIndex = zoneString.lastIndexOf("-") + if (lastDashIndex != -1) zoneString.substring(0, lastDashIndex) else zoneString + } + + val instantiatedVmInfo: Option[InstantiatedVmInfo] = (region, machineType) match { + case (Some(instantiatedRegion), Some(instantiatedMachineType)) => + Option(InstantiatedVmInfo(instantiatedRegion, instantiatedMachineType, preemptible)) + case _ => Option.empty + } if (operation.getDone) { - val workerAssignedEvent: Option[WorkerAssignedEvent] = - events.collectFirst { - case event if event.getWorkerAssigned != null => event.getWorkerAssigned - } - val virtualMachineOption = for { - pipelineValue <- pipeline - resources <- Option(pipelineValue.getResources) - virtualMachine <- Option(resources.getVirtualMachine) - } yield virtualMachine - // Correlate `executionEvents` to `actions` to potentially assign a grouping into the appropriate events. - val machineType = virtualMachineOption.flatMap(virtualMachine => Option(virtualMachine.getMachineType)) - /* - preemptible is only used if the job fails, as a heuristic to guess if the VM was preempted. - If we can't get the value of preempted we still need to return something, returning false will not make the - failure count as a preemption which seems better than saying that it was preemptible when we really don't know - */ - val preemptibleOption = for { - pipelineValue <- pipeline - resources <- Option(pipelineValue.getResources) - virtualMachine <- Option(resources.getVirtualMachine) - preemptible <- Option(virtualMachine.getPreemptible) - } yield preemptible - val preemptible = preemptibleOption.exists(_.booleanValue) - val instanceName = - workerAssignedEvent.flatMap(workerAssignedEvent => Option(workerAssignedEvent.getInstance())) - val zone = workerAssignedEvent.flatMap(workerAssignedEvent => Option(workerAssignedEvent.getZone)) // If there's an error, generate an unsuccessful status. Otherwise, we were successful! Option(operation.getError) match { case Some(error) => @@ -122,14 +128,14 @@ trait GetRequestHandler { this: RequestHandler => pollingRequest.workflowId ) errorReporter.toUnsuccessfulRunStatus(error, events) - case None => Success(executionEvents, machineType, zone, instanceName) + case None => Success(executionEvents, machineType, zone, instanceName, instantiatedVmInfo) } } else if (isQuotaDelayed(events)) { - AwaitingCloudQuota(executionEvents) + AwaitingCloudQuota(executionEvents, instantiatedVmInfo) } else if (operation.hasStarted) { - Running(executionEvents) + Running(executionEvents, instantiatedVmInfo) } else { - Initializing(executionEvents) + Initializing(executionEvents, instantiatedVmInfo) } } catch { case nullPointerException: NullPointerException =>