From 1810ac25c63a8f9d298301b1f89af79dddddeac7 Mon Sep 17 00:00:00 2001 From: Leonardo Parente <23251360+leoparente@users.noreply.github.com> Date: Mon, 29 Sep 2025 17:00:27 -0300 Subject: [PATCH 1/2] feat: add worker kwargs --- .../device-discovery-lint-tests.yaml | 3 +- .github/workflows/worker-lint-tests.yaml | 3 +- worker/tests/nbl-custom/nbl_custom/impl.py | 3 + worker/tests/policy/test_runner.py | 138 +++++++++++++++++- worker/tests/test_backend.py | 13 ++ worker/worker/backend.py | 3 +- worker/worker/policy/runner.py | 60 +++++++- 7 files changed, 211 insertions(+), 12 deletions(-) diff --git a/.github/workflows/device-discovery-lint-tests.yaml b/.github/workflows/device-discovery-lint-tests.yaml index 45e6776..c6de39b 100644 --- a/.github/workflows/device-discovery-lint-tests.yaml +++ b/.github/workflows/device-discovery-lint-tests.yaml @@ -57,5 +57,4 @@ jobs: - name: Lint with Ruff run: | - ruff check --output-format=github device_discovery/ tests/ - continue-on-error: true \ No newline at end of file + ruff check --output-format=github device_discovery/ tests/ \ No newline at end of file diff --git a/.github/workflows/worker-lint-tests.yaml b/.github/workflows/worker-lint-tests.yaml index cb5523d..798f840 100644 --- a/.github/workflows/worker-lint-tests.yaml +++ b/.github/workflows/worker-lint-tests.yaml @@ -57,5 +57,4 @@ jobs: - name: Lint with Ruff run: | - ruff check --output-format=github worker/ tests/ - continue-on-error: true \ No newline at end of file + ruff check --output-format=github worker/ tests/ \ No newline at end of file diff --git a/worker/tests/nbl-custom/nbl_custom/impl.py b/worker/tests/nbl-custom/nbl_custom/impl.py index 5ff919c..bf1628c 100644 --- a/worker/tests/nbl-custom/nbl_custom/impl.py +++ b/worker/tests/nbl-custom/nbl_custom/impl.py @@ -73,6 +73,9 @@ def run(self, policy_name: str, policy: Policy) -> Iterable[Entity]: ) entities.append(Entity(device=device)) + cache = kwargs.get("cache", {}) + logger.info(f"Policy '{policy_name}' config: {config}") logger.info(f"Policy '{policy_name}' scope: {scope}") + logger.info(f"Policy '{policy_name}' cache keys: {list(cache.keys())}") return entities diff --git a/worker/tests/policy/test_runner.py b/worker/tests/policy/test_runner.py index 244d61e..cc813c2 100644 --- a/worker/tests/policy/test_runner.py +++ b/worker/tests/policy/test_runner.py @@ -111,6 +111,14 @@ def test_setup_policy_runner_with_cron( mock_start.assert_called_once() mock_add_job.assert_called_once() mock_load_class.assert_called_once() + mock_load_class.return_value.assert_called_once_with() + job_kwargs = mock_add_job.call_args[1]["kwargs"] + assert job_kwargs["schedule"] == "0 * * * *" + assert "cache" in job_kwargs + assert job_kwargs["cache"] is policy_runner.cache + schedule_now_callable = job_kwargs["schedule_now"] + assert schedule_now_callable.__self__ is policy_runner + assert schedule_now_callable.__func__ is PolicyRunner.schedule_now mock_diode_client.assert_called_once() assert policy_runner.status == Status.RUNNING @@ -133,6 +141,13 @@ def test_setup_policy_runner_with_one_time_run( # Verify that DateTrigger is used for one-time scheduling trigger = mock_add_job.call_args[1]["trigger"] mock_load_class.assert_called_once() + mock_load_class.return_value.assert_called_once_with() + job_kwargs = mock_add_job.call_args[1]["kwargs"] + assert set(job_kwargs.keys()) == {"cache", "schedule_now"} + assert job_kwargs["cache"] is policy_runner.cache + schedule_now_callable = job_kwargs["schedule_now"] + assert schedule_now_callable.__self__ is policy_runner + assert schedule_now_callable.__func__ is PolicyRunner.schedule_now mock_diode_client.assert_called_once() assert isinstance(trigger, DateTrigger) assert mock_start.called @@ -156,12 +171,88 @@ def test_setup_policy_runner_dry_run( mock_start.assert_called_once() mock_add_job.assert_called_once() mock_load_class.assert_called_once() + mock_load_class.return_value.assert_called_once_with() + job_kwargs = mock_add_job.call_args[1]["kwargs"] + assert job_kwargs["schedule"] == "0 * * * *" + assert job_kwargs["cache"] is policy_runner.cache + schedule_now_callable = job_kwargs["schedule_now"] + assert schedule_now_callable.__self__ is policy_runner + assert schedule_now_callable.__func__ is PolicyRunner.schedule_now mock_diode_dry_run_client.assert_called_once() assert policy_runner.status == Status.RUNNING + +def test_setup_uses_backend_cache_factory( + policy_runner, sample_diode_config, sample_policy, mock_diode_client +): + """Ensure PolicyRunner uses backend cache factory when available.""" + backend_instance = MagicMock() + backend_instance.setup.return_value = Metadata( + name="my_backend", + app_name="app", + app_version="1.0", + ) + backend_instance.create_cache.return_value = {"token": "value"} + + backend_class = MagicMock(return_value=backend_instance) + + with patch("worker.policy.runner.load_class", return_value=backend_class), patch.object( + policy_runner.scheduler, "start" + ), patch.object(policy_runner.scheduler, "add_job") as mock_add_job: + policy_runner.setup("policy1", sample_diode_config, sample_policy) + + backend_instance.create_cache.assert_called_once() + assert policy_runner.cache == {"token": "value"} + job_kwargs = mock_add_job.call_args[1]["kwargs"] + assert job_kwargs["cache"] == {"token": "value"} + schedule_now_callable = job_kwargs["schedule_now"] + assert schedule_now_callable.__self__ is policy_runner + assert schedule_now_callable.__func__ is PolicyRunner.schedule_now + + +def test_schedule_now_adds_job( + policy_runner, + sample_diode_config, + sample_policy, + mock_diode_client, +): + """Ensure schedule_now queues an immediate job with merged kwargs.""" + backend_instance = MagicMock() + backend_instance.setup.return_value = Metadata( + name="backend", + app_name="app", + app_version="1.0", + ) + + backend_class = MagicMock(return_value=backend_instance) + + with patch("worker.policy.runner.load_class", return_value=backend_class), patch.object( + policy_runner.scheduler, "start" + ), patch.object(policy_runner.scheduler, "add_job"): + policy_runner.setup("policy1", sample_diode_config, sample_policy) + + with patch.object(policy_runner.scheduler, "add_job") as mock_add_job: + policy_runner.schedule_now({"custom": "value"}) + + mock_add_job.assert_called_once() + call_args, call_kwargs = mock_add_job.call_args + run_callable = call_args[0] + assert run_callable.__self__ is policy_runner + assert run_callable.__func__ is PolicyRunner.run + scheduled_trigger = call_kwargs["trigger"] + assert isinstance(scheduled_trigger, DateTrigger) + scheduled_kwargs = call_kwargs["kwargs"] + assert scheduled_kwargs["custom"] == "value" + assert scheduled_kwargs["cache"] is policy_runner.cache + schedule_now_callable = scheduled_kwargs["schedule_now"] + assert schedule_now_callable.__self__ is policy_runner + assert schedule_now_callable.__func__ is PolicyRunner.schedule_now + + def test_run_success(policy_runner, sample_policy, mock_diode_client, mock_backend): """Test the run function for a successful execution.""" policy_runner.name = "test_policy" + policy_runner.cache = {} # Create mock entities entities = [] @@ -177,7 +268,9 @@ def test_run_success(policy_runner, sample_policy, mock_diode_client, mock_backe policy_runner.run(mock_diode_client, mock_backend, sample_policy) # Assertions - mock_backend.run.assert_called_once_with(policy_runner.name, sample_policy) + mock_backend.run.assert_called_once_with( + policy_runner.name, sample_policy, cache=policy_runner.cache + ) # Should call ingest once for the single chunk mock_diode_client.ingest.assert_called_once() # Check that entities were passed correctly @@ -185,6 +278,31 @@ def test_run_success(policy_runner, sample_policy, mock_diode_client, mock_backe assert len(call_args) == 3 +def test_run_forwards_backend_kwargs(policy_runner, sample_policy, mock_diode_client, mock_backend): + """Ensure PolicyRunner forwards keyword arguments to the backend run method.""" + policy_runner.name = "test_policy" + mock_backend.run.return_value = [] + mock_diode_client.ingest.return_value.errors = [] + mock_cache = {} + + policy_runner.run( + mock_diode_client, + mock_backend, + sample_policy, + schedule="0 * * * *", + custom="value", + cache=mock_cache, + ) + + mock_backend.run.assert_called_once_with( + policy_runner.name, + sample_policy, + schedule="0 * * * *", + custom="value", + cache=mock_cache, + ) + + def test_run_ingestion_errors( policy_runner, sample_policy, @@ -194,6 +312,7 @@ def test_run_ingestion_errors( ): """Test the run function when ingestion has errors.""" policy_runner.name = "test_policy" + policy_runner.cache = {} # Create mock entities entities = [] @@ -212,7 +331,9 @@ def test_run_ingestion_errors( policy_runner.run(mock_diode_client, mock_backend, sample_policy) # Assertions - mock_backend.run.assert_called_once_with(policy_runner.name, sample_policy) + mock_backend.run.assert_called_once_with( + policy_runner.name, sample_policy, cache=policy_runner.cache + ) mock_diode_client.ingest.assert_called_once() assert ( "Policy test_policy: Chunk 1 ingestion failed: ['error1', 'error2']" @@ -229,6 +350,7 @@ def test_run_backend_exception( ): """Test the run function when an exception is raised by the backend.""" policy_runner.name = "test_policy" + policy_runner.cache = {} # Simulate backend throwing an exception mock_backend.run.side_effect = Exception("Backend error") @@ -238,19 +360,23 @@ def test_run_backend_exception( policy_runner.run(mock_diode_client, mock_backend, sample_policy) # Assertions - mock_backend.run.assert_called_once_with(policy_runner.name, sample_policy) + mock_backend.run.assert_called_once_with( + policy_runner.name, sample_policy, cache=policy_runner.cache + ) mock_diode_client.ingest.assert_not_called() # Client ingestion should not be called assert "Policy test_policy: Backend error" in caplog.text def test_stop_policy_runner(policy_runner): """Test stopping the PolicyRunner.""" + policy_runner.cache = {} with patch.object(policy_runner.scheduler, "shutdown") as mock_shutdown: policy_runner.stop() # Ensure scheduler shutdown is called and status is updated mock_shutdown.assert_called_once() assert policy_runner.status == Status.FINISHED + assert policy_runner.cache is None def test_metrics_during_policy_lifecycle( @@ -279,6 +405,8 @@ def test_metrics_during_policy_lifecycle( app_name="test_app", app_version="1.0", ) + policy_runner.cache = {} + policy_runner.cache = {} # Create mock entities entities = [] @@ -299,7 +427,9 @@ def mock_get_metric(name): policy_runner.run(mock_diode_client, mock_backend, sample_policy) - mock_backend.run.assert_called_once_with(policy_runner.name, sample_policy) + mock_backend.run.assert_called_once_with( + policy_runner.name, sample_policy, cache=policy_runner.cache + ) mock_diode_client.ingest.assert_called_once() mock_policy_executions.add.assert_called_once_with(1, {"policy": "test_policy"}) diff --git a/worker/tests/test_backend.py b/worker/tests/test_backend.py index 8c2a18b..ccb1ce4 100644 --- a/worker/tests/test_backend.py +++ b/worker/tests/test_backend.py @@ -36,6 +36,19 @@ def test_backend_run_not_implemented(): list(backend.run("mock", mock_policy)) +def test_backend_run_accepts_kwargs(): + """Test that Backend subclasses can accept keyword arguments in run.""" + + class DummyBackend(Backend): + def run(self, policy_name, policy, **kwargs): + return kwargs + + backend = DummyBackend() + mock_policy = MagicMock(spec=Policy) + result = backend.run("mock", mock_policy, foo="bar", answer=42) + assert result == {"foo": "bar", "answer": 42} + + def test_load_class_valid_backend_class(mock_import_module): """Test that load_class successfully loads a valid Backend class.""" mock_module_name = "worker.test_module" diff --git a/worker/worker/backend.py b/worker/worker/backend.py index 3571c20..1f0e327 100644 --- a/worker/worker/backend.py +++ b/worker/worker/backend.py @@ -25,7 +25,7 @@ def setup(self) -> Metadata: """ raise NotImplementedError("The 'setup' method must be implemented.") - def run(self, policy_name: str, policy: Policy) -> Iterable[Entity]: + def run(self, policy_name: str, policy: Policy, **kwargs) -> Iterable[Entity]: """ Run the backend. @@ -33,6 +33,7 @@ def run(self, policy_name: str, policy: Policy) -> Iterable[Entity]: ---- policy_name (str): The name of the policy. policy (Policy): The policy to run. + **kwargs: Additional parameters provided by the runner. Returns: ------- diff --git a/worker/worker/policy/runner.py b/worker/worker/policy/runner.py index 2c9d34d..6c7ccaf 100644 --- a/worker/worker/policy/runner.py +++ b/worker/worker/policy/runner.py @@ -5,6 +5,7 @@ import logging import time from datetime import datetime, timedelta +from typing import Any from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.cron import CronTrigger @@ -33,6 +34,9 @@ def __init__(self): self.policy = None self.status = Status.NEW self.scheduler = BackgroundScheduler() + self.cache: Any = None + self._job_args: tuple[Any, ...] = () + self._job_kwargs: dict[str, Any] = {} def setup(self, name: str, diode_config: DiodeConfig, policy: Policy): """ @@ -51,6 +55,10 @@ def setup(self, name: str, diode_config: DiodeConfig, policy: Policy): ) backend_class = load_class(policy.config.package) backend = backend_class() + backend_kwargs = policy.config.model_dump( + exclude={"package"}, exclude_none=True + ) + cache_config = backend_kwargs.pop("cache", None) metadata = backend.setup() app_name = ( @@ -72,9 +80,16 @@ def setup(self, name: str, diode_config: DiodeConfig, policy: Policy): client_secret=diode_config.client_secret, ) + self.cache = self._create_cache(backend, cache_config) + backend_kwargs["cache"] = self.cache + backend_kwargs["schedule_now"] = self.schedule_now + self.metadata = metadata self.policy = policy + self._job_args = (client, backend, self.policy) + self._job_kwargs = dict(backend_kwargs) + self.scheduler.start() if self.policy.config.schedule is not None: @@ -91,7 +106,8 @@ def setup(self, name: str, diode_config: DiodeConfig, policy: Policy): self.scheduler.add_job( self.run, trigger=trigger, - args=[client, backend, self.policy], + args=self._job_args, + kwargs=self._job_kwargs, ) self.status = Status.RUNNING @@ -101,7 +117,11 @@ def setup(self, name: str, diode_config: DiodeConfig, policy: Policy): active_policies.add(1, {"policy": self.name}) def run( - self, client: DiodeClient | DiodeDryRunClient, backend: Backend, policy: Policy + self, + client: DiodeClient | DiodeDryRunClient, + backend: Backend, + policy: Policy, + **backend_kwargs, ): """ Run the custom backend code for the specified scope. @@ -118,8 +138,11 @@ def run( policy_executions.add(1, {"policy": self.name}) exec_start_time = time.perf_counter() + if "cache" not in backend_kwargs and self.cache is not None: + backend_kwargs["cache"] = self.cache + try: - entities = backend.run(self.name, policy) + entities = backend.run(self.name, policy, **backend_kwargs) for chunk_num, entity_chunk in enumerate(self._create_message_chunks(entities), 1): chunk_size_mb = self._estimate_message_size(entity_chunk) / (1024 * 1024) @@ -174,10 +197,30 @@ def stop(self): """Stop the policy runner.""" self.scheduler.shutdown() self.status = Status.FINISHED + self.cache = None + self._job_args = () + self._job_kwargs = {} active_policies = get_metric("active_policies") if active_policies: active_policies.add(-1, {"policy": self.name}) + def schedule_now(self, extra_kwargs: dict[str, Any] | None = None) -> None: + """Schedule the backend to run immediately, merging any extra kwargs.""" + if not self._job_args: + raise RuntimeError("PolicyRunner is not initialized; cannot schedule now") + + job_kwargs = dict(self._job_kwargs) + if extra_kwargs: + job_kwargs.update(extra_kwargs) + + immediate_trigger = DateTrigger(run_date=datetime.now()) + self.scheduler.add_job( + self.run, + trigger=immediate_trigger, + args=self._job_args, + kwargs=job_kwargs, + ) + def _create_message_chunks(self, entities: list[ingester_pb2.Entity]) -> list[list[ingester_pb2.Entity]]: """Create 3.5MB chunks from entities, always returning at least one chunk.""" total_entities = len(entities) @@ -208,3 +251,14 @@ def _estimate_message_size(self, entities: list[ingester_pb2.Entity]) -> int: request = ingester_pb2.IngestRequest() request.entities.extend(entities) return request.ByteSize() + + def _create_cache(self, backend: Backend, cache_config: Any) -> Any: + """Create a cache object that backends can reuse across runs.""" + cache_factory = getattr(backend, "create_cache", None) + if callable(cache_factory): + cache = cache_factory(cache_config) + if cache is not None: + return cache + if cache_config is not None: + return cache_config + return {} From fe2953128c79f26e18dd8101d2c09dbe2ef6fa3e Mon Sep 17 00:00:00 2001 From: Leonardo Parente <23251360+leoparente@users.noreply.github.com> Date: Mon, 29 Sep 2025 17:08:28 -0300 Subject: [PATCH 2/2] Small fix --- worker/tests/nbl-custom/nbl_custom/impl.py | 4 ++-- worker/worker/policy/runner.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/worker/tests/nbl-custom/nbl_custom/impl.py b/worker/tests/nbl-custom/nbl_custom/impl.py index bf1628c..ed8b3b5 100644 --- a/worker/tests/nbl-custom/nbl_custom/impl.py +++ b/worker/tests/nbl-custom/nbl_custom/impl.py @@ -73,9 +73,9 @@ def run(self, policy_name: str, policy: Policy) -> Iterable[Entity]: ) entities.append(Entity(device=device)) - cache = kwargs.get("cache", {}) + #cache = kwargs.get("cache", {}) logger.info(f"Policy '{policy_name}' config: {config}") logger.info(f"Policy '{policy_name}' scope: {scope}") - logger.info(f"Policy '{policy_name}' cache keys: {list(cache.keys())}") + #logger.info(f"Policy '{policy_name}' cache keys: {list(cache.keys())}") return entities diff --git a/worker/worker/policy/runner.py b/worker/worker/policy/runner.py index 6c7ccaf..01cdbf9 100644 --- a/worker/worker/policy/runner.py +++ b/worker/worker/policy/runner.py @@ -131,6 +131,7 @@ def run( client: Diode client. backend: Backend class. policy: Policy configuration. + **backend_kwargs: Additional keyword arguments for the backend. """ policy_executions = get_metric("policy_executions")