Skip to content

Commit

Permalink
the gist
Browse files Browse the repository at this point in the history
  • Loading branch information
THWiseman committed Sep 18, 2024
1 parent 7de9ca1 commit c044228
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import cromwell.backend.validation.{
ValidatedRuntimeAttributes
}
import cromwell.core.logging.JobLogger
import cromwell.services.cost.{GcpCostLookupRequest, GcpCostLookupResponse, InstantiatedVmInfo}
import cromwell.services.metadata.CallMetadataKeys
import cromwell.services.metrics.bard.BardEventing.BardEventRequest
import cromwell.services.metrics.bard.model.TaskSummaryEvent
Expand Down Expand Up @@ -42,6 +43,9 @@ trait PollResultMonitorActor[PollResultType] extends Actor {
// Time that the user VM started spending money.
def extractStartTimeFromRunState(pollStatus: PollResultType): Option[OffsetDateTime]

// Used to kick off a cost calculation
def extractVmInfoFromRunState(pollStatus: PollResultType): Option[InstantiatedVmInfo]

// Time that the user VM stopped spending money.
def extractEndTimeFromRunState(pollStatus: PollResultType): Option[OffsetDateTime]

Expand Down Expand Up @@ -99,6 +103,7 @@ trait PollResultMonitorActor[PollResultType] extends Actor {
Option.empty
private var vmStartTime: Option[OffsetDateTime] = Option.empty
private var vmEndTime: Option[OffsetDateTime] = Option.empty
private var vmCostPerHour: Option[BigDecimal] = Option.empty

def processPollResult(pollStatus: PollResultType): Unit = {
// Make sure jobStartTime remains the earliest event time ever seen
Expand All @@ -122,6 +127,16 @@ trait PollResultMonitorActor[PollResultType] extends Actor {
tellMetadata(Map(CallMetadataKeys.VmEndTime -> end))
}
}
// If we don't yet have a cost per hour and we can extract VM info, send a cost request to the catalog service.
// We expect it to reply with an answer, which is handled in receive.
// NB: Due to the nature of async code, we may send a few cost requests before we get a response back.
if (vmCostPerHour.isEmpty) {
val instantiatedVmInfo = extractVmInfoFromRunState(pollStatus)
instantiatedVmInfo.foreach { vmInfo =>
val request = GcpCostLookupRequest(vmInfo, self)
params.serviceRegistry ! request
}
}
}

// When a job finishes, the bard actor needs to know about the timing in order to record metrics.
Expand All @@ -135,4 +150,11 @@ trait PollResultMonitorActor[PollResultType] extends Actor {
vmEndTime = vmEndTime.getOrElse(OffsetDateTime.now())
)
)

def handleCostResponse(costLookupResponse: GcpCostLookupResponse): Unit = {
if (vmCostPerHour.isDefined) { return } // Optimization to avoid processing responses after we've stopped caring.
val cost = costLookupResponse.calculatedCost.getOrElse(BigDecimal(-1)) // TODO: better logging here.
vmCostPerHour = Option(cost)
tellMetadata(Map(CallMetadataKeys.VmCostPerHour -> vmCostPerHour))
}
}
2 changes: 1 addition & 1 deletion core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ services {
}
}

