1313KWARG_PREFIX = "taskbadger_"
1414TB_KWARGS_ARG = f"{ KWARG_PREFIX } kwargs"
1515IGNORE_ARGS = {TB_KWARGS_ARG , f"{ KWARG_PREFIX } task" , f"{ KWARG_PREFIX } task_id" }
16+ TB_TASK_ID = f"{ KWARG_PREFIX } task_id"
1617
1718TERMINAL_STATES = {StatusEnum .SUCCESS , StatusEnum .ERROR , StatusEnum .CANCELLED , StatusEnum .STALE }
1819
@@ -110,8 +111,8 @@ def apply_async(self, *args, **kwargs):
110111 headers [TB_KWARGS_ARG ] = tb_kwargs
111112 result = super ().apply_async (* args , ** kwargs )
112113
113- tb_task_id = result .info .get ("taskbadger_task_id" ) if result .info else None
114- setattr (result , "taskbadger_task_id" , tb_task_id )
114+ tb_task_id = result .info .get (TB_TASK_ID ) if result .info else None
115+ setattr (result , TB_TASK_ID , tb_task_id )
115116
116117 _get_task = functools .partial (get_task , tb_task_id ) if tb_task_id else lambda : None
117118 setattr (result , "get_taskbadger_task" , _get_task )
@@ -120,7 +121,7 @@ def apply_async(self, *args, **kwargs):
120121
121122 @property
122123 def taskbadger_task_id (self ):
123- return self .request and self . request . headers and self . request . headers . get ( "taskbadger_task_id" )
124+ return _get_taskbadger_task_id ( self .request )
124125
125126 @property
126127 def taskbadger_task (self ):
@@ -137,8 +138,9 @@ def taskbadger_task(self):
137138
138139
139140@before_task_publish .connect
140- def task_publish_handler (sender = None , headers = None , ** kwargs ):
141- if sender .startswith ("celery." ) or not headers or not Badger .is_configured ():
141+ def task_publish_handler (sender = None , headers = None , body = None , ** kwargs ):
142+ headers = headers if "task" in headers else body
143+ if sender .startswith ("celery." ) or not Badger .is_configured ():
142144 return
143145
144146 celery_system = Badger .current .settings .get_system_by_id ("celery" )
@@ -162,7 +164,7 @@ def task_publish_handler(sender=None, headers=None, **kwargs):
162164
163165 task = create_task_safe (name , ** kwargs )
164166 if task :
165- meta = {"taskbadger_task_id" : task .id }
167+ meta = {TB_TASK_ID : task .id }
166168 headers .update (meta )
167169 ctask .update_state (task_id = headers ["id" ], state = "PENDING" , meta = meta )
168170
@@ -191,11 +193,7 @@ def task_retry_handler(sender=None, einfo=None, **kwargs):
191193
192194
193195def _update_task (signal_sender , status , einfo = None ):
194- headers = signal_sender .request .headers
195- if not headers :
196- return
197-
198- task_id = headers .get ("taskbadger_task_id" )
196+ task_id = _get_taskbadger_task_id (signal_sender .request )
199197 if not task_id :
200198 return
201199
@@ -235,7 +233,7 @@ def exit_session(signal_sender):
235233 if not headers :
236234 return
237235
238- task_id = headers .get ("taskbadger_task_id" )
236+ task_id = headers .get (TB_TASK_ID )
239237 if not task_id or not Badger .is_configured ():
240238 return
241239
@@ -253,3 +251,15 @@ def safe_get_task(task_id: str):
253251 return get_task (task_id )
254252 except Exception :
255253 log .exception ("Error fetching task '%s'" , task_id )
254+
255+
256+ def _get_taskbadger_task_id (request ):
257+ if not request :
258+ return
259+
260+ task_id = request .get (TB_TASK_ID )
261+ if task_id :
262+ return task_id
263+
264+ if request .headers :
265+ return request .headers .get (TB_TASK_ID )
0 commit comments