diff --git a/wavefront/server/apps/floware/floware/server.py b/wavefront/server/apps/floware/floware/server.py index bea92564..dca7c874 100644 --- a/wavefront/server/apps/floware/floware/server.py +++ b/wavefront/server/apps/floware/floware/server.py @@ -26,6 +26,7 @@ from common_module.log.logger import logger from common_module.prometheus.prometheus_middleware import PrometheusMiddleware from common_module.response_formatter import ResponseFormatter +from db_repo_module.cache.azure_redis_auth import patch_redis_for_azure from db_repo_module.database.connection import DatabaseClient from db_repo_module.db_repo_container import DatabaseModuleContainer from fastapi import HTTPException @@ -224,6 +225,7 @@ @asynccontextmanager async def lifespan(app: FastAPI): + patch_redis_for_azure() # Startup code (runs before the application starts) logger.info('Starting application...') diff --git a/wavefront/server/background_jobs/celery_worker/celery_worker/celery_app.py b/wavefront/server/background_jobs/celery_worker/celery_worker/celery_app.py index 01a125cc..cefc9745 100644 --- a/wavefront/server/background_jobs/celery_worker/celery_worker/celery_app.py +++ b/wavefront/server/background_jobs/celery_worker/celery_worker/celery_app.py @@ -1,7 +1,16 @@ from celery import Celery +from celery.signals import worker_process_init from celery_worker.env import CELERY_BROKER_URL, CELERY_RESULT_BACKEND + +@worker_process_init.connect +def setup_azure_redis_auth(**kwargs): + from db_repo_module.cache.azure_redis_auth import patch_redis_for_azure + + patch_redis_for_azure() + + app = Celery('async_executor') app.conf.update( broker_url=CELERY_BROKER_URL, @@ -20,4 +29,11 @@ task_reject_on_worker_lost=True, # Re-queue on worker crash worker_prefetch_multiplier=1, # Fair task distribution task_track_started=True, + task_default_queue='{celery}', + worker_enable_remote_control=False, + broker_transport_options={ + 'unacked_key': '{celery}.unacked', + 'unacked_index_key': '{celery}.unacked_index', + 'unacked_mutex_key': '{celery}.unacked_mutex', + }, ) diff --git a/wavefront/server/background_jobs/celery_worker/celery_worker/worker_setup.py b/wavefront/server/background_jobs/celery_worker/celery_worker/worker_setup.py index 6fd19d66..b6697452 100644 --- a/wavefront/server/background_jobs/celery_worker/celery_worker/worker_setup.py +++ b/wavefront/server/background_jobs/celery_worker/celery_worker/worker_setup.py @@ -15,6 +15,7 @@ create_api_services_container, ) from agents_module.agents_container import AgentsContainer +from llm_inference_config_module.container import LlmInferenceConfigContainer from agents_module.services.agent_inference_service import AgentInferenceService from agents_module.services.workflow_inference_service import WorkflowInferenceService from common_module.common_container import CommonContainer @@ -124,6 +125,11 @@ def get_services() -> WorkerServices: message_processor_bucket_name=bucket_name, ) + llm_inference_config_container = LlmInferenceConfigContainer( + db_client=db_repo_container.db_client, + cache_manager=db_repo_container.cache_manager, + ) + agents_container = AgentsContainer( db_client=db_repo_container.db_client, cloud_storage_manager=common_container.cloud_storage_manager, @@ -139,6 +145,7 @@ def get_services() -> WorkerServices: api_services_manager=api_services_container.api_service_manager, async_agentic_execution_repository=db_repo_container.async_agentic_execution_repository, executions_bucket=AGENTIC_EXECUTIONS_BUCKET, + llm_inference_config_service=llm_inference_config_container.llm_inference_config_service, ) # Inject config values from env vars so services like AgentCrudService @@ -162,6 +169,7 @@ def get_services() -> WorkerServices: agent_repository=db_repo_container.agent_repository, workflow_repository=db_repo_container.workflow_repository, async_agentic_execution_service=agents_container.async_agentic_execution_service, + cache_manager=db_repo_container.cache_manager, ) triggers_container.config.from_dict( { diff --git a/wavefront/server/background_jobs/celery_worker/pyproject.toml b/wavefront/server/background_jobs/celery_worker/pyproject.toml index 8915f899..64847004 100644 --- a/wavefront/server/background_jobs/celery_worker/pyproject.toml +++ b/wavefront/server/background_jobs/celery_worker/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "api-services-module", "common-module", "triggers-module", + "llm-inference-config-module", "celery[redis]>=5.4.0,<6.0.0", "python-dotenv>=1.1.0,<2.0.0", ] @@ -28,7 +29,8 @@ flo-utils = { workspace = true } tools-module = { workspace = true } api-services-module = { workspace = true } common-module = { workspace = true } -triggers-module = { workspace = true } +triggers-module = { workspace = true } +llm-inference-config-module = { workspace = true } [tool.uv] package = true diff --git a/wavefront/server/docker/celery_worker.Dockerfile b/wavefront/server/docker/celery_worker.Dockerfile index 96cbb7d0..bd0e7b0f 100644 --- a/wavefront/server/docker/celery_worker.Dockerfile +++ b/wavefront/server/docker/celery_worker.Dockerfile @@ -42,4 +42,4 @@ USER celery WORKDIR /app/background_jobs/celery_worker -CMD ["uv", "run", "celery", "-A", "celery_worker.celery_app", "worker", "--loglevel=info", "--pool=solo"] +CMD ["uv", "run", "celery", "-A", "celery_worker.celery_app", "worker", "--loglevel=info", "--pool=solo", "--without-mingle", "--without-gossip"] diff --git a/wavefront/server/modules/agents_module/agents_module/services/agent_inference_service.py b/wavefront/server/modules/agents_module/agents_module/services/agent_inference_service.py index 44fa192e..27f42c5f 100644 --- a/wavefront/server/modules/agents_module/agents_module/services/agent_inference_service.py +++ b/wavefront/server/modules/agents_module/agents_module/services/agent_inference_service.py @@ -373,6 +373,7 @@ async def perform_inference_v2( output_json_enabled: bool = True, access_token: Optional[str] = None, app_key: Optional[str] = None, + llm_config: Optional[LlmInferenceConfig] = None, ) -> tuple[List[BaseMessage], float, str]: """ Complete inference workflow (v2): fetch agent from DB + cloud storage, run inference @@ -415,8 +416,9 @@ async def perform_inference_v2( f'Retrieved agent - namespace: {namespace}, name: {name}, agent_id: {agent_id}' ) - # Resolve rootflo model_id references from the YAML, if any - llm_config = await self._resolve_rootflo_llm_config(yaml_content) + # Use caller-supplied config or fall back to resolving from the YAML + if llm_config is None: + llm_config = await self._resolve_rootflo_llm_config(yaml_content) # Create agent from YAML with optional LLM override and tools agent = await self.create_agent_from_yaml( diff --git a/wavefront/server/modules/agents_module/agents_module/utils/celery_client.py b/wavefront/server/modules/agents_module/agents_module/utils/celery_client.py index 9d7a6458..e6bbd2a7 100644 --- a/wavefront/server/modules/agents_module/agents_module/utils/celery_client.py +++ b/wavefront/server/modules/agents_module/agents_module/utils/celery_client.py @@ -7,4 +7,6 @@ def get_celery_client() -> Celery: broker_url = os.getenv('CELERY_BROKER_URL') if not broker_url: raise RuntimeError('Missing required env var: CELERY_BROKER_URL') - return Celery('async_executor', broker=broker_url) + app = Celery('async_executor', broker=broker_url) + app.conf.task_default_queue = '{celery}' + return app diff --git a/wavefront/server/modules/db_repo_module/db_repo_module/cache/azure_redis_auth.py b/wavefront/server/modules/db_repo_module/db_repo_module/cache/azure_redis_auth.py new file mode 100644 index 00000000..c94180d7 --- /dev/null +++ b/wavefront/server/modules/db_repo_module/db_repo_module/cache/azure_redis_auth.py @@ -0,0 +1,49 @@ +import os +import threading + +import redis + +_patched = False + + +def patch_redis_for_azure() -> None: + global _patched + if _patched or os.getenv('CLOUD_PROVIDER', '').lower() != 'azure': + return + + from redis_entraid.cred_provider import create_from_default_azure_credential + from redis.credentials import CredentialProvider + + _inner = create_from_default_azure_credential(('https://redis.azure.com/.default',)) + + class _TimedProvider(CredentialProvider): + def get_credentials(self): + result = [] + exc = [] + + def _fetch(): + try: + result.append(_inner.get_credentials()) + except Exception as e: + exc.append(e) + + t = threading.Thread(target=_fetch, daemon=True) + t.start() + t.join(timeout=10) + if t.is_alive(): + raise TimeoutError('Azure Redis token fetch timed out after 10s') + if exc: + raise exc[0] + return result[0] + + provider = _TimedProvider() + original_init = redis.ConnectionPool.__init__ + + def patched_init(self, *args, **kw): + kw.pop('password', None) + kw.pop('username', None) + kw['credential_provider'] = provider + original_init(self, *args, **kw) + + redis.ConnectionPool.__init__ = patched_init + _patched = True diff --git a/wavefront/server/modules/db_repo_module/db_repo_module/cache/cache_manager.py b/wavefront/server/modules/db_repo_module/db_repo_module/cache/cache_manager.py index 9bf47c21..d82e11b3 100644 --- a/wavefront/server/modules/db_repo_module/db_repo_module/cache/cache_manager.py +++ b/wavefront/server/modules/db_repo_module/db_repo_module/cache/cache_manager.py @@ -2,8 +2,6 @@ import time from typing import Any, List, Optional, Union -from azure.core.exceptions import ClientAuthenticationError -from azure.identity import DefaultAzureCredential from common_module.common_cache import CommonCache from common_module.log.logger import logger from redis import Connection @@ -13,34 +11,12 @@ from redis import RedisError from redis import SSLConnection from redis import TimeoutError -from redis.credentials import CredentialProvider from tenacity import retry from tenacity import retry_if_exception_type from tenacity import stop_after_attempt from tenacity import wait_exponential -class AzureManagedRedisProvider(CredentialProvider): - """ - Adapter to bridge Azure Identity with Redis CredentialProvider. - Azure Managed Redis requires 'default' as the username and the - Entra ID access token as the password. - """ - - def __init__(self): - self.credential = DefaultAzureCredential() - self.scope = 'https://redis.azure.com/.default' - self.username = os.getenv('REDIS_USERNAME', 'default') - - def get_credentials(self): - try: - token = self.credential.get_token(self.scope) - return (self.username, token.token) - except ClientAuthenticationError as e: - logger.error(f'Azure authentication failed: {e}') - raise - - class CacheManager(CommonCache): def __init__( self, @@ -88,7 +64,6 @@ def _create_connection_pool( port = int(os.getenv('REDIS_PORT', 6379)) protocol = os.getenv('REDIS_PROTOCOL', 'redis') password = os.getenv('REDIS_PASSWORD') - cloud_provider = os.getenv('CLOUD_PROVIDER', '').lower() connection_class = Connection if protocol == 'rediss' or port == 10000: @@ -110,12 +85,7 @@ def _create_connection_pool( 'decode_responses': True, } - if cloud_provider == 'azure' and not password: - logger.info( - 'Configuring Azure Entra ID (Workload Identity) authentication' - ) - pool_kwargs['credential_provider'] = AzureManagedRedisProvider() - elif password: + if password: pool_kwargs['password'] = password return ConnectionPool(**pool_kwargs) diff --git a/wavefront/server/modules/db_repo_module/pyproject.toml b/wavefront/server/modules/db_repo_module/pyproject.toml index cd083132..7a083da8 100644 --- a/wavefront/server/modules/db_repo_module/pyproject.toml +++ b/wavefront/server/modules/db_repo_module/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "sqlalchemy>=2.0.36,<3.0.0", "alembic>=1.14.1,<2.0.0", "redis>=5.2.1,<6.0.0", + "redis-entraid>=1.2.1,<2.0.0", "azure-identity>=1.17.0,<2.0.0", "pgvector>=0.4.1", "tenacity>=8.1.0,<9.0.0", diff --git a/wavefront/server/uv.lock b/wavefront/server/uv.lock index 93f054ca..501a2357 100644 --- a/wavefront/server/uv.lock +++ b/wavefront/server/uv.lock @@ -759,6 +759,7 @@ dependencies = [ { name = "db-repo-module" }, { name = "flo-cloud" }, { name = "flo-utils" }, + { name = "llm-inference-config-module" }, { name = "python-dotenv" }, { name = "tools-module" }, { name = "triggers-module" }, @@ -773,6 +774,7 @@ requires-dist = [ { name = "db-repo-module", editable = "modules/db_repo_module" }, { name = "flo-cloud", editable = "packages/flo_cloud" }, { name = "flo-utils", editable = "packages/flo_utils" }, + { name = "llm-inference-config-module", editable = "modules/llm_inference_config_module" }, { name = "python-dotenv", specifier = ">=1.1.0,<2.0.0" }, { name = "tools-module", editable = "modules/tools_module" }, { name = "triggers-module", editable = "modules/triggers_module" }, @@ -1131,6 +1133,7 @@ dependencies = [ { name = "pgvector" }, { name = "psycopg", extra = ["binary", "pool"] }, { name = "redis" }, + { name = "redis-entraid" }, { name = "sqlalchemy" }, { name = "tenacity" }, ] @@ -1144,6 +1147,7 @@ requires-dist = [ { name = "pgvector", specifier = ">=0.4.1" }, { name = "psycopg", extras = ["binary", "pool"], specifier = ">=3.2.3,<4.0.0" }, { name = "redis", specifier = ">=5.2.1,<6.0.0" }, + { name = "redis-entraid", specifier = ">=1.2.1,<2.0.0" }, { name = "sqlalchemy", specifier = ">=2.0.36,<3.0.0" }, { name = "tenacity", specifier = ">=8.1.0,<9.0.0" }, ] @@ -4654,11 +4658,11 @@ wheels = [ [[package]] name = "pyjwt" -version = "2.10.1" +version = "2.13.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3b/81/58d0ac84e1ef3a3843791d6954d94c0b33d526c75eeb1efbce9d0a4c4077/pyjwt-2.13.0.tar.gz", hash = "sha256:41571c89ca91598c79e8ef18a2d07367d4810fbbd6f637794879baf1b7703423", size = 107515, upload-time = "2026-05-21T19:54:36.618Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, + { url = "https://files.pythonhosted.org/packages/a3/5e/ecf12fdb62546d64385c158514e9b2b671f7832108ef2ecd2020ce0af2d1/pyjwt-2.13.0-py3-none-any.whl", hash = "sha256:66adcc2aff09b3f1bbd95fc1e1577df8ac8723c978552fd43304c8a290ac5728", size = 31274, upload-time = "2026-05-21T19:54:35.362Z" }, ] [package.optional-dependencies] @@ -4917,6 +4921,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/26/5c5fa0e83c3621db835cfc1f1d789b37e7fa99ed54423b5f519beb931aa7/redis-5.3.1-py3-none-any.whl", hash = "sha256:dc1909bd24669cc31b5f67a039700b16ec30571096c5f1f0d9d2324bff31af97", size = 272833, upload-time = "2025-07-25T08:06:26.317Z" }, ] +[[package]] +name = "redis-entraid" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "azure-identity" }, + { name = "msal" }, + { name = "pyjwt" }, + { name = "redis" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a2/a7/0ddaeb27b33c76709e05a12b3bbeefce893c82a3a830146608d6fe620000/redis_entraid-1.2.1.tar.gz", hash = "sha256:a7c479ce46e6edb35bce9dd804d1cad7be99a3330815cfe028a648b486a10b41", size = 9792, upload-time = "2026-06-03T11:38:55.613Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cd/ca/01b8607102de756b270d3f6befeee700bd82dace4303d6f47ce0f53c11b0/redis_entraid-1.2.1-py3-none-any.whl", hash = "sha256:9de7e4a716b156d966a2d6bb5b5ccd64a692db30ae21fe3987f57d233793d558", size = 7967, upload-time = "2026-06-03T11:38:54.497Z" }, +] + [[package]] name = "redshift-connector" version = "2.1.8"