diff --git a/taskbadger/exceptions.py b/taskbadger/exceptions.py index 946d6c1..b7b9457 100644 --- a/taskbadger/exceptions.py +++ b/taskbadger/exceptions.py @@ -1,4 +1,8 @@ class ConfigurationError(Exception): + pass + + +class MissingConfiguration(ConfigurationError): def __init__(self, **kwargs): self.missing = [name for name, arg in kwargs.items() if arg is None] diff --git a/taskbadger/mug.py b/taskbadger/mug.py index 1dacf76..7193099 100644 --- a/taskbadger/mug.py +++ b/taskbadger/mug.py @@ -2,7 +2,7 @@ from contextlib import ContextDecorator from contextvars import ContextVar from copy import deepcopy -from typing import Union +from typing import Callable, Optional, Union from taskbadger.internal import AuthenticatedClient from taskbadger.systems import System @@ -10,6 +10,9 @@ _local = ContextVar("taskbadger_client") +Callback = Union[str, Callable[[dict], Optional[dict]]] + + @dataclasses.dataclass class Settings: base_url: str @@ -17,6 +20,7 @@ class Settings: organization_slug: str project_slug: str systems: dict[str, System] = dataclasses.field(default_factory=dict) + before_create: Callback = None def get_client(self): return AuthenticatedClient(self.base_url, self.token) @@ -140,6 +144,11 @@ def client(self) -> AuthenticatedClient: def scope(self) -> Scope: return self._scope + def call_before_create(self, task: dict) -> Optional[dict]: + if self.settings and self.settings.before_create: + return self.settings.before_create(task) + return task + @classmethod def is_configured(cls): return cls.current.settings is not None diff --git a/taskbadger/sdk.py b/taskbadger/sdk.py index 885c17c..a610672 100644 --- a/taskbadger/sdk.py +++ b/taskbadger/sdk.py @@ -4,6 +4,7 @@ from taskbadger.exceptions import ( ConfigurationError, + MissingConfiguration, ServerError, TaskbadgerException, Unauthorized, @@ -21,11 +22,11 @@ PatchedTaskRequestTags, StatusEnum, TaskRequest, - TaskRequestTags, ) from taskbadger.internal.types import UNSET -from taskbadger.mug import Badger, Session, Settings +from taskbadger.mug import Badger, Callback, Session, Settings from taskbadger.systems import System +from taskbadger.utils import import_string log = logging.getLogger("taskbadger") @@ -38,12 +39,13 @@ def init( token: str = None, systems: list[System] = None, tags: dict[str, str] = None, + before_create: Callback = None, ): """Initialize Task Badger client Call this function once per thread """ - _init(_TB_HOST, organization_slug, project_slug, token, systems, tags) + _init(_TB_HOST, organization_slug, project_slug, token, systems, tags, before_create) def _init( @@ -53,12 +55,19 @@ def _init( token: str = None, systems: list[System] = None, tags: dict[str, str] = None, + before_create: Callback = None, ): host = host or os.environ.get("TASKBADGER_HOST", "https://taskbadger.net") organization_slug = organization_slug or os.environ.get("TASKBADGER_ORG") project_slug = project_slug or os.environ.get("TASKBADGER_PROJECT") token = token or os.environ.get("TASKBADGER_API_KEY") + if before_create and isinstance(before_create, str): + try: + before_create = import_string(before_create) + except ImportError as e: + raise ConfigurationError(f"Could not import module: {before_create}") from e + if host and organization_slug and project_slug and token: systems = systems or [] settings = Settings( @@ -67,10 +76,11 @@ def _init( organization_slug, project_slug, systems={system.identifier: system for system in systems}, + before_create=before_create, ) Badger.current.bind(settings, tags) else: - raise ConfigurationError( + raise MissingConfiguration( host=host, organization_slug=organization_slug, project_slug=project_slug, @@ -118,29 +128,33 @@ def create_task( Returns: Task: The created Task object. """ - value = _none_to_unset(value) - value_max = _none_to_unset(value_max) - data = _none_to_unset(data) - max_runtime = _none_to_unset(max_runtime) - stale_timeout = _none_to_unset(stale_timeout) - - task = TaskRequest( - name=name, - status=status, - value=value, - value_max=value_max, - max_runtime=max_runtime, - stale_timeout=stale_timeout, - ) + task_dict = { + "name": name, + "status": status, + } + if value is not None: + task_dict["value"] = value + if value_max is not None: + task_dict["value_max"] = value_max + if max_runtime is not None: + task_dict["max_runtime"] = max_runtime + if stale_timeout is not None: + task_dict["stale_timeout"] = stale_timeout scope = Badger.current.scope() if scope.context or data: data = data or {} - task.data = {**scope.context, **data} + task_dict["data"] = {**scope.context, **data} if actions: - task.additional_properties = {"actions": [a.to_dict() for a in actions]} + task_dict["actions"] = [a.to_dict() for a in actions] if scope.tags or tags: tags = tags or {} - task.tags = TaskRequestTags.from_dict({**scope.tags, **tags}) + task_dict["tags"] = {**scope.tags, **tags} + + task_dict = Badger.current.call_before_create(task_dict) + if not task_dict: + raise TaskbadgerException("before_create callback returned None") + + task = TaskRequest.from_dict(task_dict) kwargs = _make_args(body=task) if monitor_id: kwargs["x_taskbadger_monitor"] = monitor_id diff --git a/taskbadger/utils.py b/taskbadger/utils.py new file mode 100644 index 0000000..9dfd89c --- /dev/null +++ b/taskbadger/utils.py @@ -0,0 +1,15 @@ +from importlib import import_module + + +def import_string(dotted_path): + try: + module_path, class_name = dotted_path.rsplit(".", 1) + except ValueError as err: + raise ImportError("%s doesn't look like a module path" % dotted_path) from err + + module = import_module(module_path) + + try: + return getattr(module, class_name) + except AttributeError as err: + raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') from err diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 0000000..07acd31 --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,30 @@ +import pytest + +from taskbadger import Badger, init +from taskbadger.exceptions import ConfigurationError +from taskbadger.mug import _local + + +@pytest.fixture(autouse=True) +def _reset(): + b_global = Badger.current + _local.set(Badger()) + yield + _local.set(b_global) + + +def test_init(): + init("org", "project", "token", before_create=lambda x: x) + + +def test_init_import_before_create(): + init("org", "project", "token", before_create="tests.test_init._before_create") + + +def test_init_import_before_create_fail(): + with pytest.raises(ConfigurationError): + init("org", "project", "token", before_create="missing") + + +def _before_create(_): + pass diff --git a/tests/test_sdk.py b/tests/test_sdk.py index bd0023b..b96bbfe 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -3,7 +3,7 @@ import pytest -from taskbadger import Action, EmailIntegration, StatusEnum, WebhookIntegration +from taskbadger import Action, EmailIntegration, StatusEnum, WebhookIntegration, create_task from taskbadger.exceptions import TaskbadgerException from taskbadger.internal.models import ( PatchedTaskRequest, @@ -95,6 +95,42 @@ def test_create(settings, patched_create): ) +def test_before_create_update_task(settings, patched_create): + def before_create(task): + tags = task.setdefault("tags", {}) + tags["new"] = "tag" + return task + + settings.before_create = before_create + + api_task = task_for_test() + patched_create.return_value = Response(HTTPStatus.OK, b"", {}, api_task) + + task = create_task(name="task name") + assert task.id == api_task.id + + request = TaskRequest.from_dict( + { + "name": "task name", + "status": StatusEnum.PENDING, + "tags": {"new": "tag"}, + } + ) + assert patched_create.call_args[1]["body"] == request + + +def test_before_create_filter(settings, patched_create): + def before_create(_): + return None + + settings.before_create = before_create + + with pytest.raises(TaskbadgerException): + create_task(name="task name") + + patched_create.assert_not_called() + + def test_update_status(settings, patched_update): api_task = task_for_test() task = Task(api_task)