Skip to content

Commit c044228

Browse files
committed
the gist
1 parent 7de9ca1 commit c044228

File tree

11 files changed

+294
-78
lines changed

11 files changed

+294
-78
lines changed

backend/src/main/scala/cromwell/backend/standard/pollmonitoring/PollResultMonitorActor.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import cromwell.backend.validation.{
99
ValidatedRuntimeAttributes
1010
}
1111
import cromwell.core.logging.JobLogger
12+
import cromwell.services.cost.{GcpCostLookupRequest, GcpCostLookupResponse, InstantiatedVmInfo}
1213
import cromwell.services.metadata.CallMetadataKeys
1314
import cromwell.services.metrics.bard.BardEventing.BardEventRequest
1415
import cromwell.services.metrics.bard.model.TaskSummaryEvent
@@ -42,6 +43,9 @@ trait PollResultMonitorActor[PollResultType] extends Actor {
4243
// Time that the user VM started spending money.
4344
def extractStartTimeFromRunState(pollStatus: PollResultType): Option[OffsetDateTime]
4445

46+
// Used to kick off a cost calculation
47+
def extractVmInfoFromRunState(pollStatus: PollResultType): Option[InstantiatedVmInfo]
48+
4549
// Time that the user VM stopped spending money.
4650
def extractEndTimeFromRunState(pollStatus: PollResultType): Option[OffsetDateTime]
4751

@@ -99,6 +103,7 @@ trait PollResultMonitorActor[PollResultType] extends Actor {
99103
Option.empty
100104
private var vmStartTime: Option[OffsetDateTime] = Option.empty
101105
private var vmEndTime: Option[OffsetDateTime] = Option.empty
106+
private var vmCostPerHour: Option[BigDecimal] = Option.empty
102107

103108
def processPollResult(pollStatus: PollResultType): Unit = {
104109
// Make sure jobStartTime remains the earliest event time ever seen
@@ -122,6 +127,16 @@ trait PollResultMonitorActor[PollResultType] extends Actor {
122127
tellMetadata(Map(CallMetadataKeys.VmEndTime -> end))
123128
}
124129
}
130+
// If we don't yet have a cost per hour and we can extract VM info, send a cost request to the catalog service.
131+
// We expect it to reply with an answer, which is handled in receive.
132+
// NB: Due to the nature of async code, we may send a few cost requests before we get a response back.
133+
if (vmCostPerHour.isEmpty) {
134+
val instantiatedVmInfo = extractVmInfoFromRunState(pollStatus)
135+
instantiatedVmInfo.foreach { vmInfo =>
136+
val request = GcpCostLookupRequest(vmInfo, self)
137+
params.serviceRegistry ! request
138+
}
139+
}
125140
}
126141

127142
// When a job finishes, the bard actor needs to know about the timing in order to record metrics.
@@ -135,4 +150,11 @@ trait PollResultMonitorActor[PollResultType] extends Actor {
135150
vmEndTime = vmEndTime.getOrElse(OffsetDateTime.now())
136151
)
137152
)
153+
154+
def handleCostResponse(costLookupResponse: GcpCostLookupResponse): Unit = {
155+
if (vmCostPerHour.isDefined) { return } // Optimization to avoid processing responses after we've stopped caring.
156+
val cost = costLookupResponse.calculatedCost.getOrElse(BigDecimal(-1)) // TODO: better logging here.
157+
vmCostPerHour = Option(cost)
158+
tellMetadata(Map(CallMetadataKeys.VmCostPerHour -> vmCostPerHour))
159+
}
138160
}

core/src/main/resources/reference.conf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ services {
607607
}
608608
}
609609

