Skip to content

Commit

Permalink
Enable retry in create prediction job (#78)
Browse files Browse the repository at this point in the history
* Enable retry

* Create fail case for sync prediction job

* Create retry success test for sync prediction job

* Add failed call 4 times then retry success 1 time

* Add case where pending then failed 5 times

* Mock constant for retry and sleep delay

* Assert responses calls

Co-authored-by: Julio Anthony Leonard <[email protected]>
  • Loading branch information
imjuanleonard and Julio Anthony Leonard authored Feb 22, 2021
1 parent f1a7bdf commit e347cc1
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 6 deletions.
20 changes: 16 additions & 4 deletions python/sdk/merlin/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@

DEFAULT_MODEL_PATH = "model"
DEFAULT_MODEL_VERSION_LIMIT = 50
DEFAULT_API_CALL_RETRY = 5
DEFAULT_PREDICTION_JOB_DELAY = 5
DEFAULT_PREDICTION_JOB_RETRY_DELAY = 30
V1 = "v1"
PREDICTION_JOB = "PredictionJob"

Expand Down Expand Up @@ -1128,6 +1131,7 @@ def create_prediction_job(self, job_config: PredictionJobConfig, sync: bool = Tr
bar = pyprind.ProgBar(100, track_time=True,
title=f"Running prediction job {j.id} from model {self.model.name} version {self.id} "
f"under project {self.model.project.name}")
retry = DEFAULT_API_CALL_RETRY
while j.status == "pending" or \
j.status == "running" or \
j.status == "terminating":
Expand All @@ -1137,11 +1141,19 @@ def create_prediction_job(self, job_config: PredictionJobConfig, sync: bool = Tr
job_id=j.id)
return PredictionJob(j, self._api_client)
else:
j = job_client.models_model_id_versions_version_id_jobs_job_id_get(model_id=self.model.id,
version_id=self.id,
job_id=j.id)
try:
j = job_client.models_model_id_versions_version_id_jobs_job_id_get(model_id=self.model.id,
version_id=self.id,
job_id=j.id)
retry = DEFAULT_API_CALL_RETRY
except Exception:
retry -= 1
if retry == 0:
j.status = "failed"
break
sleep(DEFAULT_PREDICTION_JOB_RETRY_DELAY)
bar.update()
sleep(5)
sleep(DEFAULT_PREDICTION_JOB_DELAY)
bar.stop()

if j.status == "failed" or j.status == "failed_submission":
Expand Down
186 changes: 184 additions & 2 deletions python/sdk/test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.

import json

import types
import pytest
from merlin.model import ModelType
from urllib3_mock import Responses
from unittest.mock import patch

import client as cl
from merlin.batch.config import PredictionJobConfig, ResultType
Expand Down Expand Up @@ -174,7 +175,6 @@ def test_list_secret(self, project):
secret_names = project.list_secret()
assert secret_names == [self.secret_1.name, self.secret_2.name]


class TestModelVersion:
@responses.activate
def test_endpoint(self, version):
Expand Down Expand Up @@ -381,6 +381,188 @@ def test_create_prediction_job(self, version):
assert actual_req["config"]["job_config"]["model"]["type"] == ModelType.PYFUNC_V2.value.upper()
assert actual_req["config"]["service_account_name"] == "my-service-account"

@patch("merlin.model.DEFAULT_PREDICTION_JOB_DELAY", 0)
@patch("merlin.model.DEFAULT_PREDICTION_JOB_RETRY_DELAY", 0)
@responses.activate
def test_create_prediction_job_with_retry_failed(self, version):
job_1.status = "pending"
responses.add("POST", '/v1/models/1/versions/1/jobs',
body=json.dumps(job_1.to_dict()),
status=200,
content_type='application/json')

for i in range(5):
responses.add("GET", '/v1/models/1/versions/1/jobs/1',
body=json.dumps(job_1.to_dict()),
status=500,
content_type='application/json')

bq_src = BigQuerySource(table="project.dataset.source_table",
features=["feature_1", "feature2"],
options={"key": "val"})

bq_sink = BigQuerySink(table="project.dataset.result_table",
result_column="prediction",
save_mode=SaveMode.OVERWRITE,
staging_bucket="gs://test",
options={"key": "val"})

job_config = PredictionJobConfig(source=bq_src,
sink=bq_sink,
service_account_name="my-service-account",
result_type=ResultType.INTEGER)

with pytest.raises(ValueError):
j = version.create_prediction_job(job_config=job_config)
assert j.id == job_1.id
assert j.error == job_1.error
assert j.name == job_1.name
assert len(responses.calls) == 6

@patch("merlin.model.DEFAULT_PREDICTION_JOB_DELAY", 0)
@patch("merlin.model.DEFAULT_PREDICTION_JOB_RETRY_DELAY", 0)
@responses.activate
def test_create_prediction_job_with_retry_success(self, version):
job_1.status = "pending"
responses.add("POST", '/v1/models/1/versions/1/jobs',
body=json.dumps(job_1.to_dict()),
status=200,
content_type='application/json')

# Patch the method as currently it is not supported in the library
# https://github.com/getsentry/responses/issues/135
def _find_match(self, request):
for match in self._urls:
if request.method == match['method'] and \
self._has_url_match(match, request.url):
return match

def _find_match_patched(self, request):
for index, match in enumerate(self._urls):
if request.method == match['method'] and \
self._has_url_match(match, request.url):
if request.method == "GET" and request.url == "/v1/models/1/versions/1/jobs/1":
return self._urls.pop(index)
else:
return match
responses._find_match = types.MethodType(_find_match_patched, responses)

for i in range(4):
responses.add("GET", '/v1/models/1/versions/1/jobs/1',
body=json.dumps(job_1.to_dict()),
status=500,
content_type='application/json')

job_1.status = "completed"
responses.add("GET", '/v1/models/1/versions/1/jobs/1',
body=json.dumps(job_1.to_dict()),
status=200,
content_type='application/json')

bq_src = BigQuerySource(table="project.dataset.source_table",
features=["feature_1", "feature2"],
options={"key": "val"})

bq_sink = BigQuerySink(table="project.dataset.result_table",
result_column="prediction",
save_mode=SaveMode.OVERWRITE,
staging_bucket="gs://test",
options={"key": "val"})

job_config = PredictionJobConfig(source=bq_src,
sink=bq_sink,
service_account_name="my-service-account",
result_type=ResultType.INTEGER)

j = version.create_prediction_job(job_config=job_config)
assert j.status == JobStatus.COMPLETED
assert j.id == job_1.id
assert j.error == job_1.error
assert j.name == job_1.name

actual_req = json.loads(responses.calls[0].request.body)
assert actual_req["config"]["job_config"]["bigquery_source"] == bq_src.to_dict()
assert actual_req["config"]["job_config"]["bigquery_sink"] == bq_sink.to_dict()
assert actual_req["config"]["job_config"]["model"]["result"]["type"] == ResultType.INTEGER.value
assert actual_req["config"]["job_config"]["model"]["uri"] == f"{version.artifact_uri}/model"
assert actual_req["config"]["job_config"]["model"]["type"] == ModelType.PYFUNC_V2.value.upper()
assert actual_req["config"]["service_account_name"] == "my-service-account"
assert len(responses.calls) == 6

# unpatch
responses._find_match = types.MethodType(_find_match, responses)

@patch("merlin.model.DEFAULT_PREDICTION_JOB_DELAY", 0)
@patch("merlin.model.DEFAULT_PREDICTION_JOB_RETRY_DELAY", 0)
@responses.activate
def test_create_prediction_job_with_retry_pending_then_failed(self, version):
job_1.status = "pending"
responses.add("POST", '/v1/models/1/versions/1/jobs',
body=json.dumps(job_1.to_dict()),
status=200,
content_type='application/json')

# Patch the method as currently it is not supported in the library
# https://github.com/getsentry/responses/issues/135
def _find_match(self, request):
for match in self._urls:
if request.method == match['method'] and \
self._has_url_match(match, request.url):
return match

def _find_match_patched(self, request):
for index, match in enumerate(self._urls):
if request.method == match['method'] and \
self._has_url_match(match, request.url):
if request.method == "GET" and request.url == "/v1/models/1/versions/1/jobs/1":
return self._urls.pop(index)
else:
return match
responses._find_match = types.MethodType(_find_match_patched, responses)

for i in range(3):
responses.add("GET", '/v1/models/1/versions/1/jobs/1',
body=json.dumps(job_1.to_dict()),
status=500,
content_type='application/json')

responses.add("GET", '/v1/models/1/versions/1/jobs/1',
body=json.dumps(job_1.to_dict()),
status=200,
content_type='application/json')

job_1.status = "failed"
for i in range(5):
responses.add("GET", '/v1/models/1/versions/1/jobs/1',
body=json.dumps(job_1.to_dict()),
status=500,
content_type='application/json')

bq_src = BigQuerySource(table="project.dataset.source_table",
features=["feature_1", "feature2"],
options={"key": "val"})

bq_sink = BigQuerySink(table="project.dataset.result_table",
result_column="prediction",
save_mode=SaveMode.OVERWRITE,
staging_bucket="gs://test",
options={"key": "val"})

job_config = PredictionJobConfig(source=bq_src,
sink=bq_sink,
service_account_name="my-service-account",
result_type=ResultType.INTEGER)

with pytest.raises(ValueError):
j = version.create_prediction_job(job_config=job_config)
assert j.id == job_1.id
assert j.error == job_1.error
assert j.name == job_1.name

# unpatch
responses._find_match = types.MethodType(_find_match, responses)
assert len(responses.calls) == 10

@responses.activate
def test_stop_prediction_job(self, version):
job_1.status = "pending"
Expand Down

0 comments on commit e347cc1

Please sign in to comment.