diff --git a/components/clp-py-utils/clp_py_utils/clp_config.py b/components/clp-py-utils/clp_py_utils/clp_config.py index ba8a8728cf..4d408dcc75 100644 --- a/components/clp-py-utils/clp_py_utils/clp_config.py +++ b/components/clp-py-utils/clp_py_utils/clp_config.py @@ -1,7 +1,7 @@ import os import pathlib from enum import auto -from typing import Any, Literal, Optional, Set, Union +from typing import Annotated, Any, Literal, Optional, Set, Union from dotenv import dotenv_values from pydantic import ( @@ -14,7 +14,7 @@ ) from strenum import KebabCaseStrEnum, LowercaseStrEnum -from .clp_logging import get_valid_logging_level, is_valid_logging_level +from .clp_logging import LoggingLevel from .core import ( get_config_value, make_config_path_absolute, @@ -98,12 +98,27 @@ CLP_QUEUE_PASS_ENV_VAR_NAME = "CLP_QUEUE_PASS" CLP_REDIS_PASS_ENV_VAR_NAME = "CLP_REDIS_PASS" +# Generic types +NonEmptyStr = Annotated[str, Field(min_length=1)] +PositiveFloat = Annotated[float, Field(gt=0)] +PositiveInt = Annotated[int, Field(gt=0)] +# Specific types +# TODO: Replace this with pydantic_extra_types.domain.DomainStr. +DomainStr = NonEmptyStr +Port = Annotated[int, Field(gt=0, lt=2**16)] +ZstdCompressionLevel = Annotated[int, Field(ge=1, le=19)] + class StorageEngine(KebabCaseStrEnum): CLP = auto() CLP_S = auto() +class DatabaseEngine(KebabCaseStrEnum): + MARIADB = auto() + MYSQL = auto() + + class QueryEngine(KebabCaseStrEnum): CLP = auto() CLP_S = auto() @@ -122,33 +137,9 @@ class AwsAuthType(LowercaseStrEnum): ec2 = auto() -VALID_STORAGE_ENGINES = [storage_engine.value for storage_engine in StorageEngine] -VALID_QUERY_ENGINES = [query_engine.value for query_engine in QueryEngine] - - class Package(BaseModel): - storage_engine: str = "clp" - query_engine: str = "clp" - - @field_validator("storage_engine") - @classmethod - def validate_storage_engine(cls, value): - if value not in VALID_STORAGE_ENGINES: - raise ValueError( - f"package.storage_engine must be one of the following" - f" {'|'.join(VALID_STORAGE_ENGINES)}" - ) - return value - - @field_validator("query_engine") - @classmethod - def validate_query_engine(cls, value): - if value not in VALID_QUERY_ENGINES: - raise ValueError( - f"package.query_engine must be one of the following" - f" {'|'.join(VALID_QUERY_ENGINES)}" - ) - return value + storage_engine: StorageEngine = StorageEngine.CLP + query_engine: QueryEngine = QueryEngine.CLP @model_validator(mode="after") def validate_query_engine_package_compatibility(self): @@ -172,49 +163,25 @@ def validate_query_engine_package_compatibility(self): return self + def dump_to_primitive_dict(self): + d = self.model_dump() + d["storage_engine"] = d["storage_engine"].value + d["query_engine"] = d["query_engine"].value + return d + class Database(BaseModel): - type: str = "mariadb" - host: str = "localhost" - port: int = 3306 - name: str = "clp-db" - ssl_cert: Optional[str] = None + type: DatabaseEngine = DatabaseEngine.MARIADB + host: DomainStr = "localhost" + port: Port = 3306 + name: NonEmptyStr = "clp-db" + ssl_cert: Optional[NonEmptyStr] = None auto_commit: bool = False compress: bool = True username: Optional[str] = None password: Optional[str] = None - @field_validator("type") - @classmethod - def validate_type(cls, value): - supported_database_types = ["mysql", "mariadb"] - if value not in supported_database_types: - raise ValueError( - f"database.type must be one of the following {'|'.join(supported_database_types)}" - ) - return value - - @field_validator("name") - @classmethod - def validate_name(cls, value): - if "" == value: - raise ValueError("database.name cannot be empty.") - return value - - @field_validator("host") - @classmethod - def validate_host(cls, value): - if "" == value: - raise ValueError("database.host cannot be empty.") - return value - - @field_validator("port") - @classmethod - def validate_port(cls, value): - _validate_port(cls, value) - return value - def ensure_credentials_loaded(self): if self.username is None or self.password is None: raise ValueError("Credentials not loaded.") @@ -249,7 +216,7 @@ def get_clp_connection_params_and_type(self, disable_localhost_socket_connection connection_params_and_type = { # NOTE: clp-core does not distinguish between mysql and mariadb - "type": "mysql", + "type": DatabaseEngine.MYSQL.value, "host": host, "port": self.port, "username": self.username, @@ -264,7 +231,9 @@ def get_clp_connection_params_and_type(self, disable_localhost_socket_connection return connection_params_and_type def dump_to_primitive_dict(self): - return self.model_dump(exclude={"username", "password"}) + d = self.model_dump(exclude={"username", "password"}) + d["type"] = d["type"].value + return d def load_credentials_from_file(self, credentials_file_path: pathlib.Path): config = read_yaml_config_file(credentials_file_path) @@ -286,107 +255,35 @@ def load_credentials_from_env(self): self.password = _get_env_var(CLP_DB_PASS_ENV_VAR_NAME) -def _validate_logging_level(cls, value): - if not is_valid_logging_level(value): - raise ValueError( - f"{cls.__name__}: '{value}' is not a valid logging level. Use one of" - f" {get_valid_logging_level()}" - ) - - -def _validate_host(cls, value): - if "" == value: - raise ValueError(f"{cls.__name__}.host cannot be empty.") - - -def _validate_port(cls, value): - min_valid_port = 0 - max_valid_port = 2**16 - 1 - if min_valid_port > value or max_valid_port < value: - raise ValueError( - f"{cls.__name__}.port is not within valid range " f"{min_valid_port}-{max_valid_port}." - ) - - class CompressionScheduler(BaseModel): - jobs_poll_delay: float = 0.1 # seconds - logging_level: str = "INFO" - - @field_validator("logging_level") - @classmethod - def validate_logging_level(cls, value): - _validate_logging_level(cls, value) - return value + jobs_poll_delay: PositiveFloat = 0.1 # seconds + logging_level: LoggingLevel = "INFO" class QueryScheduler(BaseModel): - host: str = "localhost" - port: int = 7000 - jobs_poll_delay: float = 0.1 # seconds - num_archives_to_search_per_sub_job: int = 16 - logging_level: str = "INFO" - - @field_validator("logging_level") - @classmethod - def validate_logging_level(cls, value): - _validate_logging_level(cls, value) - return value - - @field_validator("host") - @classmethod - def validate_host(cls, value): - if "" == value: - raise ValueError(f"Cannot be empty.") - return value - - @field_validator("port") - @classmethod - def validate_port(cls, value): - _validate_port(cls, value) - return value + host: DomainStr = "localhost" + port: Port = 7000 + jobs_poll_delay: PositiveFloat = 0.1 # seconds + num_archives_to_search_per_sub_job: PositiveInt = 16 + logging_level: LoggingLevel = "INFO" class CompressionWorker(BaseModel): - logging_level: str = "INFO" - - @field_validator("logging_level") - @classmethod - def validate_logging_level(cls, value): - _validate_logging_level(cls, value) - return value + logging_level: LoggingLevel = "INFO" class QueryWorker(BaseModel): - logging_level: str = "INFO" - - @field_validator("logging_level") - @classmethod - def validate_logging_level(cls, value): - _validate_logging_level(cls, value) - return value + logging_level: LoggingLevel = "INFO" class Redis(BaseModel): - host: str = "localhost" - port: int = 6379 + host: DomainStr = "localhost" + port: Port = 6379 query_backend_database: int = 0 compression_backend_database: int = 1 # redis can perform authentication without a username password: Optional[str] = None - @field_validator("host") - @classmethod - def validate_host(cls, value): - if "" == value: - raise ValueError(f"{REDIS_COMPONENT_NAME}.host cannot be empty.") - return value - - @field_validator("port") - @classmethod - def validate_port(cls, value): - _validate_port(cls, value) - return value - def dump_to_primitive_dict(self): return self.model_dump(exclude={"password"}) @@ -409,105 +306,30 @@ def load_credentials_from_env(self): class Reducer(BaseModel): - host: str = "localhost" - base_port: int = 14009 - logging_level: str = "INFO" - upsert_interval: int = 100 # milliseconds - - @field_validator("host") - @classmethod - def validate_host(cls, value): - if "" == value: - raise ValueError(f"{value} cannot be empty") - return value - - @field_validator("logging_level") - @classmethod - def validate_logging_level(cls, value): - _validate_logging_level(cls, value) - return value - - @field_validator("base_port") - @classmethod - def validate_base_port(cls, value): - _validate_port(cls, value) - return value - - @field_validator("upsert_interval") - @classmethod - def validate_upsert_interval(cls, value): - if not value > 0: - raise ValueError(f"{value} is not greater than zero") - return value + host: DomainStr = "localhost" + base_port: Port = 14009 + logging_level: LoggingLevel = "INFO" + upsert_interval: PositiveInt = 100 # milliseconds class ResultsCache(BaseModel): - host: str = "localhost" - port: int = 27017 - db_name: str = "clp-query-results" - stream_collection_name: str = "stream-files" - retention_period: Optional[int] = 60 - - @field_validator("host") - @classmethod - def validate_host(cls, value): - if "" == value: - raise ValueError(f"{RESULTS_CACHE_COMPONENT_NAME}.host cannot be empty.") - return value - - @field_validator("port") - @classmethod - def validate_port(cls, value): - _validate_port(cls, value) - return value - - @field_validator("db_name") - @classmethod - def validate_db_name(cls, value): - if "" == value: - raise ValueError(f"{RESULTS_CACHE_COMPONENT_NAME}.db_name cannot be empty.") - return value - - @field_validator("stream_collection_name") - @classmethod - def validate_stream_collection_name(cls, value): - if "" == value: - raise ValueError( - f"{RESULTS_CACHE_COMPONENT_NAME}.stream_collection_name cannot be empty." - ) - return value - - @field_validator("retention_period") - @classmethod - def validate_retention_period(cls, value): - if value is not None and value <= 0: - raise ValueError("retention_period must be greater than 0") - return value + host: DomainStr = "localhost" + port: Port = 27017 + db_name: NonEmptyStr = "clp-query-results" + stream_collection_name: NonEmptyStr = "stream-files" + retention_period: Optional[PositiveInt] = 60 def get_uri(self): return f"mongodb://{self.host}:{self.port}/{self.db_name}" class Queue(BaseModel): - host: str = "localhost" - port: int = 5672 + host: DomainStr = "localhost" + port: Port = 5672 - username: Optional[str] = None + username: Optional[NonEmptyStr] = None password: Optional[str] = None - @field_validator("host") - @classmethod - def validate_host(cls, value): - if "" == value: - raise ValueError(f"{QUEUE_COMPONENT_NAME}.host cannot be empty.") - return value - - @field_validator("port") - @classmethod - def validate_port(cls, value): - _validate_port(cls, value) - return value - def dump_to_primitive_dict(self): return self.model_dump(exclude={"username", "password"}) @@ -532,23 +354,9 @@ def load_credentials_from_env(self): class S3Credentials(BaseModel): - access_key_id: str - secret_access_key: str - session_token: Optional[str] = None - - @field_validator("access_key_id") - @classmethod - def validate_access_key_id(cls, value): - if "" == value: - raise ValueError("access_key_id cannot be empty") - return value - - @field_validator("secret_access_key") - @classmethod - def validate_secret_access_key(cls, value): - if "" == value: - raise ValueError("secret_access_key cannot be empty") - return value + access_key_id: NonEmptyStr + secret_access_key: NonEmptyStr + session_token: Optional[NonEmptyStr] = None class AwsAuthentication(BaseModel): @@ -558,7 +366,7 @@ class AwsAuthentication(BaseModel): AwsAuthType.env_vars.value, AwsAuthType.ec2.value, ] - profile: Optional[str] = None + profile: Optional[NonEmptyStr] = None credentials: Optional[S3Credentials] = None @model_validator(mode="before") @@ -590,25 +398,11 @@ def validate_authentication(cls, data): class S3Config(BaseModel): - region_code: str - bucket: str + region_code: NonEmptyStr + bucket: NonEmptyStr key_prefix: str aws_authentication: AwsAuthentication - @field_validator("region_code") - @classmethod - def validate_region_code(cls, value): - if "" == value: - raise ValueError("region_code cannot be empty") - return value - - @field_validator("bucket") - @classmethod - def validate_bucket(cls, value): - if "" == value: - raise ValueError("bucket cannot be empty") - return value - class S3IngestionConfig(BaseModel): type: Literal[StorageType.S3.value] = StorageType.S3.value @@ -713,54 +507,12 @@ def _set_directory_for_storage_config( class ArchiveOutput(BaseModel): storage: Union[ArchiveFsStorage, ArchiveS3Storage] = ArchiveFsStorage() - target_archive_size: int = 256 * 1024 * 1024 # 256 MB - target_dictionaries_size: int = 32 * 1024 * 1024 # 32 MB - target_encoded_file_size: int = 256 * 1024 * 1024 # 256 MB - target_segment_size: int = 256 * 1024 * 1024 # 256 MB - compression_level: int = 3 - retention_period: Optional[int] = None - - @field_validator("target_archive_size") - @classmethod - def validate_target_archive_size(cls, value): - if value <= 0: - raise ValueError("target_archive_size must be greater than 0") - return value - - @field_validator("target_dictionaries_size") - @classmethod - def validate_target_dictionaries_size(cls, value): - if value <= 0: - raise ValueError("target_dictionaries_size must be greater than 0") - return value - - @field_validator("target_encoded_file_size") - @classmethod - def validate_target_encoded_file_size(cls, value): - if value <= 0: - raise ValueError("target_encoded_file_size must be greater than 0") - return value - - @field_validator("target_segment_size") - @classmethod - def validate_target_segment_size(cls, value): - if value <= 0: - raise ValueError("target_segment_size must be greater than 0") - return value - - @field_validator("compression_level") - @classmethod - def validate_compression_level(cls, value): - if value < 1 or value > 19: - raise ValueError("compression_level must be a value from 1 to 19") - return value - - @field_validator("retention_period") - @classmethod - def validate_retention_period(cls, value): - if value is not None and value <= 0: - raise ValueError("retention_period must be greater than 0") - return value + target_archive_size: PositiveInt = 256 * 1024 * 1024 # 256 MB + target_dictionaries_size: PositiveInt = 32 * 1024 * 1024 # 32 MB + target_encoded_file_size: PositiveInt = 256 * 1024 * 1024 # 256 MB + target_segment_size: PositiveInt = 256 * 1024 * 1024 # 256 MB + compression_level: ZstdCompressionLevel = 3 + retention_period: Optional[PositiveInt] = None def set_directory(self, directory: pathlib.Path): _set_directory_for_storage_config(self.storage, directory) @@ -776,14 +528,7 @@ def dump_to_primitive_dict(self): class StreamOutput(BaseModel): storage: Union[StreamFsStorage, StreamS3Storage] = StreamFsStorage() - target_uncompressed_size: int = 128 * 1024 * 1024 - - @field_validator("target_uncompressed_size") - @classmethod - def validate_target_uncompressed_size(cls, value): - if value <= 0: - raise ValueError("target_uncompressed_size must be greater than 0") - return value + target_uncompressed_size: PositiveInt = 128 * 1024 * 1024 def set_directory(self, directory: pathlib.Path): _set_directory_for_storage_config(self.storage, directory) @@ -798,73 +543,27 @@ def dump_to_primitive_dict(self): class WebUi(BaseModel): - host: str = "localhost" - port: int = 4000 - results_metadata_collection_name: str = "results-metadata" - rate_limit: int = 1000 - - @field_validator("host") - @classmethod - def validate_host(cls, value): - _validate_host(cls, value) - return value - - @field_validator("port") - @classmethod - def validate_port(cls, value): - _validate_port(cls, value) - return value - - @field_validator("results_metadata_collection_name") - @classmethod - def validate_results_metadata_collection_name(cls, value): - if "" == value: - raise ValueError( - f"{WEBUI_COMPONENT_NAME}.results_metadata_collection_name cannot be empty." - ) - return value - - @field_validator("rate_limit") - @classmethod - def validate_rate_limit(cls, value): - if value <= 0: - raise ValueError(f"rate_limit must be greater than 0") - return value + host: DomainStr = "localhost" + port: Port = 4000 + results_metadata_collection_name: NonEmptyStr = "results-metadata" + rate_limit: PositiveInt = 1000 class SweepInterval(BaseModel): model_config = ConfigDict(extra="forbid") - archive: int = Field(default=60, gt=0) - search_result: int = Field(default=30, gt=0) + archive: PositiveInt = 60 + search_result: PositiveInt = 30 class GarbageCollector(BaseModel): - logging_level: str = "INFO" + logging_level: LoggingLevel = "INFO" sweep_interval: SweepInterval = SweepInterval() - @field_validator("logging_level") - @classmethod - def validate_logging_level(cls, value): - _validate_logging_level(cls, value) - return value - class Presto(BaseModel): - host: str - port: int - - @field_validator("host") - @classmethod - def validate_host(cls, value): - _validate_host(cls, value) - return value - - @field_validator("port") - @classmethod - def validate_port(cls, value): - _validate_port(cls, value) - return value + host: DomainStr + port: Port def _get_env_var(name: str) -> str: @@ -875,7 +574,7 @@ def _get_env_var(name: str) -> str: class CLPConfig(BaseModel): - container_image_ref: Optional[str] = None + container_image_ref: Optional[NonEmptyStr] = None logs_input: Union[FsIngestionConfig, S3IngestionConfig] = FsIngestionConfig() @@ -1035,15 +734,16 @@ def get_runnable_components(self) -> Set[str]: return ALL_COMPONENTS def dump_to_primitive_dict(self): - custom_serialized_fields = ( + custom_serialized_fields = { + "package", "database", "queue", "redis", "logs_input", "archive_output", "stream_output", - ) - d = self.model_dump(exclude=set(custom_serialized_fields)) + } + d = self.model_dump(exclude=custom_serialized_fields) for key in custom_serialized_fields: d[key] = getattr(self, key).dump_to_primitive_dict() diff --git a/components/clp-py-utils/clp_py_utils/clp_logging.py b/components/clp-py-utils/clp_py_utils/clp_logging.py index dfe2ae4d8e..23d58602a2 100644 --- a/components/clp-py-utils/clp_py_utils/clp_logging.py +++ b/components/clp-py-utils/clp_py_utils/clp_logging.py @@ -1,13 +1,14 @@ import logging +from typing import get_args, Literal -LOGGING_LEVEL_MAPPING = { - "INFO": logging.INFO, - "DEBUG": logging.DEBUG, - "WARN": logging.WARNING, - "WARNING": logging.WARNING, - "ERROR": logging.ERROR, - "CRITICAL": logging.CRITICAL, -} +LoggingLevel = Literal[ + "INFO", + "DEBUG", + "WARN", + "WARNING", + "ERROR", + "CRITICAL", +] def get_logging_formatter(): @@ -25,17 +26,10 @@ def get_logger(name: str): return logger -def get_valid_logging_level(): - return [i for i in LOGGING_LEVEL_MAPPING.keys()] - - -def is_valid_logging_level(level: str): - return level in LOGGING_LEVEL_MAPPING - - def set_logging_level(logger: logging.Logger, level: str): - if not is_valid_logging_level(level): + if level not in get_args(LoggingLevel): logger.warning(f"Invalid logging level: {level}, using INFO as default") logger.setLevel(logging.INFO) return - logger.setLevel(LOGGING_LEVEL_MAPPING[level]) + + logger.setLevel(level)