Skip to content

fix: Fix type checking #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion open_feature/evaluation_context/evaluation_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import typing


class EvaluationContext:
def __init__(self, targeting_key: str = None, attributes: dict = None):
def __init__(
self,
targeting_key: typing.Optional[str] = None,
attributes: typing.Optional[dict] = None,
):
self.targeting_key = targeting_key
self.attributes = attributes or {}

Expand Down
18 changes: 9 additions & 9 deletions open_feature/exception/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class OpenFeatureError(Exception):
"""

def __init__(
self, error_message: typing.Optional[str] = None, error_code: ErrorCode = None
self, error_code: ErrorCode, error_message: typing.Optional[str] = None
):
"""
Constructor for the generic OpenFeatureError.
Expand All @@ -35,7 +35,7 @@ def __init__(self, error_message: typing.Optional[str] = None):
@param error_message: an optional string message representing
why the error has been raised
"""
super().__init__(error_message, ErrorCode.FLAG_NOT_FOUND)
super().__init__(ErrorCode.FLAG_NOT_FOUND, error_message)


class GeneralError(OpenFeatureError):
Expand All @@ -51,7 +51,7 @@ def __init__(self, error_message: typing.Optional[str] = None):
@param error_message: an optional string message representing why the error
has been raised
"""
super().__init__(error_message, ErrorCode.GENERAL)
super().__init__(ErrorCode.GENERAL, error_message)


class ParseError(OpenFeatureError):
Expand All @@ -67,7 +67,7 @@ def __init__(self, error_message: typing.Optional[str] = None):
@param error_message: an optional string message representing why the
error has been raised
"""
super().__init__(error_message, ErrorCode.PARSE_ERROR)
super().__init__(ErrorCode.PARSE_ERROR, error_message)


class TypeMismatchError(OpenFeatureError):
Expand All @@ -83,7 +83,7 @@ def __init__(self, error_message: typing.Optional[str] = None):
@param error_message: an optional string message representing why the
error has been raised
"""
super().__init__(error_message, ErrorCode.TYPE_MISMATCH)
super().__init__(ErrorCode.TYPE_MISMATCH, error_message)


class TargetingKeyMissingError(OpenFeatureError):
Expand All @@ -92,14 +92,14 @@ class TargetingKeyMissingError(OpenFeatureError):
but one was not provided in the evaluation context.
"""

def __init__(self, error_message: str = None):
def __init__(self, error_message: typing.Optional[str] = None):
"""
Constructor for the TargetingKeyMissingError. The error code for this type of
exception is ErrorCode.TARGETING_KEY_MISSING.
@param error_message: a string message representing why the error has been
raised
"""
super().__init__(error_message, ErrorCode.TARGETING_KEY_MISSING)
super().__init__(ErrorCode.TARGETING_KEY_MISSING, error_message)


class InvalidContextError(OpenFeatureError):
Expand All @@ -108,11 +108,11 @@ class InvalidContextError(OpenFeatureError):
requirements.
"""

def __init__(self, error_message: str = None):
def __init__(self, error_message: typing.Optional[str]):
"""
Constructor for the InvalidContextError. The error code for this type of
exception is ErrorCode.INVALID_CONTEXT.
@param error_message: a string message representing why the error has been
raised
"""
super().__init__(error_message, ErrorCode.INVALID_CONTEXT)
super().__init__(ErrorCode.INVALID_CONTEXT, error_message)
12 changes: 7 additions & 5 deletions open_feature/flag_evaluation/flag_evaluation_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from open_feature.exception.error_code import ErrorCode
from open_feature.flag_evaluation.reason import Reason

T = typing.TypeVar("T", covariant=True)


