Skip to content

Commit

Permalink
batch: add on-demand retry for preemption
Browse files Browse the repository at this point in the history
The Batch backend currently doesn’t retry preempted jobs with on-demand
VMs, which is problematic since our goal is to save costs and avoid
failing large tasks due to a few preempted machines. This patch
reintroduces, in part, functionality removed in commit 49d675d,
enabling the job to restart on a STANDARD VM after encountering a
VMPreemption error.
  • Loading branch information
juimonen authored and mcovarr committed Jan 30, 2025
1 parent 2c134ec commit 9c0f6d9
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import com.google.cloud.batch.v1.BatchServiceSettings
import com.google.common.collect.ImmutableMap
import com.typesafe.scalalogging.StrictLogging
import cromwell.backend._
import cromwell.backend.google.batch.GcpBatchBackendLifecycleActorFactory.{
preemptionCountKey
}
import cromwell.backend.google.batch.actors._
import cromwell.backend.google.batch.api.request.{BatchRequestExecutor, RequestHandler}
import cromwell.backend.google.batch.authentication.GcpBatchDockerCredentials
Expand All @@ -30,6 +33,7 @@ class GcpBatchBackendLifecycleActorFactory(override val name: String,
) extends StandardLifecycleActorFactory
with GcpPlatform {

override val requestedKeyValueStoreKeys: Seq[String] = Seq(preemptionCountKey)
import GcpBatchBackendLifecycleActorFactory._

override def jobIdKey: String = "__gcp_batch"
Expand Down Expand Up @@ -133,6 +137,7 @@ class GcpBatchBackendLifecycleActorFactory(override val name: String,
}

object GcpBatchBackendLifecycleActorFactory extends StrictLogging {
val preemptionCountKey = "PreemptionCount"

private[batch] def robustBuildAttributes(buildAttributes: () => GcpBatchConfigurationAttributes,
maxAttempts: Int = 3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ import cromwell.backend.async.{
AbortedExecutionHandle,
ExecutionHandle,
FailedNonRetryableExecutionHandle,
FailedRetryableExecutionHandle,
PendingExecutionHandle
}
import cromwell.backend.google.batch.GcpBatchBackendLifecycleActorFactory
import cromwell.backend.google.batch.api.GcpBatchRequestFactory._
import cromwell.backend.google.batch.io._
import cromwell.backend.google.batch.models.GcpBatchConfigurationAttributes.GcsTransferConfiguration
import cromwell.backend.google.batch.models.GcpBatchJobPaths.GcsTransferLibraryName
import cromwell.backend.google.batch.models.RunStatus.TerminalRunStatus
import cromwell.backend.google.batch.models.{GcpBatchExitCode, RunStatus}
import cromwell.backend.google.batch.models._
import cromwell.backend.google.batch.monitoring.{BatchInstrumentation, CheckpointingConfiguration, MonitoringImage}
import cromwell.backend.google.batch.runnable.WorkflowOptionKeys
Expand All @@ -46,7 +49,7 @@ import cromwell.filesystems.gcs.GcsPath
import cromwell.filesystems.http.HttpPath
import cromwell.filesystems.sra.SraPath
import cromwell.services.instrumentation.CromwellInstrumentation
import cromwell.services.keyvalue.KeyValueServiceActor.KvJobKey
import cromwell.services.keyvalue.KeyValueServiceActor.{KvJobKey, KvPair, ScopedKey}
import cromwell.services.metadata.CallMetadataKeys
import mouse.all._
import shapeless.Coproduct
Expand Down Expand Up @@ -175,6 +178,15 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar

override def dockerImageUsed: Option[String] = Option(jobDockerImage)

override lazy val preemptible: Int = jobDescriptor.prefetchedKvStoreEntries.get(GcpBatchBackendLifecycleActorFactory.preemptionCountKey) match {
case Some(KvPair(_, v)) =>
Try(v.toInt) match {
case Success(m) => m
case Failure(_) => 0
}
case _ => runtimeAttributes.preemptible
}

override def tryAbort(job: StandardAsyncJob): Unit =
abortJob(workflowId = workflowId,
jobName = JobName.parse(job.jobId),
Expand Down Expand Up @@ -644,6 +656,7 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
projectId = googleProject(jobDescriptor.workflowDescriptor),
computeServiceAccount = computeServiceAccount(jobDescriptor.workflowDescriptor),
googleLabels = backendLabels ++ customLabels,
preemptible = preemptible,
batchTimeout = batchConfiguration.batchTimeout,
jobShell = batchConfiguration.jobShell,
privateDockerKeyAndEncryptedToken = dockerKeyAndToken,
Expand Down Expand Up @@ -851,7 +864,7 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
override def executeAsync(): Future[ExecutionHandle] = {

// Want to force runtimeAttributes to evaluate so we can fail quickly now if we need to:
def evaluateRuntimeAttributes = Future.fromTry(Try(runtimeAttributes))
def evaluateRuntimeAttributes = Future.fromTry(Try(runtimeAttributes.copy(preemptible = preemptible)))

def generateInputOutputParameters: Future[InputOutputParameters] = Future.fromTry(Try {
val rcFileOutput = GcpBatchFileOutput(
Expand Down Expand Up @@ -911,7 +924,7 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
})

val runBatchResponse = for {
_ <- evaluateRuntimeAttributes
runtimeAttributes <- evaluateRuntimeAttributes
_ <- uploadScriptFile()
customLabels <- Future.fromTry(GcpLabel.fromWorkflowOptions(workflowDescriptor.workflowOptions))
batchParameters <- generateInputOutputParameters
Expand Down Expand Up @@ -1070,14 +1083,23 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
// returnCode is provided by cromwell, so far, this is empty for all the tests I ran
override def handleExecutionFailure(runStatus: RunStatus, returnCode: Option[Int]): Future[ExecutionHandle] = {
def handleFailedRunStatus(runStatus: RunStatus.UnsuccessfulRunStatus): ExecutionHandle =
FailedNonRetryableExecutionHandle(
StandardException(
if (runStatus.exitCode == Some(GcpBatchExitCode.VMPreemption)) {
FailedRetryableExecutionHandle(
StandardException(
message = runStatus.prettyPrintedError,
jobTag = jobTag),
returnCode,
Option(Seq(KvPair(ScopedKey(workflowId, futureKvJobKey, GcpBatchBackendLifecycleActorFactory.preemptionCountKey), "0")))
)
} else {
FailedNonRetryableExecutionHandle(
StandardException(
message = runStatus.prettyPrintedError,
jobTag = jobTag
),
returnCode,
None
)
jobTag = jobTag),
returnCode,
None
)
}

Future.fromTry {
Try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ trait GcpBatchJobCachingActorHelper extends StandardCachingActorHelper {
batchConfiguration.runtimeConfig
)

val preemptible: Int

lazy val workingDisk: GcpBatchAttachedDisk = runtimeAttributes.disks.find(_.name == GcpBatchWorkingDisk.Name).get

lazy val callRootPath: Path = gcpBatchCallPaths.callExecutionRoot
Expand Down Expand Up @@ -71,9 +73,10 @@ trait GcpBatchJobCachingActorHelper extends StandardCachingActorHelper {
.get(WorkflowOptionKeys.GoogleProject)
.getOrElse(batchAttributes.project)

Map[String, String](
Map[String, Any](
GcpBatchMetadataKeys.GoogleProject -> googleProject,
GcpBatchMetadataKeys.ExecutionBucket -> initializationData.workflowPaths.executionRootString
GcpBatchMetadataKeys.ExecutionBucket -> initializationData.workflowPaths.executionRootString,
"preemptible" -> preemptible
) ++ originalLabelEvents
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ object GcpBatchRequestFactory {
projectId: String,
computeServiceAccount: String,
googleLabels: Seq[GcpLabel],
preemptible: Int,
batchTimeout: FiniteDuration,
jobShell: String,
privateDockerKeyAndEncryptedToken: Option[CreateBatchDockerKeyAndToken],
Expand Down

0 comments on commit 9c0f6d9

Please sign in to comment.