瀏覽代碼

Retry sending task-sent event if connection failed.

Ask Solem 12 年之前
父節點
當前提交
bf340a03e1
共有 3 個文件被更改,包括 57 次插入29 次删除
  1. 28 12
      celery/app/amqp.py
  2. 0 6
      celery/app/task.py
  3. 29 11
      celery/events/__init__.py

+ 28 - 12
celery/app/amqp.py

@@ -167,11 +167,15 @@ class TaskProducer(Producer):
     retry = False
     retry_policy = None
     utc = True
+    event_dispatcher = None
+    send_sent_event = False
 
     def __init__(self, channel=None, exchange=None, *args, **kwargs):
         self.retry = kwargs.pop('retry', self.retry)
         self.retry_policy = kwargs.pop('retry_policy',
                                        self.retry_policy or {})
+        self.send_sent_event = kwargs.pop('send_sent_event',
+                                          self.send_sent_event)
         exchange = exchange or self.exchange
         self.queues = self.app.amqp.queues  # shortcut
         self.default_queue = self.app.amqp.default_queue
@@ -246,25 +250,36 @@ class TaskProducer(Producer):
         )
 
         signals.task_sent.send(sender=task_name, **body)
-        if event_dispatcher:
+        if self.send_sent_event:
+            evd = event_dispatcher or self.event_dispatcher
             exname = exchange or self.exchange
             if isinstance(exname, Exchange):
                 exname = exname.name
-            event_dispatcher.send(
-                'task-sent', uuid=task_id,
-                name=task_name,
-                args=safe_repr(task_args),
-                kwargs=safe_repr(task_kwargs),
-                retries=retries,
-                eta=eta,
-                expires=expires,
-                queue=qname,
-                exchange=exname,
-                routing_key=routing_key,
+            evd.publish(
+                'task-sent',
+                {
+                    'uuid': task_id,
+                    'name': task_name,
+                    'args': safe_repr(task_args),
+                    'kwargs': safe_repr(task_kwargs),
+                    'retries': retries,
+                    'eta': eta,
+                    'expires': expires,
+                    'queue': qname,
+                    'exchange': exname,
+                    'routing_key': routing_key,
+                },
+                self, retry=retry, retry_policy=retry_policy,
             )
         return task_id
     delay_task = publish_task   # XXX Compat
 
+    @cached_property
+    def event_dispatcher(self):
+        # We call Dispatcher.publish with a custom producer
+        # so don't need the dispatcher to be "enabled".
+        return self.app.events.Dispatcher(enabled=False)
+
 
 class TaskPublisher(TaskProducer):
     """Deprecated version of :class:`TaskProducer`."""
@@ -358,6 +373,7 @@ class AMQP(object):
             compression=conf.CELERY_MESSAGE_COMPRESSION,
             retry=conf.CELERY_TASK_PUBLISH_RETRY,
             retry_policy=conf.CELERY_TASK_PUBLISH_RETRY_POLICY,
+            send_sent_event=conf.CELERY_SEND_TASK_SENT_EVENT,
             utc=conf.CELERY_ENABLE_UTC,
         )
     TaskPublisher = TaskProducer  # compat

+ 0 - 6
celery/app/task.py

@@ -465,14 +465,8 @@ class Task(object):
         if connection:
             producer = app.amqp.TaskProducer(connection)
         with app.producer_or_acquire(producer) as P:
-            evd = None
-            if conf.CELERY_SEND_TASK_SENT_EVENT:
-                evd = app.events.Dispatcher(channel=P.channel,
-                                            buffer_while_offline=False)
-
             task_id = P.publish_task(self.name, args, kwargs,
                                      task_id=task_id,
-                                     event_dispatcher=evd,
                                      callbacks=maybe_list(link),
                                      errbacks=maybe_list(link_error),
                                      **options)

+ 29 - 11
celery/events/__init__.py

@@ -118,6 +118,21 @@ class EventDispatcher(object):
             for callback in self.on_disabled:
                 callback()
 
+    def publish(self, type, fields, producer, retry=False, retry_policy=None):
+        with self.mutex:
+            event = Event(type, hostname=self.hostname,
+                          clock=self.app.clock.forward(), **fields)
+            exchange = get_exchange(producer.connection)
+            producer.publish(
+                event,
+                routing_key=type.replace('-', '.'),
+                exchange=exchange.name,
+                retry=retry,
+                retry_policy=retry_policy,
+                declare=[exchange],
+                serializer=self.serializer,
+            )
+
     def send(self, type, **fields):
         """Send event.
 
@@ -126,16 +141,12 @@ class EventDispatcher(object):
 
         """
         if self.enabled:
-            with self.mutex:
-                event = Event(type, hostname=self.hostname,
-                              clock=self.app.clock.forward(), **fields)
-                try:
-                    self.publisher.publish(event,
-                                           routing_key=type.replace('-', '.'))
-                except Exception, exc:
-                    if not self.buffer_while_offline:
-                        raise
-                    self._outbound_buffer.append((type, fields, exc))
+            try:
+                self._send(type, fields, self.producer)
+            except Exception, exc:
+                if not self.buffer_while_offline:
+                    raise
+                self._outbound_buffer.append((type, fields, exc))
 
     def flush(self):
         while self._outbound_buffer:
@@ -203,9 +214,16 @@ class EventReceiver(object):
             yield consumer
 
     def itercapture(self, limit=None, timeout=None, wakeup=True):
-        with self.consumer(wakeup=wakeup) as consumer:
+        consumer = self.consumer(wakeup=wakeup)
+        consumer.consume()
+        try:
             yield consumer
             self.drain_events(limit=limit, timeout=timeout)
+        finally:
+            try:
+                consumer.cancel()
+            except self.connection.connection_errors:
+                pass
 
     def capture(self, limit=None, timeout=None, wakeup=True):
         """Open up a consumer capturing events.