diff --git a/eppo_client/__init__.py b/eppo_client/__init__.py index 647d353..2f1ca0e 100644 --- a/eppo_client/__init__.py +++ b/eppo_client/__init__.py @@ -40,6 +40,10 @@ def init(config: Config) -> EppoClient: flag_config_store.set_configurations( config.initial_configuration._flags_configuration.flags ) + if config.initial_configuration._bandits_configuration: + bandit_config_store.set_configurations( + config.initial_configuration._bandits_configuration.bandits + ) config_requestor = ExperimentConfigurationRequestor( http_client=http_client, diff --git a/eppo_client/configuration.py b/eppo_client/configuration.py index 4df901f..0440f0e 100644 --- a/eppo_client/configuration.py +++ b/eppo_client/configuration.py @@ -1,4 +1,5 @@ -from eppo_client.models import UfcResponse +from typing import Union +from eppo_client.models import UfcResponse, BanditResponse class Configuration: @@ -7,5 +8,15 @@ class Configuration: interpret feature flags. """ - def __init__(self, flags_configuration: str): + def __init__( + self, + flags_configuration: Union[bytes, str], + bandits_configuration: Union[bytes, str, None] = None, + ) -> None: self._flags_configuration = UfcResponse.model_validate_json(flags_configuration) + + self._bandits_configuration = None + if bandits_configuration is not None: + self._bandits_configuration = BanditResponse.model_validate_json( + bandits_configuration + ) diff --git a/eppo_client/configuration_requestor.py b/eppo_client/configuration_requestor.py index 997aaf9..0ab091f 100644 --- a/eppo_client/configuration_requestor.py +++ b/eppo_client/configuration_requestor.py @@ -80,3 +80,7 @@ def _set_configuration(self, configuration: Configuration): self.__flag_config_store.set_configurations( configuration._flags_configuration.flags ) + if configuration._bandits_configuration is not None: + self.__bandit_config_store.set_configurations( + configuration._bandits_configuration.bandits + ) diff --git a/eppo_client/configuration_store.py b/eppo_client/configuration_store.py index 3301f60..f7fc9a3 100644 --- a/eppo_client/configuration_store.py +++ b/eppo_client/configuration_store.py @@ -6,7 +6,7 @@ class ConfigurationStore(Generic[T]): - def __init__(self): + def __init__(self) -> None: self.__is_initialized = False self.__cache: Dict[str, T] = {} self.__lock = ReadWriteLock() diff --git a/eppo_client/models.py b/eppo_client/models.py index cd74dbc..45c7cdf 100644 --- a/eppo_client/models.py +++ b/eppo_client/models.py @@ -105,3 +105,7 @@ class BanditData(SdkBaseModel): bandit_model_version: str = Field(alias="modelVersion") bandit_model_data: BanditModelData = Field(alias="modelData") updated_at: datetime + + +class BanditResponse(SdkBaseModel): + bandits: Dict[str, BanditData] diff --git a/eppo_client/version.py b/eppo_client/version.py index 01e2117..1a6e221 100644 --- a/eppo_client/version.py +++ b/eppo_client/version.py @@ -1,4 +1,4 @@ # Note to developers: When ready to bump to 4.0, please change # the `POLL_INTERVAL_SECONDS` constant in `eppo_client/constants.py` # to 30 seconds to match the behavior of the other server SDKs. -__version__ = "3.7.0" +__version__ = "3.8.0" diff --git a/test/configuration_test.py b/test/configuration_test.py index 1de83e4..8d3757c 100644 --- a/test/configuration_test.py +++ b/test/configuration_test.py @@ -5,7 +5,7 @@ def test_init_valid(): - Configuration(flags_configuration='{"flags": {}}') + Configuration(flags_configuration=b'{"flags": {}}') def test_init_invalid_json(): @@ -15,4 +15,4 @@ def test_init_invalid_json(): def test_init_invalid_format(): with pytest.raises(pydantic.ValidationError): - Configuration(flags_configuration='{"flags": []}') + Configuration(flags_configuration=b'{"flags": []}')