Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/device-discovery-lint-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,4 @@ jobs:

- name: Lint with Ruff
run: |
ruff check --output-format=github device_discovery/ tests/
continue-on-error: true
ruff check --output-format=github device_discovery/ tests/
3 changes: 1 addition & 2 deletions .github/workflows/worker-lint-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,4 @@ jobs:

- name: Lint with Ruff
run: |
ruff check --output-format=github worker/ tests/
continue-on-error: true
ruff check --output-format=github worker/ tests/
3 changes: 3 additions & 0 deletions worker/tests/nbl-custom/nbl_custom/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def run(self, policy_name: str, policy: Policy) -> Iterable[Entity]:
)

entities.append(Entity(device=device))
#cache = kwargs.get("cache", {})

logger.info(f"Policy '{policy_name}' config: {config}")
logger.info(f"Policy '{policy_name}' scope: {scope}")
#logger.info(f"Policy '{policy_name}' cache keys: {list(cache.keys())}")
return entities
138 changes: 134 additions & 4 deletions worker/tests/policy/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def test_setup_policy_runner_with_cron(
mock_start.assert_called_once()
mock_add_job.assert_called_once()
mock_load_class.assert_called_once()
mock_load_class.return_value.assert_called_once_with()
job_kwargs = mock_add_job.call_args[1]["kwargs"]
assert job_kwargs["schedule"] == "0 * * * *"
assert "cache" in job_kwargs
assert job_kwargs["cache"] is policy_runner.cache
schedule_now_callable = job_kwargs["schedule_now"]
assert schedule_now_callable.__self__ is policy_runner
assert schedule_now_callable.__func__ is PolicyRunner.schedule_now
mock_diode_client.assert_called_once()
assert policy_runner.status == Status.RUNNING

Expand All @@ -133,6 +141,13 @@ def test_setup_policy_runner_with_one_time_run(
# Verify that DateTrigger is used for one-time scheduling
trigger = mock_add_job.call_args[1]["trigger"]
mock_load_class.assert_called_once()
mock_load_class.return_value.assert_called_once_with()
job_kwargs = mock_add_job.call_args[1]["kwargs"]
assert set(job_kwargs.keys()) == {"cache", "schedule_now"}
assert job_kwargs["cache"] is policy_runner.cache
schedule_now_callable = job_kwargs["schedule_now"]
assert schedule_now_callable.__self__ is policy_runner
assert schedule_now_callable.__func__ is PolicyRunner.schedule_now
mock_diode_client.assert_called_once()
assert isinstance(trigger, DateTrigger)
assert mock_start.called
Expand All @@ -156,12 +171,88 @@ def test_setup_policy_runner_dry_run(
mock_start.assert_called_once()
mock_add_job.assert_called_once()
mock_load_class.assert_called_once()
mock_load_class.return_value.assert_called_once_with()
job_kwargs = mock_add_job.call_args[1]["kwargs"]
assert job_kwargs["schedule"] == "0 * * * *"
assert job_kwargs["cache"] is policy_runner.cache
schedule_now_callable = job_kwargs["schedule_now"]
assert schedule_now_callable.__self__ is policy_runner
assert schedule_now_callable.__func__ is PolicyRunner.schedule_now
mock_diode_dry_run_client.assert_called_once()
assert policy_runner.status == Status.RUNNING


def test_setup_uses_backend_cache_factory(
policy_runner, sample_diode_config, sample_policy, mock_diode_client
):
"""Ensure PolicyRunner uses backend cache factory when available."""
backend_instance = MagicMock()
backend_instance.setup.return_value = Metadata(
name="my_backend",
app_name="app",
app_version="1.0",
)
backend_instance.create_cache.return_value = {"token": "value"}

backend_class = MagicMock(return_value=backend_instance)

with patch("worker.policy.runner.load_class", return_value=backend_class), patch.object(
policy_runner.scheduler, "start"
), patch.object(policy_runner.scheduler, "add_job") as mock_add_job:
policy_runner.setup("policy1", sample_diode_config, sample_policy)

backend_instance.create_cache.assert_called_once()
assert policy_runner.cache == {"token": "value"}
job_kwargs = mock_add_job.call_args[1]["kwargs"]
assert job_kwargs["cache"] == {"token": "value"}
schedule_now_callable = job_kwargs["schedule_now"]
assert schedule_now_callable.__self__ is policy_runner
assert schedule_now_callable.__func__ is PolicyRunner.schedule_now


def test_schedule_now_adds_job(
policy_runner,
sample_diode_config,
sample_policy,
mock_diode_client,
):
"""Ensure schedule_now queues an immediate job with merged kwargs."""
backend_instance = MagicMock()
backend_instance.setup.return_value = Metadata(
name="backend",
app_name="app",
app_version="1.0",
)

backend_class = MagicMock(return_value=backend_instance)

with patch("worker.policy.runner.load_class", return_value=backend_class), patch.object(
policy_runner.scheduler, "start"
), patch.object(policy_runner.scheduler, "add_job"):
policy_runner.setup("policy1", sample_diode_config, sample_policy)

with patch.object(policy_runner.scheduler, "add_job") as mock_add_job:
policy_runner.schedule_now({"custom": "value"})

mock_add_job.assert_called_once()
call_args, call_kwargs = mock_add_job.call_args
run_callable = call_args[0]
assert run_callable.__self__ is policy_runner
assert run_callable.__func__ is PolicyRunner.run
scheduled_trigger = call_kwargs["trigger"]
assert isinstance(scheduled_trigger, DateTrigger)
scheduled_kwargs = call_kwargs["kwargs"]
assert scheduled_kwargs["custom"] == "value"
assert scheduled_kwargs["cache"] is policy_runner.cache
schedule_now_callable = scheduled_kwargs["schedule_now"]
assert schedule_now_callable.__self__ is policy_runner
assert schedule_now_callable.__func__ is PolicyRunner.schedule_now


def test_run_success(policy_runner, sample_policy, mock_diode_client, mock_backend):
"""Test the run function for a successful execution."""
policy_runner.name = "test_policy"
policy_runner.cache = {}

# Create mock entities
entities = []
Expand All @@ -177,14 +268,41 @@ def test_run_success(policy_runner, sample_policy, mock_diode_client, mock_backe
policy_runner.run(mock_diode_client, mock_backend, sample_policy)

# Assertions
mock_backend.run.assert_called_once_with(policy_runner.name, sample_policy)
mock_backend.run.assert_called_once_with(
policy_runner.name, sample_policy, cache=policy_runner.cache
)
# Should call ingest once for the single chunk
mock_diode_client.ingest.assert_called_once()
# Check that entities were passed correctly
call_args = mock_diode_client.ingest.call_args[1]['entities']
assert len(call_args) == 3


def test_run_forwards_backend_kwargs(policy_runner, sample_policy, mock_diode_client, mock_backend):
"""Ensure PolicyRunner forwards keyword arguments to the backend run method."""
policy_runner.name = "test_policy"
mock_backend.run.return_value = []
mock_diode_client.ingest.return_value.errors = []
mock_cache = {}

policy_runner.run(
mock_diode_client,
mock_backend,
sample_policy,
schedule="0 * * * *",
custom="value",
cache=mock_cache,
)

mock_backend.run.assert_called_once_with(
policy_runner.name,
sample_policy,
schedule="0 * * * *",
custom="value",
cache=mock_cache,
)


def test_run_ingestion_errors(
policy_runner,
sample_policy,
Expand All @@ -194,6 +312,7 @@ def test_run_ingestion_errors(
):
"""Test the run function when ingestion has errors."""
policy_runner.name = "test_policy"
policy_runner.cache = {}

# Create mock entities
entities = []
Expand All @@ -212,7 +331,9 @@ def test_run_ingestion_errors(
policy_runner.run(mock_diode_client, mock_backend, sample_policy)

# Assertions
mock_backend.run.assert_called_once_with(policy_runner.name, sample_policy)
mock_backend.run.assert_called_once_with(
policy_runner.name, sample_policy, cache=policy_runner.cache
)
mock_diode_client.ingest.assert_called_once()
assert (
"Policy test_policy: Chunk 1 ingestion failed: ['error1', 'error2']"
Expand All @@ -229,6 +350,7 @@ def test_run_backend_exception(
):
"""Test the run function when an exception is raised by the backend."""
policy_runner.name = "test_policy"
policy_runner.cache = {}

# Simulate backend throwing an exception
mock_backend.run.side_effect = Exception("Backend error")
Expand All @@ -238,19 +360,23 @@ def test_run_backend_exception(
policy_runner.run(mock_diode_client, mock_backend, sample_policy)

# Assertions
mock_backend.run.assert_called_once_with(policy_runner.name, sample_policy)
mock_backend.run.assert_called_once_with(
policy_runner.name, sample_policy, cache=policy_runner.cache
)
mock_diode_client.ingest.assert_not_called() # Client ingestion should not be called
assert "Policy test_policy: Backend error" in caplog.text


def test_stop_policy_runner(policy_runner):
"""Test stopping the PolicyRunner."""
policy_runner.cache = {}
with patch.object(policy_runner.scheduler, "shutdown") as mock_shutdown:
policy_runner.stop()

# Ensure scheduler shutdown is called and status is updated
mock_shutdown.assert_called_once()
assert policy_runner.status == Status.FINISHED
assert policy_runner.cache is None


def test_metrics_during_policy_lifecycle(
Expand Down Expand Up @@ -279,6 +405,8 @@ def test_metrics_during_policy_lifecycle(
app_name="test_app",
app_version="1.0",
)
policy_runner.cache = {}
policy_runner.cache = {}

# Create mock entities
entities = []
Expand All @@ -299,7 +427,9 @@ def mock_get_metric(name):

policy_runner.run(mock_diode_client, mock_backend, sample_policy)

mock_backend.run.assert_called_once_with(policy_runner.name, sample_policy)
mock_backend.run.assert_called_once_with(
policy_runner.name, sample_policy, cache=policy_runner.cache
)
mock_diode_client.ingest.assert_called_once()

mock_policy_executions.add.assert_called_once_with(1, {"policy": "test_policy"})
Expand Down
13 changes: 13 additions & 0 deletions worker/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ def test_backend_run_not_implemented():
list(backend.run("mock", mock_policy))


def test_backend_run_accepts_kwargs():
"""Test that Backend subclasses can accept keyword arguments in run."""

class DummyBackend(Backend):
def run(self, policy_name, policy, **kwargs):
return kwargs

backend = DummyBackend()
mock_policy = MagicMock(spec=Policy)
result = backend.run("mock", mock_policy, foo="bar", answer=42)
assert result == {"foo": "bar", "answer": 42}


def test_load_class_valid_backend_class(mock_import_module):
"""Test that load_class successfully loads a valid Backend class."""
mock_module_name = "worker.test_module"
Expand Down
3 changes: 2 additions & 1 deletion worker/worker/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ def setup(self) -> Metadata:
"""
raise NotImplementedError("The 'setup' method must be implemented.")

def run(self, policy_name: str, policy: Policy) -> Iterable[Entity]:
def run(self, policy_name: str, policy: Policy, **kwargs) -> Iterable[Entity]:
"""
Run the backend.

Args:
----
policy_name (str): The name of the policy.
policy (Policy): The policy to run.
**kwargs: Additional parameters provided by the runner.

Returns:
-------
Expand Down
Loading
Loading