CostCatalogService {
GcpCostCatalogService {
class = "cromwell.services.cost.GcpCostCatalogService"
config {
catalogExpirySeconds = 86400
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
package cromwell.services.cost

import akka.actor.{Actor, ActorRef}
import com.google.`type`.Money
import com.google.cloud.billing.v1._
import com.typesafe.config.Config
import com.typesafe.scalalogging.LazyLogging
import common.util.StringUtil.EnhancedToStringable
import cromwell.services.ServiceRegistryActor.ServiceRegistryMessage
import cromwell.services.cost.GcpCostCatalogService.{COMPUTE_ENGINE_SERVICE_NAME, DEFAULT_CURRENCY_CODE}
import cromwell.util.GracefulShutdownHelper.ShutdownCommand

import java.time.{Duration, Instant}
import scala.jdk.CollectionConverters.IterableHasAsScala
import java.time.temporal.ChronoUnit.SECONDS
import scala.util.{Failure, Success, Try}

case class CostCatalogKey(machineType: Option[MachineType],
usageType: Option[UsageType],
machineCustomization: Option[MachineCustomization],
resourceGroup: Option[ResourceGroup]
resourceGroup: Option[ResourceGroup],
region: String
)
case class GcpCostLookupRequest(vmInfo: InstantiatedVmInfo, replyTo: ActorRef) extends ServiceRegistryMessage {
override def serviceName: String = GcpCostCatalogService.getClass.getSimpleName
}
case class GcpCostLookupResponse(calculatedCost: Option[BigDecimal])
case class CostCatalogValue(catalogObject: Sku)
case class ExpiringGcpCostCatalog(catalog: Map[CostCatalogKey, CostCatalogValue], fetchTime: Instant)

Expand Down Expand Up @@ -88,21 +96,78 @@ class GcpCostCatalogService(serviceConfig: Config, globalConfig: Config, service
* Ideally, we don't want to have an entire, unprocessed, cost catalog in memory at once since it's ~20MB.
*/
private def processCostCatalog(skus: Iterable[Sku]): Map[CostCatalogKey, CostCatalogValue] =
// TODO: Account for key collisions (same key can be in multiple regions)
// TODO: reduce memory footprint of returned map (don't store entire SKU object)
skus.foldLeft(Map.empty[CostCatalogKey, CostCatalogValue]) { case (acc, sku) =>
acc + convertSkuToKeyValuePair(sku)
acc ++ convertSkuToKeyValuePairs(sku)
}

private def convertSkuToKeyValuePairs(sku: Sku): List[(CostCatalogKey, CostCatalogValue)] = {
val allAvailableRegions = sku.getServiceRegionsList.asScala.toList
allAvailableRegions.map(region =>
CostCatalogKey(
machineType = MachineType.fromSku(sku),
usageType = UsageType.fromSku(sku),
machineCustomization = MachineCustomization.fromSku(sku),
resourceGroup = ResourceGroup.fromSku(sku),
region = region
) -> CostCatalogValue(sku)
)
}

// See: https://cloud.google.com/billing/v1/how-tos/catalog-api
private def calculateCpuPricePerHour(cpuSku: Sku, coreCount: Int): Try[BigDecimal] = {
val pricingInfo = getMostRecentPricingInfo(cpuSku)
val usageUnit = pricingInfo.getPricingExpression.getUsageUnit
if (usageUnit != "h") {
return Failure(new UnsupportedOperationException(s"Expected usage units of CPUs to be 'h'. Got ${usageUnit}"))
}
// Price per hour of a single core
// NB: Ignoring "TieredRates" here (the idea that stuff gets cheaper the more you use).
// Technically, we should write code that determines which tier(s) to use.
// In practice, from what I've seen, CPU cores and RAM don't have more than a single tier.
val costPerUnit: Money = pricingInfo.getPricingExpression.getTieredRates(0).getUnitPrice
val costPerCorePerHour: BigDecimal =
costPerUnit.getUnits + (costPerUnit.getNanos * 10e-9) // Same as above, but as a big decimal
Success(costPerCorePerHour * coreCount)
}

private def calculateRamPricePerHour(ramSku: Sku, ramMbCount: Int): Try[BigDecimal] =
// TODO
Success(ramMbCount.toLong * 0.25)

private def convertSkuToKeyValuePair(sku: Sku): (CostCatalogKey, CostCatalogValue) = CostCatalogKey(
machineType = MachineType.fromSku(sku),
usageType = UsageType.fromSku(sku),
machineCustomization = MachineCustomization.fromSku(sku),
resourceGroup = ResourceGroup.fromSku(sku)
) -> CostCatalogValue(sku)
private def getMostRecentPricingInfo(sku: Sku): PricingInfo = {
val mostRecentPricingInfoIndex = sku.getPricingInfoCount - 1
sku.getPricingInfo(mostRecentPricingInfoIndex)
}

private def calculateVmCostPerHour(instantiatedVmInfo: InstantiatedVmInfo): Try[BigDecimal] = {
val machineType = MachineType.fromGoogleMachineTypeString(instantiatedVmInfo.machineType)
val usageType = UsageType.fromBoolean(instantiatedVmInfo.preemptible)
val machineCustomization = MachineCustomization.fromMachineTypeString(instantiatedVmInfo.machineType)
val region = instantiatedVmInfo.region
val coreCount = MachineType.extractCoreCountFromMachineTypeString(instantiatedVmInfo.machineType)
val ramMbCount = MachineType.extractRamMbFromMachineTypeString(instantiatedVmInfo.machineType)

val cpuResourceGroup = Cpu // TODO: Investigate the situation in which the resource group is n1
val cpuKey =
CostCatalogKey(machineType, Option(usageType), Option(machineCustomization), Option(cpuResourceGroup), region)
val cpuSku = getSku(cpuKey)
val cpuCost = cpuSku.map(sku => calculateCpuPricePerHour(sku.catalogObject, coreCount.get)) // TODO .get

val ramResourceGroup = Ram
val ramKey =
CostCatalogKey(machineType, Option(usageType), Option(machineCustomization), Option(ramResourceGroup), region)
val ramSku = getSku(ramKey)
val ramCost = ramSku.map(sku => calculateRamPricePerHour(sku.catalogObject, ramMbCount.get)) // TODO .get
Success(cpuCost.get.get + ramCost.get.get)
}

def serviceRegistryActor: ActorRef = serviceRegistry
override def receive: Receive = {
case GcpCostLookupRequest(vmInfo, replyTo) =>
val calculatedCost = calculateVmCostPerHour(vmInfo).toOption
val response = GcpCostLookupResponse(calculatedCost)
replyTo ! response
case ShutdownCommand =>
googleClient.foreach(client => client.shutdownNow())
context stop self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ package cromwell.services.cost

import com.google.cloud.billing.v1.Sku

import java.util.regex.{Matcher, Pattern}
import scala.util.{Failure, Success, Try}

/*
* Case class that contains information retrieved from Google about a VM that cromwell has started
*/
case class InstantiatedVmInfo(region: String, machineType: String, preemptible: Boolean)
/*
* These types reflect hardcoded strings found in a google cost catalog.
*/
Expand All @@ -13,6 +20,32 @@ object MachineType {
else if (tokenizedDescription.contains(N2d.machineTypeName)) Some(N2d)
else Option.empty
}

// expects a string that looks something like "n1-standard-1" or "custom-1-4096"
def fromGoogleMachineTypeString(machineTypeString: String): Option[MachineType] =
if (machineTypeString.startsWith("n1")) Some(N1)
else if (machineTypeString.startsWith("n2d")) Some(N2d)
else if (machineTypeString.startsWith("n2")) Some(N2)
else if (machineTypeString.startsWith("custom"))
None // TODO: should this be n1? Make a 'custom' type? Combine with MachineCustomization?
else {
println(s"Error: Unrecognized machine type: $machineTypeString")
None
}

def extractCoreCountFromMachineTypeString(machineTypeString: String): Try[Int] = {
val pattern: Pattern = Pattern.compile("-(\\d+)")
val matcher: Matcher = pattern.matcher(machineTypeString)
if (matcher.find()) {
Success(matcher.group(1).toInt)
} else {
Failure(new IllegalArgumentException(s"Could not extract core count from ${machineTypeString}"))
}
}

def extractRamMbFromMachineTypeString(machineTypeString: String): Try[Int] =
// TODO
Success(4096)
}
sealed trait MachineType { def machineTypeName: String }
case object N1 extends MachineType { override val machineTypeName = "N1" }
Expand All @@ -26,12 +59,23 @@ object UsageType {
case Preemptible.typeName => Some(Preemptible)
case _ => Option.empty
}
def fromBoolean(isPreemptible: Boolean): UsageType = isPreemptible match {
case true => Preemptible
case false => OnDemand
}

}
sealed trait UsageType { def typeName: String }
case object OnDemand extends UsageType { override val typeName = "OnDemand" }
case object Preemptible extends UsageType { override val typeName = "Preemptible" }

object MachineCustomization {
// TODO: I think this is right but I am not 100% sure. Needs testing.
// Do custom machine types always have the word "custom"?
// Does Cromwell ever assign Predefined Machines?
// Is it possible to have a custom machine that *doesn't* contain the word custom?
def fromMachineTypeString(machineTypeString: String): MachineCustomization =
if (machineTypeString.toLowerCase.contains("custom")) Custom else Predefined
def fromSku(sku: Sku): Option[MachineCustomization] = {
val tokenizedDescription = sku.getDescription.split(" ")
if (tokenizedDescription.contains(Predefined.customizationName)) Some(Predefined)
Expand All @@ -55,4 +99,10 @@ object ResourceGroup {
sealed trait ResourceGroup { def groupName: String }
case object Cpu extends ResourceGroup { override val groupName = "CPU" }
case object Ram extends ResourceGroup { override val groupName = "RAM" }

// TODO: What is the deal with this? It seems out of place.
// Need to figure out how to reconcile with the Cpu resource group.
// Current theory is that the n1 machines are legacy machines,
// and are therefore categorized differently.
// Unfortunately, N1 is Cromwell's default machine.
case object N1Standard extends ResourceGroup { override val groupName = "N1Standard" }
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ class GcpCostCatalogServiceSpec
machineType = Some(N2),
usageType = Some(Preemptible),
machineCustomization = Some(Predefined),
resourceGroup = Some(Cpu)
resourceGroup = Some(Cpu),
region = "europe-west9"
)

val freshActor = constructTestActor
Expand Down Expand Up @@ -103,9 +104,10 @@ class GcpCostCatalogServiceSpec
machineType = Some(N2d),
usageType = Some(Preemptible),
machineCustomization = None,
resourceGroup = Some(Ram)
resourceGroup = Some(Ram),
region = "europe-west9"
)
val foundValue = testActorRef.getSku(expectedKey)
foundValue.get.catalogObject.getDescription shouldBe "Spot Preemptible N2D AMD Instance Ram running in Johannesburg"
foundValue.get.catalogObject.getDescription shouldBe "Spot Preemptible N2D AMD Instance Ram running in Paris"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,10 @@ package cromwell.backend.google.batch.actors
import akka.actor.{ActorRef, Props}
import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform}
import cromwell.backend.google.batch.models.RunStatus
import cromwell.backend.standard.pollmonitoring.{
AsyncJobHasFinished,
PollMonitorParameters,
PollResultMessage,
PollResultMonitorActor,
ProcessThisPollResult
}
import cromwell.backend.standard.pollmonitoring.{AsyncJobHasFinished, PollMonitorParameters, PollResultMessage, PollResultMonitorActor, ProcessThisPollResult}
import cromwell.backend.validation.ValidatedRuntimeAttributes
import cromwell.core.logging.JobLogger
import cromwell.services.cost.InstantiatedVmInfo
import cromwell.services.metadata.CallMetadataKeys

import java.time.OffsetDateTime
Expand Down Expand Up @@ -78,4 +73,6 @@ class BatchPollResultMonitorActor(pollMonitorParameters: PollMonitorParameters)
}

override def params: PollMonitorParameters = pollMonitorParameters

override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] = Option.empty //TODO
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
package cromwell.backend.google.pipelines.common

import akka.actor.{ActorRef, Props}
import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform}
import cromwell.backend.google.pipelines.common.api.RunStatus
import cromwell.backend.standard.pollmonitoring.{
AsyncJobHasFinished,
PollMonitorParameters,
PollResultMessage,
PollResultMonitorActor,
ProcessThisPollResult
}
import cromwell.backend.standard.pollmonitoring._
import cromwell.backend.validation.ValidatedRuntimeAttributes
import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform}
import cromwell.core.logging.JobLogger
import cromwell.services.cost.{GcpCostLookupResponse, InstantiatedVmInfo}
import cromwell.services.metadata.CallMetadataKeys

import java.time.OffsetDateTime
Expand Down Expand Up @@ -50,7 +45,13 @@ class PapiPollResultMonitorActor(parameters: PollMonitorParameters) extends Poll
case event if event.name == CallMetadataKeys.VmEndTime => event.offsetDateTime
}

override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] =
pollStatus.instantiatedVmInfo

override def params: PollMonitorParameters = parameters

override def receive: Receive = {
case costResponse: GcpCostLookupResponse => handleCostResponse(costResponse)
case message: PollResultMessage =>
message match {
case ProcessThisPollResult(pollResult: RunStatus) => processPollResult(pollResult)
Expand All @@ -75,5 +76,4 @@ class PapiPollResultMonitorActor(parameters: PollMonitorParameters) extends Poll
)
)
}
override def params: PollMonitorParameters = parameters
}
Loading

0 comments on commit c044228

Please sign in to comment.