610-
CostCatalogService {
610+
GcpCostCatalogService {
611611
class = "cromwell.services.cost.GcpCostCatalogService"
612612
config {
613613
catalogExpirySeconds = 86400

services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
11
package cromwell.services.cost
22

33
import akka.actor.{Actor, ActorRef}
4+
import com.google.`type`.Money
45
import com.google.cloud.billing.v1._
56
import com.typesafe.config.Config
67
import com.typesafe.scalalogging.LazyLogging
78
import common.util.StringUtil.EnhancedToStringable
9+
import cromwell.services.ServiceRegistryActor.ServiceRegistryMessage
810
import cromwell.services.cost.GcpCostCatalogService.{COMPUTE_ENGINE_SERVICE_NAME, DEFAULT_CURRENCY_CODE}
911
import cromwell.util.GracefulShutdownHelper.ShutdownCommand
1012

1113
import java.time.{Duration, Instant}
1214
import scala.jdk.CollectionConverters.IterableHasAsScala
1315
import java.time.temporal.ChronoUnit.SECONDS
16+
import scala.util.{Failure, Success, Try}
1417

1518
case class CostCatalogKey(machineType: Option[MachineType],
1619
usageType: Option[UsageType],
1720
machineCustomization: Option[MachineCustomization],
18-
resourceGroup: Option[ResourceGroup]
21+
resourceGroup: Option[ResourceGroup],
22+
region: String
1923
)
24+
case class GcpCostLookupRequest(vmInfo: InstantiatedVmInfo, replyTo: ActorRef) extends ServiceRegistryMessage {
25+
override def serviceName: String = GcpCostCatalogService.getClass.getSimpleName
26+
}
27+
case class GcpCostLookupResponse(calculatedCost: Option[BigDecimal])
2028
case class CostCatalogValue(catalogObject: Sku)
2129
case class ExpiringGcpCostCatalog(catalog: Map[CostCatalogKey, CostCatalogValue], fetchTime: Instant)
2230

@@ -88,21 +96,78 @@ class GcpCostCatalogService(serviceConfig: Config, globalConfig: Config, service
8896
* Ideally, we don't want to have an entire, unprocessed, cost catalog in memory at once since it's ~20MB.
8997
*/
9098
private def processCostCatalog(skus: Iterable[Sku]): Map[CostCatalogKey, CostCatalogValue] =
91-
// TODO: Account for key collisions (same key can be in multiple regions)
9299
// TODO: reduce memory footprint of returned map (don't store entire SKU object)
93100
skus.foldLeft(Map.empty[CostCatalogKey, CostCatalogValue]) { case (acc, sku) =>
94-
acc + convertSkuToKeyValuePair(sku)
101+
acc ++ convertSkuToKeyValuePairs(sku)
102+
}
103+
104+
private def convertSkuToKeyValuePairs(sku: Sku): List[(CostCatalogKey, CostCatalogValue)] = {
105+
val allAvailableRegions = sku.getServiceRegionsList.asScala.toList
106+
allAvailableRegions.map(region =>
107+
CostCatalogKey(
108+
machineType = MachineType.fromSku(sku),
109+
usageType = UsageType.fromSku(sku),
110+
machineCustomization = MachineCustomization.fromSku(sku),
111+
resourceGroup = ResourceGroup.fromSku(sku),
112+
region = region
113+
) -> CostCatalogValue(sku)
114+
)
115+
}
116+
117+
// See: https://cloud.google.com/billing/v1/how-tos/catalog-api
118+
private def calculateCpuPricePerHour(cpuSku: Sku, coreCount: Int): Try[BigDecimal] = {
119+
val pricingInfo = getMostRecentPricingInfo(cpuSku)
120+
val usageUnit = pricingInfo.getPricingExpression.getUsageUnit
121+
if (usageUnit != "h") {
122+
return Failure(new UnsupportedOperationException(s"Expected usage units of CPUs to be 'h'. Got ${usageUnit}"))
95123
}
124+
// Price per hour of a single core
125+
// NB: Ignoring "TieredRates" here (the idea that stuff gets cheaper the more you use).
126+
// Technically, we should write code that determines which tier(s) to use.
127+
// In practice, from what I've seen, CPU cores and RAM don't have more than a single tier.
128+
val costPerUnit: Money = pricingInfo.getPricingExpression.getTieredRates(0).getUnitPrice
129+
val costPerCorePerHour: BigDecimal =
130+
costPerUnit.getUnits + (costPerUnit.getNanos * 10e-9) // Same as above, but as a big decimal
131+
Success(costPerCorePerHour * coreCount)
132+
}
133+
134+
private def calculateRamPricePerHour(ramSku: Sku, ramMbCount: Int): Try[BigDecimal] =
135+
// TODO
136+
Success(ramMbCount.toLong * 0.25)
96137

97-
private def convertSkuToKeyValuePair(sku: Sku): (CostCatalogKey, CostCatalogValue) = CostCatalogKey(
98-
machineType = MachineType.fromSku(sku),
99-
usageType = UsageType.fromSku(sku),
100-
machineCustomization = MachineCustomization.fromSku(sku),
101-
resourceGroup = ResourceGroup.fromSku(sku)
102-
) -> CostCatalogValue(sku)
138+
private def getMostRecentPricingInfo(sku: Sku): PricingInfo = {
139+
val mostRecentPricingInfoIndex = sku.getPricingInfoCount - 1
140+
sku.getPricingInfo(mostRecentPricingInfoIndex)
141+
}
142+
143+
private def calculateVmCostPerHour(instantiatedVmInfo: InstantiatedVmInfo): Try[BigDecimal] = {
144+
val machineType = MachineType.fromGoogleMachineTypeString(instantiatedVmInfo.machineType)
145+
val usageType = UsageType.fromBoolean(instantiatedVmInfo.preemptible)
146+
val machineCustomization = MachineCustomization.fromMachineTypeString(instantiatedVmInfo.machineType)
147+
val region = instantiatedVmInfo.region
148+
val coreCount = MachineType.extractCoreCountFromMachineTypeString(instantiatedVmInfo.machineType)
149+
val ramMbCount = MachineType.extractRamMbFromMachineTypeString(instantiatedVmInfo.machineType)
150+
151+
val cpuResourceGroup = Cpu // TODO: Investigate the situation in which the resource group is n1
152+
val cpuKey =
153+
CostCatalogKey(machineType, Option(usageType), Option(machineCustomization), Option(cpuResourceGroup), region)
154+
val cpuSku = getSku(cpuKey)
155+
val cpuCost = cpuSku.map(sku => calculateCpuPricePerHour(sku.catalogObject, coreCount.get)) // TODO .get
156+
157+
val ramResourceGroup = Ram
158+
val ramKey =
159+
CostCatalogKey(machineType, Option(usageType), Option(machineCustomization), Option(ramResourceGroup), region)
160+
val ramSku = getSku(ramKey)
161+
val ramCost = ramSku.map(sku => calculateRamPricePerHour(sku.catalogObject, ramMbCount.get)) // TODO .get
162+
Success(cpuCost.get.get + ramCost.get.get)
163+
}
103164

104165
def serviceRegistryActor: ActorRef = serviceRegistry
105166
override def receive: Receive = {
167+
case GcpCostLookupRequest(vmInfo, replyTo) =>
168+
val calculatedCost = calculateVmCostPerHour(vmInfo).toOption
169+
val response = GcpCostLookupResponse(calculatedCost)
170+
replyTo ! response
106171
case ShutdownCommand =>
107172
googleClient.foreach(client => client.shutdownNow())
108173
context stop self

services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@ package cromwell.services.cost
22

33
import com.google.cloud.billing.v1.Sku
44

5+
import java.util.regex.{Matcher, Pattern}
6+
import scala.util.{Failure, Success, Try}
7+
8+
/*
9+
* Case class that contains information retrieved from Google about a VM that cromwell has started
10+
*/
11+
case class InstantiatedVmInfo(region: String, machineType: String, preemptible: Boolean)
512
/*
613
* These types reflect hardcoded strings found in a google cost catalog.
714
*/
@@ -13,6 +20,32 @@ object MachineType {
1320
else if (tokenizedDescription.contains(N2d.machineTypeName)) Some(N2d)
1421
else Option.empty
1522
}
23+
24+
// expects a string that looks something like "n1-standard-1" or "custom-1-4096"
25+
def fromGoogleMachineTypeString(machineTypeString: String): Option[MachineType] =
26+
if (machineTypeString.startsWith("n1")) Some(N1)
27+
else if (machineTypeString.startsWith("n2d")) Some(N2d)
28+
else if (machineTypeString.startsWith("n2")) Some(N2)
29+
else if (machineTypeString.startsWith("custom"))
30+
None // TODO: should this be n1? Make a 'custom' type? Combine with MachineCustomization?
31+
else {
32+
println(s"Error: Unrecognized machine type: $machineTypeString")
33+
None
34+
}
35+
36+
def extractCoreCountFromMachineTypeString(machineTypeString: String): Try[Int] = {
37+
val pattern: Pattern = Pattern.compile("-(\\d+)")
38+
val matcher: Matcher = pattern.matcher(machineTypeString)
39+
if (matcher.find()) {
40+
Success(matcher.group(1).toInt)
41+
} else {
42+
Failure(new IllegalArgumentException(s"Could not extract core count from ${machineTypeString}"))
43+
}
44+
}
45+
46+
def extractRamMbFromMachineTypeString(machineTypeString: String): Try[Int] =
47+
// TODO
48+
Success(4096)
1649
}
1750
sealed trait MachineType { def machineTypeName: String }
1851
case object N1 extends MachineType { override val machineTypeName = "N1" }
@@ -26,12 +59,23 @@ object UsageType {
2659
case Preemptible.typeName => Some(Preemptible)
2760
case _ => Option.empty
2861
}
62+
def fromBoolean(isPreemptible: Boolean): UsageType = isPreemptible match {
63+
case true => Preemptible
64+
case false => OnDemand
65+
}
66+
2967
}
3068
sealed trait UsageType { def typeName: String }
3169
case object OnDemand extends UsageType { override val typeName = "OnDemand" }
3270
case object Preemptible extends UsageType { override val typeName = "Preemptible" }
3371

3472
object MachineCustomization {
73+
// TODO: I think this is right but I am not 100% sure. Needs testing.
74+
// Do custom machine types always have the word "custom"?
75+
// Does Cromwell ever assign Predefined Machines?
76+
// Is it possible to have a custom machine that *doesn't* contain the word custom?
77+
def fromMachineTypeString(machineTypeString: String): MachineCustomization =
78+
if (machineTypeString.toLowerCase.contains("custom")) Custom else Predefined
3579
def fromSku(sku: Sku): Option[MachineCustomization] = {
3680
val tokenizedDescription = sku.getDescription.split(" ")
3781
if (tokenizedDescription.contains(Predefined.customizationName)) Some(Predefined)
@@ -55,4 +99,10 @@ object ResourceGroup {
5599
sealed trait ResourceGroup { def groupName: String }
56100
case object Cpu extends ResourceGroup { override val groupName = "CPU" }
57101
case object Ram extends ResourceGroup { override val groupName = "RAM" }
102+
103+
// TODO: What is the deal with this? It seems out of place.
104+
// Need to figure out how to reconcile with the Cpu resource group.
105+
// Current theory is that the n1 machines are legacy machines,
106+
// and are therefore categorized differently.
107+
// Unfortunately, N1 is Cromwell's default machine.
58108
case object N1Standard extends ResourceGroup { override val groupName = "N1Standard" }

services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ class GcpCostCatalogServiceSpec
7575
machineType = Some(N2),
7676
usageType = Some(Preemptible),
7777
machineCustomization = Some(Predefined),
78-
resourceGroup = Some(Cpu)
78+
resourceGroup = Some(Cpu),
79+
region = "europe-west9"
7980
)
8081

8182
val freshActor = constructTestActor
@@ -103,9 +104,10 @@ class GcpCostCatalogServiceSpec
103104
machineType = Some(N2d),
104105
usageType = Some(Preemptible),
105106
machineCustomization = None,
106-
resourceGroup = Some(Ram)
107+
resourceGroup = Some(Ram),
108+
region = "europe-west9"
107109
)
108110
val foundValue = testActorRef.getSku(expectedKey)
109-
foundValue.get.catalogObject.getDescription shouldBe "Spot Preemptible N2D AMD Instance Ram running in Johannesburg"
111+
foundValue.get.catalogObject.getDescription shouldBe "Spot Preemptible N2D AMD Instance Ram running in Paris"
110112
}
111113
}

supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,10 @@ package cromwell.backend.google.batch.actors
33
import akka.actor.{ActorRef, Props}
44
import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform}
55
import cromwell.backend.google.batch.models.RunStatus
6-
import cromwell.backend.standard.pollmonitoring.{
7-
AsyncJobHasFinished,
8-
PollMonitorParameters,
9-
PollResultMessage,
10-
PollResultMonitorActor,
11-
ProcessThisPollResult
12-
}
6+
import cromwell.backend.standard.pollmonitoring.{AsyncJobHasFinished, PollMonitorParameters, PollResultMessage, PollResultMonitorActor, ProcessThisPollResult}
137
import cromwell.backend.validation.ValidatedRuntimeAttributes
148
import cromwell.core.logging.JobLogger
9+
import cromwell.services.cost.InstantiatedVmInfo
1510
import cromwell.services.metadata.CallMetadataKeys
1611

1712
import java.time.OffsetDateTime
@@ -78,4 +73,6 @@ class BatchPollResultMonitorActor(pollMonitorParameters: PollMonitorParameters)
7873
}
7974

