Browse Source

Backends: store_result/mark_as* now takes request kwargs (needed for RPC backend)

Backends cannot rely on current_task as the state may be updated
in the parent process after task.request is gone.
Ask Solem 11 years ago
parent
commit
9c38139d5b

+ 13 - 5
celery/app/trace.py

@@ -93,7 +93,9 @@ class TraceInfo(object):
             reason = self.retval
             einfo = ExceptionInfo((type_, reason, tb))
             if store_errors:
-                task.backend.mark_as_retry(req.id, reason.exc, einfo.traceback)
+                task.backend.mark_as_retry(
+                    req.id, reason.exc, einfo.traceback, request=req,
+                )
             task.on_retry(reason.exc, req.id, req.args, req.kwargs, einfo)
             signals.task_retry.send(sender=task, request=req,
                                     reason=reason, einfo=einfo)
@@ -111,7 +113,9 @@ class TraceInfo(object):
             einfo.exception = get_pickleable_exception(einfo.exception)
             einfo.type = get_pickleable_etype(einfo.type)
             if store_errors:
-                task.backend.mark_as_failure(req.id, exc, einfo.traceback)
+                task.backend.mark_as_failure(
+                    req.id, exc, einfo.traceback, request=req,
+                )
             task.on_failure(exc, req.id, req.args, req.kwargs, einfo)
             signals.task_failure.send(sender=task, task_id=req.id,
                                       exception=exc, args=req.args,
@@ -204,8 +208,10 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                                 args=args, kwargs=kwargs)
                 loader_task_init(uuid, task)
                 if track_started:
-                    store_result(uuid, {'pid': pid,
-                                        'hostname': hostname}, STARTED)
+                    store_result(
+                        uuid, {'pid': pid, 'hostname': hostname}, STARTED,
+                        request=task_request,
+                    )
 
                 # -*- TRACE -*-
                 try:
@@ -237,7 +243,9 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                     [subtask(callback).apply_async((retval, ))
                         for callback in task_request.callbacks or []]
                     if publish_result:
-                        store_result(uuid, retval, SUCCESS)
+                        store_result(
+                            uuid, retval, SUCCESS, request=task_request,
+                        )
                     if task_on_success:
                         task_on_success(retval, uuid, args, kwargs)
                     if success_receivers:

+ 14 - 11
celery/backends/amqp.py

@@ -106,21 +106,24 @@ class AMQPBackend(BaseBackend):
     def revive(self, channel):
         pass
 
-    def _routing_key(self, task_id):
+    def _routing_key(self, task_id, request):
         return task_id.replace('-', '')
 
-    def _store_result(self, task_id, result, status, traceback=None):
+    def _store_result(self, task_id, result, status,
+                      traceback=None, request=None, **kwargs):
         """Send task return value and status."""
         with self.app.amqp.producer_pool.acquire(block=True) as producer:
-            producer.publish({'task_id': task_id, 'status': status,
-                              'result': self.encode_result(result, status),
-                              'traceback': traceback,
-                              'children': self.current_task_children()},
-                             exchange=self.exchange,
-                             routing_key=self._routing_key(task_id),
-                             serializer=self.serializer,
-                             retry=True, retry_policy=self.retry_policy,
-                             declare=self.on_reply_declare(task_id))
+            producer.publish(
+                {'task_id': task_id, 'status': status,
+                 'result': self.encode_result(result, status),
+                 'traceback': traceback,
+                 'children': self.current_task_children(request)},
+                exchange=self.exchange,
+                routing_key=self._routing_key(task_id, request),
+                serializer=self.serializer,
+                retry=True, retry_policy=self.retry_policy,
+                declare=self.on_reply_declare(task_id),
+            )
         return result
 
     def on_reply_declare(self, task_id):

+ 21 - 16
celery/backends/base.py

@@ -89,14 +89,15 @@ class BaseBackend(object):
         """Mark a task as started"""
         return self.store_result(task_id, meta, status=states.STARTED)
 
-    def mark_as_done(self, task_id, result):
+    def mark_as_done(self, task_id, result, request=None):
         """Mark task as successfully executed."""
-        return self.store_result(task_id, result, status=states.SUCCESS)
+        return self.store_result(task_id, result,
+                                 status=states.SUCCESS, request=request)
 
