Skip to content

Commit

Permalink
Move recursion + add test
Browse files Browse the repository at this point in the history
  • Loading branch information
lucymcnatt committed Jan 23, 2025
1 parent 57a77bf commit 15c220a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,36 +57,30 @@ case class GcsUriDownloader(gcsUrl: String,
downloadAttempt: Int = 0
): IO[DownloadResult] = {

// Necessary function to handle the throwable when trying to recover a failed download
def handleDownloadFailure(t: Throwable): IO[DownloadResult] =
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)

logger.info(s"Attempting download attempt $downloadAttempt of $downloadRetries for a GCS url")

if (downloadAttempt < downloadRetries) {
backoff foreach { b => Thread.sleep(b.backoffMillis) }
logger.warn(s"Attempting download retry $downloadAttempt of $downloadRetries for a GCS url")
downloadWithRetries(downloadRetries,
backoff map {
_.next
},
downloadAttempt + 1
runDownloadCommand.redeemWith(
recover = handleDownloadFailure,
bind = {
case s: DownloadSuccess.type =>
IO.pure(s)
case _: RecognizedRetryableDownloadFailure =>
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)
case _: UnrecognizedRetryableDownloadFailure =>
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)
case _ =>
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)
}
)
} else {
IO.raiseError(new RuntimeException(s"Exhausted $downloadRetries resolution retries to download GCS file"))
}

// Necessary function to handle the throwable when trying to recover a failed download
def handleDownloadFailure(t: Throwable): IO[DownloadResult] =
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)

runDownloadCommand.redeemWith(
recover = handleDownloadFailure,
bind = {
case s: DownloadSuccess.type =>
IO.pure(s)
case _: RecognizedRetryableDownloadFailure =>
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)
case _: UnrecognizedRetryableDownloadFailure =>
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)
case _ =>
downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1)
}
)
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package drs.localizer.downloaders

import common.assertion.CromwellTimeoutSpec
import org.mockito.Mockito.{spy, times, verify}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

Expand Down Expand Up @@ -96,4 +97,25 @@ class GcsUriDownloaderSpec extends AnyFlatSpec with CromwellTimeoutSpec with Mat

downloader.generateDownloadScript(gcsUrl, Option(fakeSAJsonPath)) shouldBe expectedDownloadScript
}

it should "fail to download GCS URL after 5 attempts" in {
val gcsUrl = "gs://foo/bar.bam"
val downloader = spy(new GcsUriDownloader(
gcsUrl = gcsUrl,
downloadLoc = fakeDownloadLocation,
requesterPaysProjectIdOption = Option(fakeRequesterPaysId),
serviceAccountJson = None
))

val result = downloader.downloadWithRetries(5, None).attempt.unsafeRunSync()

result.isLeft shouldBe true
// attempts to download the 1st time and the 5th time, but doesn't attempt a 6th
verify(downloader, times(1)).downloadWithRetries(5, None, 1)
verify(downloader, times(1)).downloadWithRetries(5, None, 5)
verify(downloader, times(0)).downloadWithRetries(5, None, 6)
// attempts the actual download command 5 times
verify(downloader, times(5)).runDownloadCommand

}
}

0 comments on commit 15c220a

Please sign in to comment.