diff --git a/optimizely/cmab/cmab_client.py b/optimizely/cmab/cmab_client.py new file mode 100644 index 00000000..dfcffa78 --- /dev/null +++ b/optimizely/cmab/cmab_client.py @@ -0,0 +1,193 @@ +# Copyright 2025 Optimizely +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import time +import requests +import math +from typing import Dict, Any, Optional +from optimizely import logger as _logging +from optimizely.helpers.enums import Errors +from optimizely.exceptions import CmabFetchError, CmabInvalidResponseError + +# Default constants for CMAB requests +DEFAULT_MAX_RETRIES = 3 +DEFAULT_INITIAL_BACKOFF = 0.1 # in seconds (100 ms) +DEFAULT_MAX_BACKOFF = 10 # in seconds +DEFAULT_BACKOFF_MULTIPLIER = 2.0 +MAX_WAIT_TIME = 10.0 + + +class CmabRetryConfig: + """Configuration for retrying CMAB requests. + + Contains parameters for maximum retries, backoff intervals, and multipliers. + """ + def __init__( + self, + max_retries: int = DEFAULT_MAX_RETRIES, + initial_backoff: float = DEFAULT_INITIAL_BACKOFF, + max_backoff: float = DEFAULT_MAX_BACKOFF, + backoff_multiplier: float = DEFAULT_BACKOFF_MULTIPLIER, + ): + self.max_retries = max_retries + self.initial_backoff = initial_backoff + self.max_backoff = max_backoff + self.backoff_multiplier = backoff_multiplier + + +class DefaultCmabClient: + """Client for interacting with the CMAB service. + + Provides methods to fetch decisions with optional retry logic. + """ + def __init__(self, http_client: Optional[requests.Session] = None, + retry_config: Optional[CmabRetryConfig] = None, + logger: Optional[_logging.Logger] = None): + """Initialize the CMAB client. + + Args: + http_client (Optional[requests.Session]): HTTP client for making requests. + retry_config (Optional[CmabRetryConfig]): Configuration for retry logic. + logger (Optional[_logging.Logger]): Logger for logging messages. + """ + self.http_client = http_client or requests.Session() + self.retry_config = retry_config + self.logger = _logging.adapt_logger(logger or _logging.NoOpLogger()) + + def fetch_decision( + self, + rule_id: str, + user_id: str, + attributes: Dict[str, Any], + cmab_uuid: str, + timeout: float = MAX_WAIT_TIME + ) -> str: + """Fetch a decision from the CMAB prediction service. + + Args: + rule_id (str): The rule ID for the experiment. + user_id (str): The user ID for the request. + attributes (Dict[str, Any]): User attributes for the request. + cmab_uuid (str): Unique identifier for the CMAB request. + timeout (float): Maximum wait time for request to respond in seconds. Defaults to 10 seconds. + + Returns: + str: The variation ID. + """ + url = f"https://prediction.cmab.optimizely.com/predict/{rule_id}" + cmab_attributes = [ + {"id": key, "value": value, "type": "custom_attribute"} + for key, value in attributes.items() + ] + + request_body = { + "instances": [{ + "visitorId": user_id, + "experimentId": rule_id, + "attributes": cmab_attributes, + "cmabUUID": cmab_uuid, + }] + } + if self.retry_config: + variation_id = self._do_fetch_with_retry(url, request_body, self.retry_config, timeout) + else: + variation_id = self._do_fetch(url, request_body, timeout) + return variation_id + + def _do_fetch(self, url: str, request_body: Dict[str, Any], timeout: float) -> str: + """Perform a single fetch request to the CMAB prediction service. + + Args: + url (str): The endpoint URL. + request_body (Dict[str, Any]): The request payload. + timeout (float): Maximum wait time for request to respond in seconds. + Returns: + str: The variation ID + """ + headers = {'Content-Type': 'application/json'} + try: + response = self.http_client.post(url, data=json.dumps(request_body), headers=headers, timeout=timeout) + except requests.exceptions.RequestException as e: + error_message = Errors.CMAB_FETCH_FAILED.format(str(e)) + self.logger.error(error_message) + raise CmabFetchError(error_message) + + if not 200 <= response.status_code < 300: + error_message = Errors.CMAB_FETCH_FAILED.format(str(response.status_code)) + self.logger.error(error_message) + raise CmabFetchError(error_message) + + try: + body = response.json() + except json.JSONDecodeError: + error_message = Errors.INVALID_CMAB_FETCH_RESPONSE + self.logger.error(error_message) + raise CmabInvalidResponseError(error_message) + + if not self.validate_response(body): + error_message = Errors.INVALID_CMAB_FETCH_RESPONSE + self.logger.error(error_message) + raise CmabInvalidResponseError(error_message) + + return str(body['predictions'][0]['variation_id']) + + def validate_response(self, body: Dict[str, Any]) -> bool: + """Validate the response structure from the CMAB service. + + Args: + body (Dict[str, Any]): The response body to validate. + + Returns: + bool: True if the response is valid, False otherwise. + """ + return ( + isinstance(body, dict) and + 'predictions' in body and + isinstance(body['predictions'], list) and + len(body['predictions']) > 0 and + isinstance(body['predictions'][0], dict) and + "variation_id" in body["predictions"][0] + ) + + def _do_fetch_with_retry( + self, + url: str, + request_body: Dict[str, Any], + retry_config: CmabRetryConfig, + timeout: float + ) -> str: + """Perform a fetch request with retry logic. + + Args: + url (str): The endpoint URL. + request_body (Dict[str, Any]): The request payload. + retry_config (CmabRetryConfig): Configuration for retry logic. + timeout (float): Maximum wait time for request to respond in seconds. + Returns: + str: The variation ID + """ + backoff = retry_config.initial_backoff + for attempt in range(retry_config.max_retries + 1): + try: + variation_id = self._do_fetch(url, request_body, timeout) + return variation_id + except: + if attempt < retry_config.max_retries: + self.logger.info(f"Retrying CMAB request (attempt: {attempt + 1}) after {backoff} seconds...") + time.sleep(backoff) + backoff = min(backoff * math.pow(retry_config.backoff_multiplier, attempt + 1), + retry_config.max_backoff) + + error_message = Errors.CMAB_FETCH_FAILED.format('Exhausted all retries for CMAB request.') + self.logger.error(error_message) + raise CmabFetchError(error_message) diff --git a/optimizely/exceptions.py b/optimizely/exceptions.py index e7644064..b17b1397 100644 --- a/optimizely/exceptions.py +++ b/optimizely/exceptions.py @@ -82,3 +82,21 @@ class OdpInvalidData(Exception): """ Raised when passing invalid ODP data. """ pass + + +class CmabError(Exception): + """Base exception for CMAB client errors.""" + + pass + + +class CmabFetchError(CmabError): + """Exception raised when CMAB fetch fails.""" + + pass + + +class CmabInvalidResponseError(CmabError): + """Exception raised when CMAB response is invalid.""" + + pass diff --git a/optimizely/helpers/enums.py b/optimizely/helpers/enums.py index fe90946e..2d6febab 100644 --- a/optimizely/helpers/enums.py +++ b/optimizely/helpers/enums.py @@ -127,6 +127,8 @@ class Errors: ODP_INVALID_DATA: Final = 'ODP data is not valid.' ODP_INVALID_ACTION: Final = 'ODP action is not valid (cannot be empty).' MISSING_SDK_KEY: Final = 'SDK key not provided/cannot be found in the datafile.' + CMAB_FETCH_FAILED: Final = 'CMAB decision fetch failed with status: {}' + INVALID_CMAB_FETCH_RESPONSE = 'Invalid CMAB fetch response' class ForcedDecisionLogs: diff --git a/tests/test_cmab_client.py b/tests/test_cmab_client.py new file mode 100644 index 00000000..0e15b3f4 --- /dev/null +++ b/tests/test_cmab_client.py @@ -0,0 +1,235 @@ +import unittest +import json +from unittest.mock import MagicMock, patch, call +from optimizely.cmab.cmab_client import DefaultCmabClient, CmabRetryConfig +from requests.exceptions import RequestException +from optimizely.helpers.enums import Errors +from optimizely.exceptions import CmabFetchError, CmabInvalidResponseError + + +class TestDefaultCmabClient(unittest.TestCase): + def setUp(self): + self.mock_http_client = MagicMock() + self.mock_logger = MagicMock() + self.retry_config = CmabRetryConfig(max_retries=3, initial_backoff=0.01, max_backoff=1, backoff_multiplier=2) + self.client = DefaultCmabClient( + http_client=self.mock_http_client, + logger=self.mock_logger, + retry_config=None + ) + self.rule_id = 'test_rule' + self.user_id = 'user123' + self.attributes = {'attr1': 'value1', 'attr2': 'value2'} + self.cmab_uuid = 'uuid-1234' + self.expected_url = f"https://prediction.cmab.optimizely.com/predict/{self.rule_id}" + self.expected_body = { + "instances": [{ + "visitorId": self.user_id, + "experimentId": self.rule_id, + "attributes": [ + {"id": "attr1", "value": "value1", "type": "custom_attribute"}, + {"id": "attr2", "value": "value2", "type": "custom_attribute"} + ], + "cmabUUID": self.cmab_uuid, + }] + } + self.expected_headers = {'Content-Type': 'application/json'} + + def test_fetch_decision_returns_success_no_retry(self): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'predictions': [{'variation_id': 'abc123'}] + } + self.mock_http_client.post.return_value = mock_response + result = self.client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid) + self.assertEqual(result, 'abc123') + self.mock_http_client.post.assert_called_once_with( + self.expected_url, + data=json.dumps(self.expected_body), + headers=self.expected_headers, + timeout=10.0 + ) + + def test_fetch_decision_returns_http_exception_no_retry(self): + self.mock_http_client.post.side_effect = RequestException('Connection error') + + with self.assertRaises(CmabFetchError) as context: + self.client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid) + + self.mock_http_client.post.assert_called_once() + self.mock_logger.error.assert_called_with(Errors.CMAB_FETCH_FAILED.format('Connection error')) + self.assertIn('Connection error', str(context.exception)) + + def test_fetch_decision_returns_non_2xx_status_no_retry(self): + mock_response = MagicMock() + mock_response.status_code = 500 + self.mock_http_client.post.return_value = mock_response + + with self.assertRaises(CmabFetchError) as context: + self.client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid) + + self.mock_http_client.post.assert_called_once_with( + self.expected_url, + data=json.dumps(self.expected_body), + headers=self.expected_headers, + timeout=10.0 + ) + self.mock_logger.error.assert_called_with(Errors.CMAB_FETCH_FAILED.format(str(mock_response.status_code))) + self.assertIn(str(mock_response.status_code), str(context.exception)) + + def test_fetch_decision_returns_invalid_json_no_retry(self): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + self.mock_http_client.post.return_value = mock_response + + with self.assertRaises(CmabInvalidResponseError) as context: + self.client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid) + + self.mock_http_client.post.assert_called_once_with( + self.expected_url, + data=json.dumps(self.expected_body), + headers=self.expected_headers, + timeout=10.0 + ) + self.mock_logger.error.assert_called_with(Errors.INVALID_CMAB_FETCH_RESPONSE) + self.assertIn(Errors.INVALID_CMAB_FETCH_RESPONSE, str(context.exception)) + + def test_fetch_decision_returns_invalid_response_structure_no_retry(self): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'no_predictions': []} + self.mock_http_client.post.return_value = mock_response + + with self.assertRaises(CmabInvalidResponseError) as context: + self.client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid) + + self.mock_http_client.post.assert_called_once_with( + self.expected_url, + data=json.dumps(self.expected_body), + headers=self.expected_headers, + timeout=10.0 + ) + self.mock_logger.error.assert_called_with(Errors.INVALID_CMAB_FETCH_RESPONSE) + self.assertIn(Errors.INVALID_CMAB_FETCH_RESPONSE, str(context.exception)) + + @patch('time.sleep', return_value=None) + def test_fetch_decision_returns_success_with_retry_on_first_try(self, mock_sleep): + # Create client with retry + client_with_retry = DefaultCmabClient( + http_client=self.mock_http_client, + logger=self.mock_logger, + retry_config=self.retry_config + ) + + # Mock successful response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'predictions': [{'variation_id': 'abc123'}] + } + self.mock_http_client.post.return_value = mock_response + + result = client_with_retry.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid) + + # Verify result and request parameters + self.assertEqual(result, 'abc123') + self.mock_http_client.post.assert_called_once_with( + self.expected_url, + data=json.dumps(self.expected_body), + headers=self.expected_headers, + timeout=10.0 + ) + self.assertEqual(self.mock_http_client.post.call_count, 1) + mock_sleep.assert_not_called() + + @patch('time.sleep', return_value=None) + def test_fetch_decision_returns_success_with_retry_on_third_try(self, mock_sleep): + client_with_retry = DefaultCmabClient( + http_client=self.mock_http_client, + logger=self.mock_logger, + retry_config=self.retry_config + ) + + # Create failure and success responses + failure_response = MagicMock() + failure_response.status_code = 500 + + success_response = MagicMock() + success_response.status_code = 200 + success_response.json.return_value = { + 'predictions': [{'variation_id': 'xyz456'}] + } + + # First two calls fail, third succeeds + self.mock_http_client.post.side_effect = [ + failure_response, + failure_response, + success_response + ] + + result = client_with_retry.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid) + + self.assertEqual(result, 'xyz456') + self.assertEqual(self.mock_http_client.post.call_count, 3) + + # Verify all HTTP calls used correct parameters + self.mock_http_client.post.assert_called_with( + self.expected_url, + data=json.dumps(self.expected_body), + headers=self.expected_headers, + timeout=10.0 + ) + + # Verify retry logging + self.mock_logger.info.assert_has_calls([ + call("Retrying CMAB request (attempt: 1) after 0.01 seconds..."), + call("Retrying CMAB request (attempt: 2) after 0.02 seconds...") + ]) + + # Verify sleep was called with correct backoff times + mock_sleep.assert_has_calls([ + call(0.01), + call(0.02) + ]) + + @patch('time.sleep', return_value=None) + def test_fetch_decision_exhausts_all_retry_attempts(self, mock_sleep): + client_with_retry = DefaultCmabClient( + http_client=self.mock_http_client, + logger=self.mock_logger, + retry_config=self.retry_config + ) + + # Create failure response + failure_response = MagicMock() + failure_response.status_code = 500 + + # All attempts fail + self.mock_http_client.post.return_value = failure_response + + with self.assertRaises(CmabFetchError): + client_with_retry.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid) + + # Verify all attempts were made (1 initial + 3 retries) + self.assertEqual(self.mock_http_client.post.call_count, 4) + + # Verify retry logging + self.mock_logger.info.assert_has_calls([ + call("Retrying CMAB request (attempt: 1) after 0.01 seconds..."), + call("Retrying CMAB request (attempt: 2) after 0.02 seconds..."), + call("Retrying CMAB request (attempt: 3) after 0.08 seconds...") + ]) + + # Verify sleep was called for each retry + mock_sleep.assert_has_calls([ + call(0.01), + call(0.02), + call(0.08) + ]) + + # Verify final error + self.mock_logger.error.assert_called_with( + Errors.CMAB_FETCH_FAILED.format('Exhausted all retries for CMAB request.') + )