Skip to content

Commit 5f4c178

Browse files
authored
Merge pull request #23 from taskbadger/celery-auto-track
auto track celery tasks
2 parents d1093db + 8303ca8 commit 5f4c178

File tree

10 files changed

+213
-19
lines changed

10 files changed

+213
-19
lines changed

integration_tests/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
import taskbadger as badger
7+
from taskbadger.systems.celery import CelerySystemIntegration
78

89

910
def _load_config():
@@ -30,5 +31,6 @@ def _load_config():
3031
os.environ.get("TASKBADGER_ORG", ""),
3132
os.environ.get("TASKBADGER_PROJECT", ""),
3233
os.environ.get("TASKBADGER_API_KEY", ""),
34+
systems=[CelerySystemIntegration()],
3335
)
3436
print(f"\nIntegration tests configuration:\n {badger.mug.Badger.current.settings}\n")

integration_tests/tasks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,9 @@ def add(self, x, y):
88
assert self.taskbadger_task is not None, "missing task on self"
99
self.taskbadger_task.update(value=100, data={"result": x + y})
1010
return x + y
11+
12+
13+
@shared_task(bind=True)
14+
def add_auto_track(self, x, y):
15+
assert self.request.taskbadger_task_id is not None, "missing task ID on self.request"
16+
return x + y

integration_tests/test_celery.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1+
import logging
12
import random
23

34
import pytest
45

56
from taskbadger import StatusEnum
67

7-
from .tasks import add
8+
from .tasks import add, add_auto_track
9+
10+
11+
@pytest.fixture(autouse=True)
12+
def check_log_errors(caplog):
13+
yield
14+
for when in ("call", "setup", "teardown"):
15+
errors = [r.getMessage() for r in caplog.get_records(when) if r.levelno == logging.ERROR]
16+
if errors:
17+
pytest.fail(f"log errors during '{when}': {errors}")
818

919

1020
@pytest.fixture(scope="session", autouse=True)
@@ -24,3 +34,9 @@ def test_celery(celery_worker):
2434
assert tb_task.status == StatusEnum.SUCCESS
2535
assert tb_task.value == 100
2636
assert tb_task.data == {"result": a + b}
37+
38+
39+
def test_celery_auto_track(celery_worker):
40+
a, b = random.randint(1, 1000), random.randint(1, 1000)
41+
result = add_auto_track.delay(a, b)
42+
assert result.get(timeout=10, propagate=True) == a + b

taskbadger/celery.py

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import functools
23
import logging
34

@@ -18,6 +19,46 @@
1819
log = logging.getLogger("taskbadger")
1920

2021

