1+ import collections
12import functools
23import logging
34
1819log = logging .getLogger ("taskbadger" )
1920
2021
22+ class Cache :
23+ def __init__ (self , maxsize = 128 ):
24+ self .cache = collections .OrderedDict ()
25+ self .maxsize = maxsize
26+
27+ def set (self , key , value ):
28+ self .cache [key ] = value
29+
30+ def unset (self , key ):
31+ self .cache .pop (key , None )
32+
33+ def get (self , key ):
34+ return self .cache .get (key )
35+
36+ def prune (self ):
37+ if len (self .cache ) > self .maxsize :
38+ self .cache .popitem (last = False )
39+
40+
41+ def cached (cache_none = True , maxsize = 128 ):
42+ cache = Cache (maxsize = maxsize )
43+
44+ def _wrapper (func ):
45+ @functools .wraps (func )
46+ def _inner (* args , ** kwargs ):
47+ key = args + tuple (sorted (kwargs .items ()))
48+ if key in cache .cache :
49+ return cache .get (key )
50+
51+ result = func (* args , ** kwargs )
52+ if result is not None or cache_none :
53+ cache .set (key , result )
54+ return result
55+
56+ _inner .cache = cache
57+ return _inner
58+
59+ return _wrapper
60+
61+
2162class Task (celery .Task ):
2263 """A Celery Task that tracks itself with TaskBadger.
2364
@@ -89,18 +130,21 @@ def taskbadger_task(self):
89130 task = self .request .get ("taskbadger_task" )
90131 if not task :
91132 log .debug ("Fetching task '%s'" , self .taskbadger_task_id )
92- try :
93- task = get_task ( self . taskbadger_task_id )
133+ task = safe_get_task ( self . taskbadger_task_id )
134+ if task :
94135 self .request .update ({"taskbadger_task" : task })
95- except Exception :
96- log .exception ("Error fetching task '%s'" , self .taskbadger_task_id )
97- task = None
98136 return task
99137
100138
101139@before_task_publish .connect
102140def task_publish_handler (sender = None , headers = None , ** kwargs ):
103- if not headers .get ("taskbadger_track" ) or not Badger .is_configured ():
141+ if sender .startswith ("celery." ) or not headers or not Badger .is_configured ():
142+ return
143+
144+ celery_system = Badger .current .settings .get_system_by_id ("celery" )
145+ auto_track = celery_system and celery_system .auto_track_tasks
146+ manual_track = headers .get ("taskbadger_track" )
147+ if not manual_track and not auto_track :
104148 return
105149
106150 ctask = celery .current_app .tasks .get (sender )
@@ -112,7 +156,7 @@ def task_publish_handler(sender=None, headers=None, **kwargs):
112156 kwargs [attr .removeprefix (KWARG_PREFIX )] = getattr (ctask , attr )
113157
114158 # get kwargs from the task headers (set via apply_async)
115- kwargs .update (headers [ TB_KWARGS_ARG ] )
159+ kwargs .update (headers . get ( TB_KWARGS_ARG , {}) )
116160 kwargs ["status" ] = StatusEnum .PENDING
117161 name = kwargs .pop ("name" , headers ["task" ])
118162
@@ -147,11 +191,20 @@ def task_retry_handler(sender=None, einfo=None, **kwargs):
147191
148192
149193def _update_task (signal_sender , status , einfo = None ):
150- log .debug ("celery_task_update %s %s" , signal_sender , status )
151- if not hasattr (signal_sender , "taskbadger_task" ):
194+ headers = signal_sender .request .headers
195+ if not headers :
196+ return
197+
198+ task_id = headers .get ("taskbadger_task_id" )
199+ if not task_id :
152200 return
153201
154- task = signal_sender .taskbadger_task
202+ log .debug ("celery_task_update %s %s" , signal_sender , status )
203+ if hasattr (signal_sender , "taskbadger_task" ):
204+ task = signal_sender .taskbadger_task
205+ else :
206+ task = safe_get_task (task_id )
207+
155208 if task is None :
156209 return
157210
@@ -164,7 +217,9 @@ def _update_task(signal_sender, status, einfo=None):
164217 data = None
165218 if einfo :
166219 data = DefaultMergeStrategy ().merge (task .data , {"exception" : str (einfo )})
167- update_task_safe (task .id , status = status , data = data )
220+ task = update_task_safe (task .id , status = status , data = data )
221+ if task :
222+ safe_get_task .cache .set ((task_id ,), task )
168223
169224
170225def enter_session ():
@@ -176,8 +231,25 @@ def enter_session():
176231
177232
178233def exit_session (signal_sender ):
179- if not hasattr (signal_sender , "taskbadger_task" ) or not Badger .is_configured ():
234+ headers = signal_sender .request .headers
235+ if not headers :
180236 return
237+
238+ task_id = headers .get ("taskbadger_task_id" )
239+ if not task_id or not Badger .is_configured ():
240+ return
241+
242+ safe_get_task .cache .unset ((task_id ,))
243+ safe_get_task .cache .prune ()
244+
181245 session = Badger .current .session ()
182246 if session .client :
183247 session .__exit__ ()
248+
249+
250+ @cached (cache_none = False )
251+ def safe_get_task (task_id : str ):
252+ try :
253+ return get_task (task_id )
254+ except Exception :
255+ log .exception ("Error fetching task '%s'" , task_id )
0 commit comments