From 7c835f58f94c42cb361baff2d62e9a06579393ad Mon Sep 17 00:00:00 2001 From: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com> Date: Wed, 31 Jan 2024 14:16:38 +0800 Subject: [PATCH] feat: Add support for multiple sinks (#513) # Description Allow the observation publisher to publish to multiple sinks. Supported sinks are: Arize, BigQuery # Modifications - configuration format has been modified to support multiple sinks - update python requirement - fix mypy linter errors - Bigquery sink implementation # Tests # Checklist - [x] Added PR label - [x] Added unit test, integration, and/or e2e tests - [x] Tested locally - [ ] Updated documentation - [ ] Update Swagger spec if the PR introduce API changes - [ ] Regenerated Golang and Python client if the PR introduces API changes # Release Notes ```release-note ``` --- .../conf/environment/example-override.yaml | 29 +- .../publisher/__main__.py | 26 +- .../observation-publisher/publisher/config.py | 46 +--- .../observation-publisher/publisher/metric.py | 60 +++++ .../publisher/observability_backend.py | 113 -------- .../publisher/observation_sink.py | 249 ++++++++++++++++++ .../publisher/prediction_log_consumer.py | 96 +++++-- .../publisher/prediction_log_parser.py | 89 +++++-- python/observation-publisher/pyproject.toml | 7 + .../requirements-dev.txt | 7 +- python/observation-publisher/requirements.in | 3 + python/observation-publisher/requirements.txt | 19 +- .../tests/test_config.py | 35 ++- ...ty_backend.py => test_observation_sink.py} | 12 +- python/sdk/merlin/observability/inference.py | 31 ++- 15 files changed, 581 insertions(+), 241 deletions(-) create mode 100644 python/observation-publisher/publisher/metric.py delete mode 100644 python/observation-publisher/publisher/observability_backend.py create mode 100644 python/observation-publisher/publisher/observation_sink.py rename python/observation-publisher/tests/{test_observability_backend.py => test_observation_sink.py} (94%) diff --git a/python/observation-publisher/conf/environment/example-override.yaml b/python/observation-publisher/conf/environment/example-override.yaml index 276352df2..d645b450d 100644 --- a/python/observation-publisher/conf/environment/example-override.yaml +++ b/python/observation-publisher/conf/environment/example-override.yaml @@ -16,20 +16,31 @@ inference_schema: distance: "int64" transaction: "float64" prediction_id_column: "prediction_id" -observability_backend: - # Supported backend types: +observation_sinks: + # Supported sink types: # - ARIZE - type: "ARIZE" - # Required if observability_backend.type is ARIZE - arize_config: - api_key: "SECRET_API_KEY" - space_key: "SECRET_SPACE_KEY" + # - BIGQUERY + - type: "ARIZE" + config: + api_key: "SECRET_API_KEY" + space_key: "SECRET_SPACE_KEY" + - type: "BIGQUERY" + config: + # GCP project for the dataset + project: "test-project" + # GCP dataset to store the observation data on + dataset: "test-dataset" + # Number of days before the created table will expire + ttl_days: 14 observation_source: # Supported consumer types: # - KAFKA type: "KAFKA" - # Required if consumer.type is KAFKA - kafka_config: + # (Optional) Number of messages to be kept in-memory before being sent to the sinks. Default: 10 + buffer_capacity: 10 + # (Optional) Maximum duration in seconds to keep messages in-memory before being sent to the sinks, if the capacity is not met. Default: 60 + buffer_max_duration_seconds: 60 + config: topic: "test-topic" bootstrap_servers: "localhost:9092" group_id: "test-group" diff --git a/python/observation-publisher/publisher/__main__.py b/python/observation-publisher/publisher/__main__.py index 6454173bb..b048ce147 100644 --- a/python/observation-publisher/publisher/__main__.py +++ b/python/observation-publisher/publisher/__main__.py @@ -1,9 +1,11 @@ import hydra from merlin.observability.inference import InferenceSchema from omegaconf import OmegaConf +from prometheus_client import start_http_server from publisher.config import PublisherConfig -from publisher.observability_backend import new_observation_sink +from publisher.metric import MetricWriter +from publisher.observation_sink import new_observation_sink from publisher.prediction_log_consumer import new_consumer @@ -12,18 +14,26 @@ def start_consumer(cfg: PublisherConfig) -> None: missing_keys: set[str] = OmegaConf.missing_keys(cfg) if missing_keys: raise RuntimeError(f"Got missing keys in config:\n{missing_keys}") + + start_http_server(cfg.environment.prometheus_port) + MetricWriter().setup( + model_id=cfg.environment.model_id, model_version=cfg.environment.model_version + ) prediction_log_consumer = new_consumer(cfg.environment.observation_source) inference_schema = InferenceSchema.from_dict( OmegaConf.to_container(cfg.environment.inference_schema) ) - observation_sink = new_observation_sink( - config=cfg.environment.observability_backend, - inference_schema=inference_schema, - model_id=cfg.environment.model_id, - model_version=cfg.environment.model_version, - ) + observation_sinks = [ + new_observation_sink( + sink_config=sink_config, + inference_schema=inference_schema, + model_id=cfg.environment.model_id, + model_version=cfg.environment.model_version, + ) + for sink_config in cfg.environment.observation_sinks + ] prediction_log_consumer.start_polling( - observation_sink=observation_sink, + observation_sinks=observation_sinks, inference_schema=inference_schema, ) diff --git a/python/observation-publisher/publisher/config.py b/python/observation-publisher/publisher/config.py index f55f10da3..914496748 100644 --- a/python/observation-publisher/publisher/config.py +++ b/python/observation-publisher/publisher/config.py @@ -1,56 +1,31 @@ from dataclasses import dataclass from enum import Enum -from typing import Optional +from typing import List from hydra.core.config_store import ConfigStore -@dataclass -class ArizeConfig: - api_key: str - space_key: str - - -class ObservabilityBackendType(Enum): +class ObservationSinkType(Enum): ARIZE = "arize" + BIGQUERY = "bigquery" @dataclass -class ObservabilityBackend: - type: ObservabilityBackendType - arize_config: Optional[ArizeConfig] = None - - def __post_init__(self): - if self.type == ObservabilityBackendType.ARIZE: - assert ( - self.arize_config is not None - ), "Arize config must be set for Arize observability backend" +class ObservationSinkConfig: + type: ObservationSinkType + config: dict class ObservationSource(Enum): KAFKA = "kafka" -@dataclass -class KafkaConsumerConfig: - topic: str - bootstrap_servers: str - group_id: str - batch_size: int = 100 - poll_timeout_seconds: float = 1.0 - additional_consumer_config: Optional[dict] = None - - @dataclass class ObservationSourceConfig: type: ObservationSource - kafka_config: Optional[KafkaConsumerConfig] = None - - def __post_init__(self): - if self.type == ObservationSource.KAFKA: - assert ( - self.kafka_config is not None - ), "Kafka config must be set for Kafka observation source" + config: dict + buffer_capacity: int = 10 + buffer_max_duration_seconds: int = 60 @dataclass @@ -58,8 +33,9 @@ class Environment: model_id: str model_version: str inference_schema: dict - observability_backend: ObservabilityBackend + observation_sinks: List[ObservationSinkConfig] observation_source: ObservationSourceConfig + prometheus_port: int = 8000 @dataclass diff --git a/python/observation-publisher/publisher/metric.py b/python/observation-publisher/publisher/metric.py new file mode 100644 index 000000000..26d5314af --- /dev/null +++ b/python/observation-publisher/publisher/metric.py @@ -0,0 +1,60 @@ +from pandas import Timestamp +from prometheus_client import Gauge, Counter + + +class MetricWriter(object): + """ + Singleton class for writing metrics to Prometheus. + """ + + _instance = None + + def __init__(self): + if not self._initialized: + self.model_id = None + self.model_version = "" + self.last_processed_timestamp_gauge = Gauge( + "last_processed_timestamp", + "The timestamp of the last prediction log processed by the publisher", + ["model_id", "model_version"], + ) + self.total_prediction_logs_processed_counter = Counter( + "total_prediction_logs_processed", + "The total number of prediction logs processed by the publisher", + ) + self._initialized = True + + def __new__(cls): + if not cls._instance: + cls._instance = super(MetricWriter, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance + + def setup(self, model_id: str, model_version: str): + """ + Needs to be run before sending metrics, so that the singleton instance has the correct properties value. + :param model_id: + :param model_version: + :return: + """ + self.model_id = model_id + self.model_version = model_version + + def update_last_processed_timestamp(self, last_processed_timestamp: Timestamp): + """ + Updates the last_processed_timestamp gauge with the given value. + :param last_processed_timestamp: + :return: + """ + self.last_processed_timestamp_gauge.labels( + model_id=self.model_id, model_version=self.model_version + ).set(last_processed_timestamp.timestamp()) + + def increment_total_prediction_logs_processed(self, value: int): + """ + Increments the total_prediction_logs_processed counter by value. + :return: + """ + self.total_prediction_logs_processed_counter.labels( + model_id=self.model_id, model_version=self.model_version + ).inc(value) diff --git a/python/observation-publisher/publisher/observability_backend.py b/python/observation-publisher/publisher/observability_backend.py deleted file mode 100644 index 71463fe53..000000000 --- a/python/observation-publisher/publisher/observability_backend.py +++ /dev/null @@ -1,113 +0,0 @@ -import abc -from typing import Tuple - -import pandas as pd -from arize.pandas.logger import Client -from arize.pandas.logger import Schema as ArizeSchema -from arize.pandas.validation.errors import ValidationFailure -from arize.utils.types import Environments -from arize.utils.types import ModelTypes as ArizeModelType -from merlin.observability.inference import ( - InferenceSchema, - RegressionOutput, - BinaryClassificationOutput, - RankingOutput, - ObservationType, -) - -from publisher.config import ObservabilityBackend, ObservabilityBackendType -from publisher.prediction_log_parser import PREDICTION_LOG_TIMESTAMP_COLUMN - - -class ObservationSink(abc.ABC): - @abc.abstractmethod - def write(self, dataframe: pd.DataFrame): - raise NotImplementedError - - -class ArizeSink(ObservationSink): - def __init__( - self, - arize_client: Client, - inference_schema: InferenceSchema, - model_id: str, - model_version: str, - ): - self._client = arize_client - self._model_id = model_id - self._model_version = model_version - self._inference_schema = inference_schema - - def common_arize_schema_attributes(self) -> dict: - return dict( - feature_column_names=self._inference_schema.feature_columns, - prediction_id_column_name=self._inference_schema.prediction_id_column, - timestamp_column_name=PREDICTION_LOG_TIMESTAMP_COLUMN, - tag_column_names=self._inference_schema.tag_columns, - ) - - def to_arize_schema(self) -> Tuple[ArizeModelType, ArizeSchema]: - prediction_output = self._inference_schema.model_prediction_output - if isinstance(prediction_output, BinaryClassificationOutput): - schema_attributes = self.common_arize_schema_attributes() | dict( - prediction_label_column_name=prediction_output.prediction_label_column, - prediction_score_column_name=prediction_output.prediction_score_column, - ) - model_type = ArizeModelType.BINARY_CLASSIFICATION - elif isinstance(prediction_output, RegressionOutput): - schema_attributes = self.common_arize_schema_attributes() | dict( - prediction_score_column_name=prediction_output.prediction_score_column, - ) - model_type = ArizeModelType.REGRESSION - elif isinstance(prediction_output, RankingOutput): - schema_attributes = self.common_arize_schema_attributes() | dict( - rank_column_name=prediction_output.rank_column, - prediction_group_id_column_name=prediction_output.prediction_group_id_column, - ) - model_type = ArizeModelType.RANKING - else: - raise ValueError( - f"Unknown prediction output type: {type(prediction_output)}" - ) - - return model_type, ArizeSchema(**schema_attributes) - - def write(self, df: pd.DataFrame): - processed_df = self._inference_schema.model_prediction_output.preprocess( - df, [ObservationType.FEATURE, ObservationType.PREDICTION] - ) - model_type, arize_schema = self.to_arize_schema() - try: - self._client.log( - dataframe=processed_df, - environment=Environments.PRODUCTION, - schema=arize_schema, - model_id=self._model_id, - model_type=model_type, - model_version=self._model_version, - ) - except ValidationFailure as e: - error_mesage = "\n".join([err.error_message() for err in e.errors]) - print(f"Failed to log to Arize: {error_mesage}") - raise e - except Exception as e: - print(f"Failed to log to Arize: {e}") - raise e - - -def new_observation_sink( - config: ObservabilityBackend, - inference_schema: InferenceSchema, - model_id: str, - model_version: str, -) -> ObservationSink: - if config.type == ObservabilityBackendType.ARIZE: - client = Client(space_key=config.arize_config.space_key, api_key=config.arize_config.api_key) - return ArizeSink( - arize_client = client, - inference_schema=inference_schema, - model_id=model_id, - model_version=model_version, - ) - else: - raise ValueError(f"Unknown observability backend type: {config.type}") diff --git a/python/observation-publisher/publisher/observation_sink.py b/python/observation-publisher/publisher/observation_sink.py new file mode 100644 index 000000000..f501c60e0 --- /dev/null +++ b/python/observation-publisher/publisher/observation_sink.py @@ -0,0 +1,249 @@ +import abc +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import List, Tuple + +import pandas as pd +from arize.pandas.logger import Client as ArizeClient +from arize.pandas.logger import Schema as ArizeSchema +from arize.pandas.validation.errors import ValidationFailure +from arize.utils.types import Environments +from arize.utils.types import ModelTypes as ArizeModelType +from dataclasses_json import dataclass_json +from google.cloud.bigquery import Client as BigQueryClient +from google.cloud.bigquery import ( + SchemaField, + Table, + TimePartitioning, + TimePartitioningType, +) +from merlin.observability.inference import ( + BinaryClassificationOutput, + InferenceSchema, + ObservationType, + RankingOutput, + RegressionOutput, + ValueType, +) + +from publisher.config import ObservationSinkConfig, ObservationSinkType +from publisher.prediction_log_parser import PREDICTION_LOG_TIMESTAMP_COLUMN + + +class ObservationSink(abc.ABC): + """ + An abstract class for writing prediction logs to an observability backend. + """ + + def __init__( + self, + inference_schema: InferenceSchema, + model_id: str, + model_version: str, + ): + self._inference_schema = inference_schema + self._model_id = model_id + self._model_version = model_version + + @abc.abstractmethod + def write(self, dataframe: pd.DataFrame): + """ + Convert a given pandas dataframe to PredictionLog protobuf, then send them to the observability backend. + :param dataframe: + :return: + """ + raise NotImplementedError + + +@dataclass_json +@dataclass +class ArizeConfig: + api_key: str + space_key: str + + +class ArizeSink(ObservationSink): + """ + Writes prediction logs to Arize AI. + """ + + def __init__( + self, + inference_schema: InferenceSchema, + model_id: str, + model_version: str, + arize_client: ArizeClient, + ): + super().__init__(inference_schema, model_id, model_version) + self._client = arize_client + + def _common_arize_schema_attributes(self) -> dict: + return dict( + feature_column_names=self._inference_schema.feature_columns, + prediction_id_column_name=self._inference_schema.prediction_id_column, + timestamp_column_name=PREDICTION_LOG_TIMESTAMP_COLUMN, + tag_column_names=self._inference_schema.tag_columns, + ) + + def _to_arize_schema(self) -> Tuple[ArizeModelType, ArizeSchema]: + prediction_output = self._inference_schema.model_prediction_output + if isinstance(prediction_output, BinaryClassificationOutput): + schema_attributes = self._common_arize_schema_attributes() | dict( + prediction_label_column_name=prediction_output.prediction_label_column, + prediction_score_column_name=prediction_output.prediction_score_column, + ) + model_type = ArizeModelType.BINARY_CLASSIFICATION + elif isinstance(prediction_output, RegressionOutput): + schema_attributes = self._common_arize_schema_attributes() | dict( + prediction_score_column_name=prediction_output.prediction_score_column, + ) + model_type = ArizeModelType.REGRESSION + elif isinstance(prediction_output, RankingOutput): + schema_attributes = self._common_arize_schema_attributes() | dict( + rank_column_name=prediction_output.rank_column, + prediction_group_id_column_name=prediction_output.prediction_group_id_column, + ) + model_type = ArizeModelType.RANKING + else: + raise ValueError( + f"Unknown prediction output type: {type(prediction_output)}" + ) + + return model_type, ArizeSchema(**schema_attributes) + + def write(self, df: pd.DataFrame): + processed_df = self._inference_schema.model_prediction_output.preprocess( + df, [ObservationType.FEATURE, ObservationType.PREDICTION] + ) + model_type, arize_schema = self._to_arize_schema() + try: + self._client.log( + dataframe=processed_df, + environment=Environments.PRODUCTION, + schema=arize_schema, + model_id=self._model_id, + model_type=model_type, + model_version=self._model_version, + ) + except ValidationFailure as e: + error_message = "\n".join([err.error_message() for err in e.errors]) + print(f"Failed to log to Arize: {error_message}") + raise e + except Exception as e: + print(f"Failed to log to Arize: {e}") + raise e + + +@dataclass_json +@dataclass +class BigQueryConfig: + project: str + dataset: str + ttl_days: int + + +class BigQuerySink(ObservationSink): + """ + Writes prediction logs to BigQuery. If the destination table doesn't exist, it will be created based on the inference schema.. + """ + + def __init__( + self, + inference_schema: InferenceSchema, + model_id: str, + model_version: str, + project: str, + dataset: str, + ttl_days: int, + ): + super().__init__(inference_schema, model_id, model_version) + self._client = BigQueryClient() + self._inference_schema = inference_schema + self._model_id = model_id + self._model_version = model_version + self._project = project + self._dataset = dataset + table = Table(self.write_location, schema=self.schema_fields) + table.time_partitioning = TimePartitioning(type_=TimePartitioningType.DAY) + table.expires = datetime.now() + timedelta(days=ttl_days) + self._table: Table = self._client.create_table(exists_ok=True, table=table) + + @property + def schema_fields(self) -> List[SchemaField]: + value_type_to_bq_type = { + ValueType.INT64: "INTEGER", + ValueType.FLOAT64: "FLOAT", + ValueType.BOOLEAN: "BOOLEAN", + ValueType.STRING: "STRING", + } + + schema_fields = [ + SchemaField( + name=self._inference_schema.prediction_id_column, + field_type="STRING", + ), + SchemaField( + name=PREDICTION_LOG_TIMESTAMP_COLUMN, + field_type="TIMESTAMP", + ), + ] + for feature, feature_type in self._inference_schema.feature_types.items(): + schema_fields.append( + SchemaField( + name=feature, field_type=value_type_to_bq_type[feature_type] + ) + ) + for ( + prediction, + prediction_type, + ) in self._inference_schema.model_prediction_output.prediction_types().items(): + schema_fields.append( + SchemaField( + name=prediction, field_type=value_type_to_bq_type[prediction_type] + ) + ) + + return schema_fields + + @property + def write_location(self) -> str: + table_name = f"prediction_log_{self._model_id}_{self._model_version}".replace( + "-", "_" + ).replace(".", "_") + return f"{self._project}.{self._dataset}.{table_name}" + + def write(self, dataframe: pd.DataFrame): + self._client.insert_rows_from_dataframe(dataframe=dataframe, table=self._table) + + +def new_observation_sink( + sink_config: ObservationSinkConfig, + inference_schema: InferenceSchema, + model_id: str, + model_version: str, +) -> ObservationSink: + match sink_config.type: + case ObservationSinkType.BIGQUERY: + bq_config: BigQueryConfig = BigQueryConfig.from_dict(sink_config.config) # type: ignore[attr-defined] + + return BigQuerySink( + inference_schema=inference_schema, + model_id=model_id, + model_version=model_version, + project=bq_config.project, + dataset=bq_config.dataset, + ttl_days=bq_config.ttl_days, + ) + case ObservationSinkType.ARIZE: + arize_config: ArizeConfig = ArizeConfig.from_dict(sink_config.config) # type: ignore[attr-defined] + client = ArizeClient( + space_key=arize_config.space_key, api_key=arize_config.api_key + ) + return ArizeSink( + inference_schema=inference_schema, + model_id=model_id, + model_version=model_version, + arize_client=client, + ) + case _: + raise ValueError(f"Unknown observability backend type: {sink_config.type}") diff --git a/python/observation-publisher/publisher/prediction_log_consumer.py b/python/observation-publisher/publisher/prediction_log_consumer.py index c2e8f21f4..1d6199dde 100644 --- a/python/observation-publisher/publisher/prediction_log_consumer.py +++ b/python/observation-publisher/publisher/prediction_log_consumer.py @@ -1,26 +1,31 @@ import abc -from typing import List, Tuple +from dataclasses import dataclass +from datetime import datetime +from threading import Thread +from typing import List, Optional, Tuple import numpy as np import pandas as pd from caraml.upi.v1.prediction_log_pb2 import PredictionLog from confluent_kafka import Consumer, KafkaException +from dataclasses_json import DataClassJsonMixin, dataclass_json from merlin.observability.inference import InferenceSchema -from publisher.config import ( - KafkaConsumerConfig, - ObservationSource, - ObservationSourceConfig, -) -from publisher.observability_backend import ObservationSink +from publisher.config import ObservationSource, ObservationSourceConfig +from publisher.metric import MetricWriter +from publisher.observation_sink import ObservationSink from publisher.prediction_log_parser import ( PREDICTION_LOG_TIMESTAMP_COLUMN, - parse_struct_to_feature_table, - parse_struct_to_result_table, + PredictionLogFeatureTable, + PredictionLogResultsTable, ) class PredictionLogConsumer(abc.ABC): + def __init__(self, buffer_capacity: int, buffer_max_duration_seconds: int): + self.buffer_capacity = buffer_capacity + self.buffer_max_duration_seconds = buffer_max_duration_seconds + @abc.abstractmethod def poll_new_logs(self) -> List[PredictionLog]: raise NotImplementedError @@ -34,22 +39,71 @@ def close(self): raise NotImplementedError def start_polling( - self, observation_sink: ObservationSink, inference_schema: InferenceSchema + self, + observation_sinks: List[ObservationSink], + inference_schema: InferenceSchema, ): try: + buffered_logs = [] + buffered_max_duration_seconds = 60 + buffer_start_time = datetime.now() while True: logs = self.poll_new_logs() if len(logs) == 0: continue - df = log_batch_to_dataframe(logs, inference_schema) - observation_sink.write(df) + buffered_logs.extend(logs) + buffered_duration = (datetime.now() - buffer_start_time).seconds + if ( + len(buffered_logs) < self.buffer_capacity + and buffered_duration < buffered_max_duration_seconds + ): + continue + df = log_batch_to_dataframe(buffered_logs, inference_schema) + most_recent_prediction_timestamp = df[ + PREDICTION_LOG_TIMESTAMP_COLUMN + ].max() + MetricWriter().update_last_processed_timestamp( + most_recent_prediction_timestamp + ) + MetricWriter().increment_total_prediction_logs_processed( + len(buffered_logs) + ) + write_tasks = [ + Thread(target=sink.write, args=(df,)) for sink in observation_sinks + ] + for task in write_tasks: + task.start() + for task in write_tasks: + task.join() self.commit() + buffered_logs = [] + buffer_start_time = datetime.now() finally: self.close() +@dataclass_json +@dataclass +class KafkaConsumerConfig: + topic: str + bootstrap_servers: str + group_id: str + batch_size: int = 100 + poll_timeout_seconds: float = 1.0 + additional_consumer_config: Optional[dict] = None + + class KafkaPredictionLogConsumer(PredictionLogConsumer): - def __init__(self, config: KafkaConsumerConfig): + def __init__( + self, + buffer_capacity: int, + buffer_max_duration_seconds: int, + config: KafkaConsumerConfig, + ): + super().__init__( + buffer_capacity=buffer_capacity, + buffer_max_duration_seconds=buffer_max_duration_seconds, + ) consumer_config = { "bootstrap.servers": config.bootstrap_servers, "group.id": config.group_id, @@ -86,7 +140,15 @@ def close(self): def new_consumer(config: ObservationSourceConfig) -> PredictionLogConsumer: if config.type == ObservationSource.KAFKA: - return KafkaPredictionLogConsumer(config.kafka_config) + assert issubclass(KafkaConsumerConfig, DataClassJsonMixin) + kafka_consumer_config: KafkaConsumerConfig = KafkaConsumerConfig.from_dict( + config.config + ) # type: ignore[attr-defined] + return KafkaPredictionLogConsumer( + config.buffer_capacity, + config.buffer_max_duration_seconds, + kafka_consumer_config, + ) else: raise ValueError(f"Unknown consumer type: {config.type}") @@ -101,10 +163,10 @@ def log_to_records( log: PredictionLog, inference_schema: InferenceSchema ) -> Tuple[List[List[np.int64 | np.float64 | np.bool_ | np.str_]], List[str]]: request_timestamp = log.request_timestamp.ToDatetime() - feature_table = parse_struct_to_feature_table( + feature_table = PredictionLogFeatureTable.from_struct( log.input.features_table, inference_schema ) - prediction_results_table = parse_struct_to_result_table( + prediction_results_table = PredictionLogResultsTable.from_struct( log.output.prediction_results_table, inference_schema ) @@ -130,7 +192,7 @@ def log_batch_to_dataframe( logs: List[PredictionLog], inference_schema: InferenceSchema ) -> pd.DataFrame: combined_records = [] - column_names = [] + column_names: List[str] = [] for log in logs: rows, column_names = log_to_records(log, inference_schema) combined_records.extend(rows) diff --git a/python/observation-publisher/publisher/prediction_log_parser.py b/python/observation-publisher/publisher/prediction_log_parser.py index 13024de66..669da15d2 100644 --- a/python/observation-publisher/publisher/prediction_log_parser.py +++ b/python/observation-publisher/publisher/prediction_log_parser.py @@ -1,9 +1,11 @@ from dataclasses import dataclass +from types import NoneType from typing import Dict, List, Optional, Union import numpy as np from google.protobuf.internal.well_known_types import ListValue, Struct from merlin.observability.inference import InferenceSchema, ValueType +from typing_extensions import Self PREDICTION_LOG_TIMESTAMP_COLUMN = "request_timestamp" @@ -13,6 +15,20 @@ class PredictionLogFeatureTable: columns: List[str] rows: List[List[Union[np.int64, np.float64, np.bool_, np.str_]]] + @classmethod + def from_struct( + cls, table_struct: Struct, inference_schema: InferenceSchema + ) -> Self: + assert isinstance(table_struct["columns"], ListValue) + columns = list_value_as_string_list(table_struct["columns"]) + column_types = inference_schema.feature_types + assert isinstance(table_struct["data"], ListValue) + rows = list_value_as_rows(table_struct["data"]) + return cls( + columns=columns, + rows=[list_value_as_numpy_list(row, columns, column_types) for row in rows], + ) + @dataclass class PredictionLogResultsTable: @@ -20,54 +36,69 @@ class PredictionLogResultsTable: rows: List[List[Union[np.int64, np.float64, np.bool_, np.str_]]] row_ids: List[str] + @classmethod + def from_struct( + cls, table_struct: Struct, inference_schema: InferenceSchema + ) -> Self: + assert isinstance(table_struct["columns"], ListValue) + assert isinstance(table_struct["data"], ListValue) + assert isinstance(table_struct["row_ids"], ListValue) + columns = list_value_as_string_list(table_struct["columns"]) + column_types = inference_schema.model_prediction_output.prediction_types() + rows = list_value_as_rows(table_struct["data"]) + row_ids = list_value_as_string_list(table_struct["row_ids"]) + return cls( + columns=columns, + rows=[list_value_as_numpy_list(row, columns, column_types) for row in rows], + row_ids=row_ids, + ) + def convert_to_numpy_value( col_value: Optional[int | str | float | bool], value_type: ValueType ) -> np.int64 | np.float64 | np.bool_ | np.str_: match value_type: case ValueType.INT64: + assert isinstance(col_value, (int, float)) return np.int64(col_value) case ValueType.FLOAT64: + assert isinstance(col_value, (int, float, NoneType)) return np.float64(col_value) case ValueType.BOOLEAN: + assert isinstance(col_value, bool) return np.bool_(col_value) case ValueType.STRING: + assert isinstance(col_value, str) return np.str_(col_value) case _: raise ValueError(f"Unknown value type: {value_type}") -def convert_list_value( +def list_value_as_string_list(list_value: ListValue) -> List[str]: + string_list: List[str] = [] + for v in list_value: + assert isinstance(v, str) + string_list.append(v) + return string_list + + +def list_value_as_rows(list_value: ListValue) -> List[ListValue]: + rows: List[ListValue] = [] + for d in list_value: + assert isinstance(d, ListValue) + rows.append(d) + return rows + + +def list_value_as_numpy_list( list_value: ListValue, column_names: List[str], column_types: Dict[str, ValueType] ) -> List[np.int64 | np.float64 | np.bool_ | np.str_]: + column_values: List[int | str | float | bool | None] = [] + for v in list_value: + assert isinstance(v, (int, str, float, bool, NoneType)) + column_values.append(v) + return [ convert_to_numpy_value(col_value, column_types[col_name]) - for col_value, col_name in zip([v for v in list_value], column_names) + for col_value, col_name in zip(column_values, column_names) ] - - -def parse_struct_to_feature_table( - table_struct: Struct, inference_schema: InferenceSchema -) -> PredictionLogFeatureTable: - columns = [c for c in table_struct["columns"]] - column_types = inference_schema.feature_types - return PredictionLogFeatureTable( - columns=columns, - rows=[ - convert_list_value(d, columns, column_types) for d in table_struct["data"] - ], - ) - - -def parse_struct_to_result_table( - table_struct: Struct, inference_schema: InferenceSchema -) -> PredictionLogResultsTable: - columns = [c for c in table_struct["columns"]] - column_types = inference_schema.model_prediction_output.prediction_types() - return PredictionLogResultsTable( - columns=columns, - rows=[ - convert_list_value(d, columns, column_types) for d in table_struct["data"] - ], - row_ids=[r for r in table_struct["row_ids"]], - ) diff --git a/python/observation-publisher/pyproject.toml b/python/observation-publisher/pyproject.toml index 2c70d6c3d..ac89db93c 100644 --- a/python/observation-publisher/pyproject.toml +++ b/python/observation-publisher/pyproject.toml @@ -2,3 +2,10 @@ addopts = [ "--import-mode=importlib", ] + +[tool.mypy] +exclude = "test.*" + +[[tool.mypy.overrides]] +module = ["arize.*", "merlin.*", "confluent_kafka.*", "caraml.upi.*", "pyarrow.*"] +ignore_missing_imports = true diff --git a/python/observation-publisher/requirements-dev.txt b/python/observation-publisher/requirements-dev.txt index 1c1df9505..d8df3e33f 100644 --- a/python/observation-publisher/requirements-dev.txt +++ b/python/observation-publisher/requirements-dev.txt @@ -1,2 +1,7 @@ pip-tools==7.3.0 -pytest==7.4.3 \ No newline at end of file +pytest==7.4.3 +types-requests==2.31.0.20231231 +types-PyYAML==6.0.12.12 +types-jmespath==1.0.2.7 +mypy==1.7.1 +mypy-extensions==1.0.0 \ No newline at end of file diff --git a/python/observation-publisher/requirements.in b/python/observation-publisher/requirements.in index 2f5ec3225..f42e3bd7f 100644 --- a/python/observation-publisher/requirements.in +++ b/python/observation-publisher/requirements.in @@ -3,4 +3,7 @@ caraml-upi-protos>=1.0.0 arize==7.7.* hydra-core>=1.3.0 pandas>=1.0.0 +google-cloud-bigquery +prometheus-client >= 0.19.0 +typing-extensions==4.9.0 -e file:../sdk \ No newline at end of file diff --git a/python/observation-publisher/requirements.txt b/python/observation-publisher/requirements.txt index 2062495ee..ebd0d6930 100644 --- a/python/observation-publisher/requirements.txt +++ b/python/observation-publisher/requirements.txt @@ -77,6 +77,7 @@ gitpython==3.1.40 # via mlflow google-api-core==2.15.0 # via + # google-cloud-bigquery # google-cloud-core # google-cloud-storage google-auth==2.25.2 @@ -85,8 +86,12 @@ google-auth==2.25.2 # google-api-core # google-cloud-core # google-cloud-storage +google-cloud-bigquery==3.14.1 + # via -r requirements.in google-cloud-core==2.4.1 - # via google-cloud-storage + # via + # google-cloud-bigquery + # google-cloud-storage google-cloud-storage==2.13.0 # via merlin-sdk google-crc32c==1.5.0 @@ -94,7 +99,9 @@ google-crc32c==1.5.0 # google-cloud-storage # google-resumable-media google-resumable-media==2.6.0 - # via google-cloud-storage + # via + # google-cloud-bigquery + # google-cloud-storage googleapis-common-protos==1.62.0 # via # arize @@ -153,6 +160,7 @@ omegaconf==2.3.0 packaging==21.3 # via # docker + # google-cloud-bigquery # hydra-core # marshmallow # mlflow @@ -162,7 +170,9 @@ pandas==1.5.3 # arize # mlflow prometheus-client==0.19.0 - # via prometheus-flask-exporter + # via + # -r requirements.in + # prometheus-flask-exporter prometheus-flask-exporter==0.23.0 # via mlflow protobuf==4.25.1 @@ -193,6 +203,7 @@ python-dateutil==2.8.2 # via # arrow # botocore + # google-cloud-bigquery # merlin-sdk # pandas python-slugify==8.0.1 @@ -215,6 +226,7 @@ requests==2.31.0 # databricks-cli # docker # google-api-core + # google-cloud-bigquery # google-cloud-storage # mlflow # requests-futures @@ -252,6 +264,7 @@ types-python-dateutil==2.8.19.14 # via arrow typing-extensions==4.9.0 # via + # -r requirements.in # alembic # typing-inspect typing-inspect==0.9.0 diff --git a/python/observation-publisher/tests/test_config.py b/python/observation-publisher/tests/test_config.py index 6483c7cde..3a3ba60f9 100644 --- a/python/observation-publisher/tests/test_config.py +++ b/python/observation-publisher/tests/test_config.py @@ -1,18 +1,13 @@ import dataclasses from hydra import compose, initialize -from merlin.observability.inference import ( - InferenceSchema, - ValueType, -) +from merlin.observability.inference import InferenceSchema, ValueType from omegaconf import OmegaConf from publisher.config import ( - ArizeConfig, Environment, - KafkaConsumerConfig, - ObservabilityBackend, - ObservabilityBackendType, + ObservationSinkConfig, + ObservationSinkType, ObservationSource, ObservationSourceConfig, PublisherConfig, @@ -40,20 +35,22 @@ def test_config_initialization(): score_threshold=0.5, ), ), - observability_backend=ObservabilityBackend( - type=ObservabilityBackendType.ARIZE, - arize_config=ArizeConfig( - api_key="SECRET_API_KEY", - space_key="SECRET_SPACE_KEY", - ), - ), + observation_sinks=[ + ObservationSinkConfig( + type=ObservationSinkType.ARIZE, + config=dict( + api_key="SECRET_API_KEY", + space_key="SECRET_SPACE_KEY", + ), + ) + ], observation_source=ObservationSourceConfig( type=ObservationSource.KAFKA, - kafka_config=KafkaConsumerConfig( + config=dict( topic="test-topic", bootstrap_servers="localhost:9092", group_id="test-group", - poll_timeout_seconds=1.0, + batch_size=100, additional_consumer_config={ "auto.offset.reset": "latest", }, @@ -67,8 +64,8 @@ def test_config_initialization(): assert parsed_schema == InferenceSchema.from_dict( expected_cfg.environment.inference_schema ) - assert cfg.environment.observability_backend == dataclasses.asdict( - expected_cfg.environment.observability_backend + assert cfg.environment.observation_sinks[0] == dataclasses.asdict( + expected_cfg.environment.observation_sinks[0] ) assert cfg.environment.observation_source == dataclasses.asdict( expected_cfg.environment.observation_source diff --git a/python/observation-publisher/tests/test_observability_backend.py b/python/observation-publisher/tests/test_observation_sink.py similarity index 94% rename from python/observation-publisher/tests/test_observability_backend.py rename to python/observation-publisher/tests/test_observation_sink.py index 2a5437440..de93b2a72 100644 --- a/python/observation-publisher/tests/test_observability_backend.py +++ b/python/observation-publisher/tests/test_observation_sink.py @@ -7,12 +7,12 @@ from merlin.observability.inference import ( BinaryClassificationOutput, InferenceSchema, - ValueType, RankingOutput, + ValueType, ) from requests import Response -from publisher.observability_backend import ArizeSink +from publisher.observation_sink import ArizeSink class MockResponse(Response): @@ -31,7 +31,9 @@ def _post_file( sync: Optional[bool], timeout: Optional[float] = None, ) -> Response: - return MockResponse(pa.ipc.open_stream(pa.OSFile(path)).read_pandas(), "Success", 200) + return MockResponse( + pa.ipc.open_stream(pa.OSFile(path)).read_pandas(), "Success", 200 + ) def test_binary_classification_model_preprocessing_for_arize(): @@ -49,10 +51,10 @@ def test_binary_classification_model_preprocessing_for_arize(): ) arize_client = MockArizeClient(api_key="test", space_key="test") arize_sink = ArizeSink( - arize_client, inference_schema, "test-model", "0.1.0", + arize_client, ) request_timestamp = datetime.now() input_df = pd.DataFrame.from_records( @@ -98,9 +100,9 @@ def test_ranking_model_preprocessing_for_arize(): ) arize_client = MockArizeClient(api_key="test", space_key="test") arize_sink = ArizeSink( - arize_client, inference_schema, "test-model", "0.1.0", + arize_client, ) arize_sink.write(input_df) diff --git a/python/sdk/merlin/observability/inference.py b/python/sdk/merlin/observability/inference.py index 7660b7ed5..edc491542 100644 --- a/python/sdk/merlin/observability/inference.py +++ b/python/sdk/merlin/observability/inference.py @@ -98,6 +98,13 @@ def preprocess( def prediction_types(self) -> Dict[str, ValueType]: raise NotImplementedError + """ + Return a dictionary mapping the name of the ground truth output column to its value type. + """ + @abc.abstractmethod + def ground_truth_types(self) -> Dict[str, ValueType]: + raise NotImplementedError + @dataclass_json @dataclass @@ -120,6 +127,10 @@ def preprocess( def prediction_types(self) -> Dict[str, ValueType]: return { self.prediction_score_column: ValueType.FLOAT64, + } + + def ground_truth_types(self) -> Dict[str, ValueType]: + return { self.actual_score_column: ValueType.FLOAT64, } @@ -193,9 +204,13 @@ def preprocess( def prediction_types(self) -> Dict[str, ValueType]: return { self.prediction_score_column: ValueType.FLOAT64, - self.prediction_label_column: ValueType.STRING, + self.prediction_label_column: ValueType.STRING + } + + def ground_truth_types(self) -> Dict[str, ValueType]: + return { self.actual_score_column: ValueType.FLOAT64, - self.actual_label_column: ValueType.STRING, + self.actual_label_column: ValueType.STRING } @@ -230,6 +245,10 @@ def prediction_types(self) -> Dict[str, ValueType]: return { self.rank_column: ValueType.INT64, self.prediction_group_id_column: ValueType.STRING, + } + + def ground_truth_types(self) -> Dict[str, ValueType]: + return { self.relevance_score_column: ValueType.FLOAT64, } @@ -250,3 +269,11 @@ class InferenceSchema: @property def feature_columns(self) -> List[str]: return list(self.feature_types.keys()) + + @property + def prediction_columns(self) -> List[str]: + return list(self.model_prediction_output.prediction_types().keys()) + + @property + def ground_truth_columns(self) -> List[str]: + return list(self.model_prediction_output.ground_truth_types().keys())