-    def mark_as_failure(self, task_id, exc, traceback=None):
+    def mark_as_failure(self, task_id, exc, traceback=None, request=None):
         """Mark task as executed with failure. Stores the execption."""
         return self.store_result(task_id, exc, status=states.FAILURE,
-                                 traceback=traceback)
+                                 traceback=traceback, request=request)
 
     def fail_from_current_stack(self, task_id, exc=None):
         type_, real_exc, tb = sys.exc_info()
@@ -108,15 +109,16 @@ class BaseBackend(object):
         finally:
             del(tb)
 
-    def mark_as_retry(self, task_id, exc, traceback=None):
+    def mark_as_retry(self, task_id, exc, traceback=None, request=None):
         """Mark task as being retries. Stores the current
         exception (if any)."""
         return self.store_result(task_id, exc, status=states.RETRY,
-                                 traceback=traceback)
+                                 traceback=traceback, request=request)
 
-    def mark_as_revoked(self, task_id, reason=''):
+    def mark_as_revoked(self, task_id, reason='', request=None):
         return self.store_result(task_id, TaskRevokedError(reason),
-                                 status=states.REVOKED, traceback=None)
+                                 status=states.REVOKED, traceback=None,
+                                 request=request)
 
     def prepare_exception(self, exc):
         """Prepare exception for serialization."""
@@ -195,10 +197,12 @@ class BaseBackend(object):
     def is_cached(self, task_id):
         return task_id in self._cache
 
-    def store_result(self, task_id, result, status, traceback=None, **kwargs):
+    def store_result(self, task_id, result, status,
+                     traceback=None, request=None, **kwargs):
         """Update task state and result."""
         result = self.encode_result(result, status)
-        self._store_result(task_id, result, status, traceback, **kwargs)
+        self._store_result(task_id, result, status, traceback,
+                           request=request, **kwargs)
         return result
 
     def forget(self, task_id):
@@ -300,10 +304,10 @@ class BaseBackend(object):
         )
     on_chord_apply = fallback_chord_unlock
 
-    def current_task_children(self):
-        current = current_task()
-        if current:
-            return [r.serializable() for r in current.request.children]
+    def current_task_children(self, request=None):
+        request = request or getattr(current_task(), 'request', None)
+        if request:
+            return [r.serializable() for r in getattr(request, 'children', [])]
 
     def __reduce__(self, args=(), kwargs={}):
         return (unpickle_backend, (self.__class__, args, kwargs))
@@ -398,9 +402,10 @@ class KeyValueStoreBackend(BaseBackend):
     def _forget(self, task_id):
         self.delete(self.get_key_for_task(task_id))
 
-    def _store_result(self, task_id, result, status, traceback=None):
+    def _store_result(self, task_id, result, status,
+                      traceback=None, request=None, **kwargs):
         meta = {'status': status, 'result': result, 'traceback': traceback,
-                'children': self.current_task_children()}
+                'children': self.current_task_children(request)}
         self.set(self.get_key_for_task(task_id), self.encode(meta))
         return result
 

+ 5 - 2
celery/backends/cassandra.py

@@ -135,7 +135,8 @@ class CassandraBackend(BaseBackend):
         if self._column_family is not None:
             self._column_family = None
 
-    def _store_result(self, task_id, result, status, traceback=None):
+    def _store_result(self, task_id, result, status,
+                      traceback=None, request=None, **kwargs):
         """Store return value and status of an executed task."""
 
         def _do_store():
@@ -144,7 +145,9 @@ class CassandraBackend(BaseBackend):
             meta = {'status': status,
                     'date_done': date_done.strftime('%Y-%m-%dT%H:%M:%SZ'),
                     'traceback': self.encode(traceback),
-                    'children': self.encode(self.current_task_children())}
+                    'children': self.encode(
+                        self.current_task_children(request),
+                    )}
             if self.detailed_mode:
                 meta['result'] = result
                 cf.insert(task_id, {date_done: self.encode(meta)},

+ 1 - 1
celery/backends/database/__init__.py

@@ -92,7 +92,7 @@ class DatabaseBackend(BaseBackend):
 
     @retry
     def _store_result(self, task_id, result, status,
-                      traceback=None, max_retries=3):
+                      traceback=None, max_retries=3, **kwargs):
         """Store return value and status of an executed task."""
         session = self.ResultSession()
         try:

+ 5 - 2
celery/backends/mongodb.py