8075
override def params: PollMonitorParameters = pollMonitorParameters
76+
77+
override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] = Option.empty //TODO
8178
}

supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PapiPollResultMonitorActor.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
11
package cromwell.backend.google.pipelines.common
22

33
import akka.actor.{ActorRef, Props}
4-
import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform}
54
import cromwell.backend.google.pipelines.common.api.RunStatus
6-
import cromwell.backend.standard.pollmonitoring.{
7-
AsyncJobHasFinished,
8-
PollMonitorParameters,
9-
PollResultMessage,
10-
PollResultMonitorActor,
11-
ProcessThisPollResult
12-
}
5+
import cromwell.backend.standard.pollmonitoring._
136
import cromwell.backend.validation.ValidatedRuntimeAttributes
7+
import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform}
148
import cromwell.core.logging.JobLogger
9+
import cromwell.services.cost.{GcpCostLookupResponse, InstantiatedVmInfo}
1510
import cromwell.services.metadata.CallMetadataKeys
1611

1712
import java.time.OffsetDateTime
@@ -50,7 +45,13 @@ class PapiPollResultMonitorActor(parameters: PollMonitorParameters) extends Poll
5045
case event if event.name == CallMetadataKeys.VmEndTime => event.offsetDateTime
5146
}
5247

48+
override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] =
49+
pollStatus.instantiatedVmInfo
50+
51+
override def params: PollMonitorParameters = parameters
52+
5353
override def receive: Receive = {
54+
case costResponse: GcpCostLookupResponse => handleCostResponse(costResponse)
5455
case message: PollResultMessage =>
5556
message match {
5657
case ProcessThisPollResult(pollResult: RunStatus) => processPollResult(pollResult)
@@ -75,5 +76,4 @@ class PapiPollResultMonitorActor(parameters: PollMonitorParameters) extends Poll
7576
)
7677
)
7778
}
78-
override def params: PollMonitorParameters = parameters
7979
}

0 commit comments

Comments
 (0)