22+
class Cache:
23+
def __init__(self, maxsize=128):
24+
self.cache = collections.OrderedDict()
25+
self.maxsize = maxsize
26+
27+
def set(self, key, value):
28+
self.cache[key] = value
29+
30+
def unset(self, key):
31+
self.cache.pop(key, None)
32+
33+
def get(self, key):
34+
return self.cache.get(key)
35+
36+
def prune(self):
37+
if len(self.cache) > self.maxsize:
38+
self.cache.popitem(last=False)
39+
40+
41+
def cached(cache_none=True, maxsize=128):
42+
cache = Cache(maxsize=maxsize)
43+
44+
def _wrapper(func):
45+
@functools.wraps(func)
46+
def _inner(*args, **kwargs):
47+
key = args + tuple(sorted(kwargs.items()))
48+
if key in cache.cache:
49+
return cache.get(key)
50+
51+
result = func(*args, **kwargs)
52+
if result is not None or cache_none:
53+
cache.set(key, result)
54+
return result
55+
56+
_inner.cache = cache
57+
return _inner
58+
59+
return _wrapper
60+
61+
2162
class Task(celery.Task):
2263
"""A Celery Task that tracks itself with TaskBadger.
2364
@@ -89,18 +130,21 @@ def taskbadger_task(self):
89130
task = self.request.get("taskbadger_task")
90131
if not task:
91132
log.debug("Fetching task '%s'", self.taskbadger_task_id)
92-
try:
93-
task = get_task(self.taskbadger_task_id)
133+
task = safe_get_task(self.taskbadger_task_id)
134+
if task:
94135
self.request.update({"taskbadger_task": task})
95-
except Exception:
96-
log.exception("Error fetching task '%s'", self.taskbadger_task_id)
97-
task = None
98136
return task
99137

100138

101139
@before_task_publish.connect
102140
def task_publish_handler(sender=None, headers=None, **kwargs):
103-
if not headers.get("taskbadger_track") or not Badger.is_configured():
141+
if sender.startswith("celery.") or not headers or not Badger.is_configured():
142+
return
143+
144+
celery_system = Badger.current.settings.get_system_by_id("celery")
145+
auto_track = celery_system and celery_system.auto_track_tasks
146+
manual_track = headers.get("taskbadger_track")
147+
if not manual_track and not auto_track:
104148
return
105149

106150
ctask = celery.current_app.tasks.get(sender)
@@ -112,7 +156,7 @@ def task_publish_handler(sender=None, headers=None, **kwargs):
112156
kwargs[attr.removeprefix(KWARG_PREFIX)] = getattr(ctask, attr)
113157

114158
# get kwargs from the task headers (set via apply_async)
115-
kwargs.update(headers[TB_KWARGS_ARG])
159+
kwargs.update(headers.get(TB_KWARGS_ARG, {}))
116160
kwargs["status"] = StatusEnum.PENDING
117161
name = kwargs.pop("name", headers["task"])
118162

@@ -147,11 +191,20 @@ def task_retry_handler(sender=None, einfo=None, **kwargs):
147191

148192

149193
def _update_task(signal_sender, status, einfo=None):
150-
log.debug("celery_task_update %s %s", signal_sender, status)
151-
if not hasattr(signal_sender, "taskbadger_task"):
194+
headers = signal_sender.request.headers
195+
if not headers:
196+
return
197+
198+
task_id = headers.get("taskbadger_task_id")
199+
if not task_id:
152200
return
153201

154-
task = signal_sender.taskbadger_task
202+
log.debug("celery_task_update %s %s", signal_sender, status)
203+
if hasattr(signal_sender, "taskbadger_task"):
204+
task = signal_sender.taskbadger_task
205+
else:
206+
task = safe_get_task(task_id)
207+
155208
if task is None:
156209
return
157210

@@ -164,7 +217,9 @@ def _update_task(signal_sender, status, einfo=None):
164217
data = None
165218
if einfo:
166219
data = DefaultMergeStrategy().merge(task.data, {"exception": str(einfo)})
167-
update_task_safe(task.id, status=status, data=data)
220+
task = update_task_safe(task.id, status=status, data=data)
221+
if task:
222+
safe_get_task.cache.set((task_id,), task)
168223

169224

170225
def enter_session():
@@ -176,8 +231,25 @@ def enter_session():
176231

177232

178233
def exit_session(signal_sender):
179-
if not hasattr(signal_sender, "taskbadger_task") or not Badger.is_configured():
234+
headers = signal_sender.request.headers
235+
if not headers:
180236
return
237+
238+
task_id = headers.get("taskbadger_task_id")
239+
if not task_id or not Badger.is_configured():
240+
return
241+
242+
safe_get_task.cache.unset((task_id,))
243+
safe_get_task.cache.prune()
244+
181245
session = Badger.current.session()
182246
if session.client:
183247
session.__exit__()
248+
249+
250+
@cached(cache_none=False)
251+
def safe_get_task(task_id: str):
252+
try:
253+
return get_task(task_id)
254+
except Exception:
255+
log.exception("Error fetching task '%s'", task_id)

taskbadger/mug.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import dataclasses
22
from contextlib import ContextDecorator
33
from contextvars import ContextVar
4-
from typing import Union
4+
from typing import Dict, Union
55

66
from taskbadger.internal import AuthenticatedClient
7+
from taskbadger.systems import System
78

89
_local = ContextVar("taskbadger_client")
910

@@ -14,6 +15,7 @@ class Settings:
1415
token: str
1516
organization_slug: str
1617
project_slug: str
18+
systems: Dict[str, System] = dataclasses.field(default_factory=dict)
1719

1820
def get_client(self):
1921
return AuthenticatedClient(self.base_url, self.token)
@@ -24,6 +26,9 @@ def as_kwargs(self):
2426
"project_slug": self.project_slug,
2527
}
2628

29+
def get_system_by_id(self, identifier: str) -> System:
30+
return self.systems.get(identifier)
31+
2732
def __str__(self):
2833
return (
2934
f"Settings(base_url='{self.base_url}',"

taskbadger/sdk.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,36 @@
1313
)
1414
from taskbadger.internal.types import UNSET
1515
from taskbadger.mug import Badger, Session, Settings
16+
from taskbadger.systems import System
1617

1718
_TB_HOST = "https://taskbadger.net"
1819

1920

20-
def init(organization_slug: str = None, project_slug: str = None, token: str = None):
21+
def init(organization_slug: str = None, project_slug: str = None, token: str = None, systems: List[System] = None):
2122
"""Initialize Task Badger client
2223
2324
Call this function once per thread
2425
"""
25-
_init(_TB_HOST, organization_slug, project_slug, token)
26+
_init(_TB_HOST, organization_slug, project_slug, token, systems)
2627

2728

28-
def _init(host: str = None, organization_slug: str = None, project_slug: str = None, token: str = None):
29+
def _init(
30+
host: str = None,
31+
organization_slug: str = None,
32+
project_slug: str = None,
33+
token: str = None,
34+
systems: List[System] = None,
35+
):
2936
host = host or os.environ.get("TASKBADGER_HOST", "https://taskbadger.net")
3037
organization_slug = organization_slug or os.environ.get("TASKBADGER_ORG")
3138
project_slug = project_slug or os.environ.get("TASKBADGER_PROJECT")
3239
token = token or os.environ.get("TASKBADGER_API_KEY")
3340

3441
if host and organization_slug and project_slug and token:
35-
settings = Settings(host, token, organization_slug, project_slug)
42+
systems = systems or []
43+
settings = Settings(
44+
host, token, organization_slug, project_slug, systems={system.identifier: system for system in systems}
45+
)
3646
Badger.current.bind(settings)
3747
else:
3848
raise ConfigurationError(

taskbadger/systems/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
class System(object):
2+
"""
3+
Baseclass for all systems.
4+
"""
5+
6+
identifier: str = None

taskbadger/systems/celery.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from taskbadger.systems import System
2+
3+
4+
class CelerySystemIntegration(System):
5+
identifier = "celery"
6+
7+
def __init__(self, auto_track_tasks=True):
8+
"""
9+
Args:
10+
auto_track_tasks: Automatically track all Celery tasks regardless of whether they are using the
11+
`taskbadger.celery.Task` base class.
12+
"""
13+
self.auto_track_tasks = auto_track_tasks

tests/test_celery_error.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ def add_error(self, a, b):
2020
with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch(
2121
"taskbadger.celery.update_task_safe"
2222
) as update, mock.patch("taskbadger.celery.get_task") as get_task:
23-
get_task.return_value = task_for_test()
23+
task = task_for_test()
24+
get_task.return_value = task
25+
update.return_value = task
2426
result = add_error.delay(2, 2)
2527
with pytest.raises(Exception):
2628
result.get(timeout=10, propagate=True)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
Note
3+
====
4+
5+
As part of the Celery fixture setup a 'ping' task is run which executes
6+
before the `bind_settings` fixture is executed. This means that if any code
7+
calls `Badger.is_configured()` (or similar), the `_local` ContextVar in the
8+
Celery runner thread will not have the configuration set.
9+
"""
10+
import logging
11+
from unittest import mock
12+
13+
import pytest
14+
15+
from taskbadger.mug import Badger, Settings
16+
from taskbadger.systems.celery import CelerySystemIntegration
17+
from tests.utils import task_for_test
18+
19+
20+
@pytest.fixture
21+
def bind_settings_with_system():
22+
systems = [CelerySystemIntegration()]
23+
Badger.current.bind(
24+
Settings(
25+
"https://taskbadger.net", "token", "org", "proj", systems={system.identifier: system for system in systems}
26+
)
27+
)
28+
yield
29+
Badger.current.bind(None)
30+
31+
32+
@pytest.fixture(autouse=True)
33+
def check_log_errors(caplog):
34+
yield
35+
errors = [r.getMessage() for r in caplog.get_records("call") if r.levelno == logging.ERROR]
36+
if errors:
37+
pytest.fail(f"log errors during tests: {errors}")
38+
39+
40+
def test_celery_auto_track_task(celery_session_app, celery_session_worker, bind_settings_with_system):
41+
@celery_session_app.task(bind=True)
42+
def add_normal(self, a, b):
43+
assert self.request.get("taskbadger_task_id") is not None, "missing task in request"
44+
assert not hasattr(self, "taskbadger_task")
45+
assert Badger.current.session().client is not None, "missing client"
46+
return a + b
47+
48+
celery_session_worker.reload()
49+
50+
with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch(
51+
"taskbadger.celery.update_task_safe"
52+
) as update, mock.patch("taskbadger.celery.get_task") as get_task:
53+
tb_task = task_for_test()
54+
create.return_value = tb_task
55+
result = add_normal.delay(2, 2)
56+
assert result.info.get("taskbadger_task_id") == tb_task.id
57+
assert result.get(timeout=10, propagate=True) == 4
58+
59+
create.assert_called_once()
60+
assert get_task.call_count == 1
61+
assert update.call_count == 2
62+
assert Badger.current.session().client is None

0 commit comments

Comments
 (0)