@@ -128,14 +128,17 @@ class MongoBackend(BaseBackend):
             del(self.database)
             self._connection = None
 
-    def _store_result(self, task_id, result, status, traceback=None):
+    def _store_result(self, task_id, result, status,
+                      traceback=None, request=None, **kwargs):
         """Store return value and status of an executed task."""
         meta = {'_id': task_id,
                 'status': status,
                 'result': Binary(self.encode(result)),
                 'date_done': datetime.utcnow(),
                 'traceback': Binary(self.encode(traceback)),
-                'children': Binary(self.encode(self.current_task_children()))}
+                'children': Binary(self.encode(
+                    self.current_task_children(request),
+                ))}
         self.collection.save(meta)
 
         return result

+ 3 - 1
celery/backends/rpc.py

@@ -37,7 +37,9 @@ class RPCBackend(amqp.AMQPBackend):
     def _many_bindings(self, ids):
         return [self.binding]
 
-    def _routing_key(self, task_id):
+    def _routing_key(self, task_id, request):
+        if request:
+            return request.reply_to
         task = current_task._get_current_object()
         if task is not None:
             return task.request.reply_to

+ 4 - 1
celery/tests/backends/test_rpc.py

@@ -23,11 +23,14 @@ class test_RPCBackend(AppCase):
         self.b.on_reply_declare('task_id')
 
     def test_current_routing_key(self):
+        req = Mock(name='request')
+        req.reply_to = 'reply_to'
+        self.assertEqual(self.b._routing_key('task_id', req), 'reply_to')
         task = Mock()
         _task_stack.push(task)
         try:
             task.request.reply_to = 'reply_to'
-            self.assertEqual(self.b._routing_key('task_id'), 'reply_to')
+            self.assertEqual(self.b._routing_key('task_id', None), 'reply_to')
         finally:
             _task_stack.pop()
 

+ 2 - 4
celery/tests/worker/test_request.py

@@ -144,9 +144,7 @@ class test_trace_task(AppCase):
         tid = uuid()
         ret = jail(self.app, tid, self.mytask.name, [2], {})
         self.assertEqual(ret, 4)
-        self.mytask.backend.store_result.assert_called_with(
-            tid, 4, states.SUCCESS,
-        )
+        self.assertTrue(self.mytask.backend.store_result.called)
         self.assertIn('Process cleanup failed', _logger.error.call_args[0][0])
 
     def test_process_cleanup_BaseException(self):
@@ -162,7 +160,7 @@ class test_trace_task(AppCase):
     def test_marked_as_started(self):
         _started = []
 
-        def store_result(tid, meta, state):
+        def store_result(tid, meta, state, **kwars):
             if state == states.STARTED:
                 _started.append(tid)
         self.mytask.backend.store_result = Mock(name='store_result')

+ 10 - 3
celery/worker/job.py

@@ -285,7 +285,7 @@ class Request(object):
         self.send_event('task-revoked',
                         terminated=terminated, signum=signum, expired=expired)
         if self.store_errors:
-            self.task.backend.mark_as_revoked(self.id, reason)
+            self.task.backend.mark_as_revoked(self.id, reason, request=self)
         self.acknowledge()
         self._already_revoked = True
         send_revoked(self.task, request=self,
@@ -336,7 +336,7 @@ class Request(object):
             exc = TimeLimitExceeded(timeout)
 
         if self.store_errors:
-            self.task.backend.mark_as_failure(self.id, exc)
+            self.task.backend.mark_as_failure(self.id, exc, request=self)
 
     def on_success(self, ret_value, now=None, nowfun=monotonic):
         """Handler called if the task was successfully processed."""
@@ -393,7 +393,9 @@ class Request(object):
             # time to write the result.
             if self.store_errors:
                 if isinstance(exc, WorkerLostError):
-                    self.task.backend.mark_as_failure(self.id, exc)
+                    self.task.backend.mark_as_failure(
+                        self.id, exc, request=self,
+                    )
                 elif isinstance(exc, Terminated):
                     self._announce_revoked('terminated', True, str(exc), False)
                     send_failed_event = False  # already sent revoked event
@@ -529,6 +531,11 @@ class Request(object):
     def task_name(self, value):
         self.name = value
 
+    @property
+    def reply_to(self):
+        # used by rpc backend when failures reported by parent process
+        return self.request_dict['reply_to']
+
 
 class TaskRequest(Request):