@dataclass
class FlagEvaluationDetails:
class FlagEvaluationDetails(typing.Generic[T]):
flag_key: str
value: typing.Any
variant: str = None
reason: Reason = None
error_code: ErrorCode = None
value: T
variant: typing.Optional[str] = None
reason: typing.Optional[Reason] = None
error_code: typing.Optional[ErrorCode] = None
error_message: typing.Optional[str] = None
4 changes: 2 additions & 2 deletions open_feature/hooks/hook_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ class HookContext:
flag_type: FlagType
default_value: typing.Any
evaluation_context: EvaluationContext
client_metadata: dict = None
provider_metadata: dict = None
client_metadata: typing.Optional[dict] = None
provider_metadata: typing.Optional[dict] = None
8 changes: 4 additions & 4 deletions open_feature/hooks/hook_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def error_hooks(
hook_context: HookContext,
exception: Exception,
hooks: typing.List[Hook],
hints: dict,
hints: typing.Optional[typing.Mapping] = None,
):
kwargs = {"hook_context": hook_context, "exception": exception, "hints": hints}
_execute_hooks(
Expand All @@ -27,7 +27,7 @@ def after_all_hooks(
flag_type: FlagType,
hook_context: HookContext,
hooks: typing.List[Hook],
hints: dict,
hints: typing.Optional[typing.Mapping] = None,
):
kwargs = {"hook_context": hook_context, "hints": hints}
_execute_hooks(
Expand All @@ -40,7 +40,7 @@ def after_hooks(
hook_context: HookContext,
details: FlagEvaluationDetails,
hooks: typing.List[Hook],
hints: dict,
hints: typing.Optional[typing.Mapping] = None,
):
kwargs = {"hook_context": hook_context, "details": details, "hints": hints}
_execute_hooks_unchecked(
Expand All @@ -52,7 +52,7 @@ def before_hooks(
flag_type: FlagType,
hook_context: HookContext,
hooks: typing.List[Hook],
hints: dict,
hints: typing.Optional[typing.Mapping] = None,
) -> EvaluationContext:
kwargs = {"hook_context": hook_context, "hints": hints}
executed_hooks = _execute_hooks_unchecked(
Expand Down
6 changes: 4 additions & 2 deletions open_feature/open_feature_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from open_feature.open_feature_client import OpenFeatureClient
from open_feature.provider.provider import AbstractProvider

_provider = None
_provider: typing.Optional[AbstractProvider] = None


def get_client(name: str = None, version: str = None) -> OpenFeatureClient:
def get_client(
name: typing.Optional[str] = None, version: typing.Optional[str] = None
) -> OpenFeatureClient:
if _provider is None:
raise GeneralError(
error_message="Provider not set. Call set_provider before using get_client"
Expand Down
80 changes: 49 additions & 31 deletions open_feature/open_feature_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,34 @@
from open_feature.provider.no_op_provider import NoOpProvider
from open_feature.provider.provider import AbstractProvider

NUMERIC_TYPES = [FlagType.FLOAT, FlagType.INTEGER]

GetDetailCallable = typing.Union[
typing.Callable[
[str, bool, typing.Optional[EvaluationContext]], FlagEvaluationDetails[bool]
],
typing.Callable[
[str, int, typing.Optional[EvaluationContext]], FlagEvaluationDetails[int]
],
typing.Callable[
[str, float, typing.Optional[EvaluationContext]], FlagEvaluationDetails[float]
],
typing.Callable[
[str, str, typing.Optional[EvaluationContext]], FlagEvaluationDetails[str]
],
typing.Callable[
[str, dict, typing.Optional[EvaluationContext]], FlagEvaluationDetails[dict]
],
]


class OpenFeatureClient:
def __init__(
self,
name: str,
version: str,
context: EvaluationContext = None,
hooks: typing.List[Hook] = None,
provider: AbstractProvider = None,
name: typing.Optional[str],
version: typing.Optional[str],
provider: AbstractProvider,
context: typing.Optional[EvaluationContext] = None,
hooks: typing.Optional[typing.List[Hook]] = None,
):
self.name = name
self.version = version
Expand All @@ -49,8 +66,8 @@ def get_boolean_value(
self,
flag_key: str,
default_value: bool,
evaluation_context: EvaluationContext = None,
flag_evaluation_options: FlagEvaluationOptions = None,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> bool:
return self.evaluate_flag_details(
FlagType.BOOLEAN,
Expand All @@ -64,8 +81,8 @@ def get_boolean_details(
self,
flag_key: str,
default_value: bool,
evaluation_context: EvaluationContext = None,
flag_evaluation_options: FlagEvaluationOptions = None,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails:
return self.evaluate_flag_details(
FlagType.BOOLEAN,
Expand All @@ -79,8 +96,8 @@ def get_string_value(
self,
flag_key: str,
default_value: str,
evaluation_context: EvaluationContext = None,
flag_evaluation_options: FlagEvaluationOptions = None,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> str:
return self.evaluate_flag_details(
FlagType.STRING,
Expand All @@ -94,8 +111,8 @@ def get_string_details(
self,
flag_key: str,
default_value: str,
evaluation_context: EvaluationContext = None,
flag_evaluation_options: FlagEvaluationOptions = None,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails:
return self.evaluate_flag_details(
FlagType.STRING,
Expand All @@ -109,8 +126,8 @@ def get_integer_value(
self,
flag_key: str,
default_value: int,
evaluation_context: EvaluationContext = None,
flag_evaluation_options: FlagEvaluationOptions = None,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> int:
return self.get_integer_details(
flag_key,
Expand All @@ -123,8 +140,8 @@ def get_integer_details(
self,
flag_key: str,
default_value: int,
evaluation_context: EvaluationContext = None,
flag_evaluation_options: FlagEvaluationOptions = None,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails:
return self.evaluate_flag_details(
FlagType.INTEGER,
Expand All @@ -138,8 +155,8 @@ def get_float_value(
self,
flag_key: str,
default_value: float,
evaluation_context: EvaluationContext = None,
flag_evaluation_options: FlagEvaluationOptions = None,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> float:
return self.get_float_details(
flag_key,
Expand All @@ -152,8 +169,8 @@ def get_float_details(
self,
flag_key: str,
default_value: float,
evaluation_context: EvaluationContext = None,
flag_evaluation_options: FlagEvaluationOptions = None,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails:
return self.evaluate_flag_details(
FlagType.FLOAT,
Expand All @@ -167,8 +184,8 @@ def get_object_value(
self,
flag_key: str,
default_value: dict,
evaluation_context: EvaluationContext = None,
flag_evaluation_options: FlagEvaluationOptions = None,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> dict:
return self.evaluate_flag_details(
FlagType.OBJECT,
Expand All @@ -182,8 +199,8 @@ def get_object_details(
self,
flag_key: str,
default_value: dict,
evaluation_context: EvaluationContext = None,
flag_evaluation_options: FlagEvaluationOptions = None,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails:
return self.evaluate_flag_details(
FlagType.OBJECT,
Expand All @@ -198,8 +215,8 @@ def evaluate_flag_details(
flag_type: FlagType,
flag_key: str,
default_value: typing.Any,
evaluation_context: EvaluationContext = None,
flag_evaluation_options: FlagEvaluationOptions = None,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails:
"""
Evaluate the flag requested by the user from the clients provider.
Expand Down Expand Up @@ -302,7 +319,7 @@ def _create_provider_evaluation(
flag_type: FlagType,
flag_key: str,
default_value: typing.Any,
evaluation_context: EvaluationContext = None,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagEvaluationDetails:
"""
Encapsulated method to create a FlagEvaluationDetail from a specific provider.
Expand All @@ -324,14 +341,15 @@ def _create_provider_evaluation(
logging.info("No provider configured, using no-op provider.")
self.provider = NoOpProvider()

get_details_callable = {
get_details_callables: typing.Mapping[FlagType, GetDetailCallable] = {
FlagType.BOOLEAN: self.provider.resolve_boolean_details,
FlagType.INTEGER: self.provider.resolve_integer_details,
FlagType.FLOAT: self.provider.resolve_float_details,
FlagType.OBJECT: self.provider.resolve_object_details,
FlagType.STRING: self.provider.resolve_string_details,
}.get(flag_type)
}

get_details_callable = get_details_callables.get(flag_type)
if not get_details_callable:
raise GeneralError(error_message="Unknown flag type")

Expand Down
Loading