瀏覽代碼

amqrpc backend: one queue per client implementation (reply-to style)

Ask Solem 13 年之前
父節點
當前提交
9244b6265f
共有 6 個文件被更改,包括 80 次插入10 次删除
  1. 4 2
      celery/app/amqp.py
  2. 3 1
      celery/app/task.py
  3. 1 0
      celery/backends/__init__.py
  4. 16 7
      celery/backends/amqp.py
  5. 53 0
      celery/backends/amqrpc.py
  6. 3 0
      celery/backends/base.py

+ 4 - 2
celery/app/amqp.py

@@ -156,8 +156,9 @@ class TaskProducer(Producer):
             queue=None, now=None, retries=0, chord=None, callbacks=None,
             queue=None, now=None, retries=0, chord=None, callbacks=None,
             errbacks=None, mandatory=None, priority=None, immediate=None,
             errbacks=None, mandatory=None, priority=None, immediate=None,
             routing_key=None, serializer=None, delivery_mode=None,
             routing_key=None, serializer=None, delivery_mode=None,
-            compression=None, **kwargs):
+            compression=None, reply_to=None, **kwargs):
         """Send task message."""
         """Send task message."""
+        retry = self.retry if retry is None else retry
         # merge default and custom policy
         # merge default and custom policy
         _rp = (dict(self.retry_policy, **retry_policy) if retry_policy
         _rp = (dict(self.retry_policy, **retry_policy) if retry_policy
                                                        else self.retry_policy)
                                                        else self.retry_policy)
@@ -186,7 +187,8 @@ class TaskProducer(Producer):
                 'expires': expires,
                 'expires': expires,
                 'utc': self.utc,
                 'utc': self.utc,
                 'callbacks': callbacks,
                 'callbacks': callbacks,
-                'errbacks': errbacks}
+                'errbacks': errbacks,
+                'reply_to': reply_to}
         group_id = group_id or taskset_id
         group_id = group_id or taskset_id
         if group_id:
         if group_id:
             body['taskset'] = group_id
             body['taskset'] = group_id

+ 3 - 1
celery/app/task.py

@@ -466,6 +466,7 @@ class Task(object):
             be replaced by a local :func:`apply` call instead.
             be replaced by a local :func:`apply` call instead.
 
 
         """
         """
+        task_id = task_id or uuid()
         producer = producer or publisher
         producer = producer or publisher
         app = self._get_app()
         app = self._get_app()
         router = router or self.app.amqp.router
         router = router or self.app.amqp.router
@@ -488,12 +489,13 @@ class Task(object):
                 evd = app.events.Dispatcher(channel=P.channel,
                 evd = app.events.Dispatcher(channel=P.channel,
                                             buffer_while_offline=False)
                                             buffer_while_offline=False)
 
 
+            extra_properties = self.backend.on_task_apply(task_id)
             task_id = P.delay_task(self.name, args, kwargs,
             task_id = P.delay_task(self.name, args, kwargs,
                                    task_id=task_id,
                                    task_id=task_id,
                                    event_dispatcher=evd,
                                    event_dispatcher=evd,
                                    callbacks=maybe_list(link),
                                    callbacks=maybe_list(link),
                                    errbacks=maybe_list(link_error),
                                    errbacks=maybe_list(link_error),
-                                   **options)
+                                   **dict(options, **extra_properties))
         result = self.AsyncResult(task_id)
         result = self.AsyncResult(task_id)
         if add_to_parent:
         if add_to_parent:
             parent = get_current_worker_task()
             parent = get_current_worker_task()

+ 1 - 0
celery/backends/__init__.py

@@ -23,6 +23,7 @@ Unknown result backend: %r.  Did you spell that correctly? (%r)\
 
 
 BACKEND_ALIASES = {
 BACKEND_ALIASES = {
     'amqp': 'celery.backends.amqp:AMQPBackend',
     'amqp': 'celery.backends.amqp:AMQPBackend',
+    'amqrpc': 'celery.backends.amqrpc.AMQRPCBackend',
     'cache': 'celery.backends.cache:CacheBackend',
     'cache': 'celery.backends.cache:CacheBackend',
     'redis': 'celery.backends.redis:RedisBackend',
     'redis': 'celery.backends.redis:RedisBackend',
     'mongodb': 'celery.backends.mongodb:MongoBackend',
     'mongodb': 'celery.backends.mongodb:MongoBackend',

+ 16 - 7
celery/backends/amqp.py

@@ -65,14 +65,10 @@ class AMQPBackend(BaseDictBackend):
         self.queue_arguments = {}
         self.queue_arguments = {}
         self.persistent = (conf.CELERY_RESULT_PERSISTENT if persistent is None
         self.persistent = (conf.CELERY_RESULT_PERSISTENT if persistent is None
                                                          else persistent)
                                                          else persistent)
-        delivery_mode = persistent and 'persistent' or 'transient'
         exchange = exchange or conf.CELERY_RESULT_EXCHANGE
         exchange = exchange or conf.CELERY_RESULT_EXCHANGE
         exchange_type = exchange_type or conf.CELERY_RESULT_EXCHANGE_TYPE
         exchange_type = exchange_type or conf.CELERY_RESULT_EXCHANGE_TYPE
-        self.exchange = self.Exchange(name=exchange,
-                                      type=exchange_type,
-                                      delivery_mode=delivery_mode,
-                                      durable=self.persistent,
-                                      auto_delete=False)
+        self.exchange = self._create_exchange(exchange, exchange_type,
+                                              self.persistent)
         self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
         self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
         self.auto_delete = auto_delete
         self.auto_delete = auto_delete
 
 
@@ -91,6 +87,14 @@ class AMQPBackend(BaseDictBackend):
             self.queue_arguments['x-expires'] = int(self.expires * 1000)
             self.queue_arguments['x-expires'] = int(self.expires * 1000)
         self.mutex = threading.Lock()
         self.mutex = threading.Lock()
 
 
+    def _create_exchange(self, name, type='direct', persistent=True):
+        delivery_mode = persistent and 'persistent' or 'transient'
+        return self.Exchange(name=name,
+                             type=type,
+                             delivery_mode=delivery_mode,
+                             durable=self.persistent,
+                             auto_delete=False)
+
     def _create_binding(self, task_id):
     def _create_binding(self, task_id):
         name = task_id.replace('-', '')
         name = task_id.replace('-', '')
         return self.Queue(name=name,
         return self.Queue(name=name,
@@ -103,16 +107,21 @@ class AMQPBackend(BaseDictBackend):
     def revive(self, channel):
     def revive(self, channel):
         pass
         pass
 
 
+    def _routing_key(self, task_id):
+        return task_id.replace('-', '')
+
     def _store_result(self, task_id, result, status, traceback=None):
     def _store_result(self, task_id, result, status, traceback=None):
         """Send task return value and status."""
         """Send task return value and status."""
         with self.mutex:
         with self.mutex:
             with self.app.amqp.producer_pool.acquire(block=True) as pub:
             with self.app.amqp.producer_pool.acquire(block=True) as pub:
+                print("PUBLISH TO exchange=%r rkey=%r" % (self.exchange,
+                    self._routing_key(task_id)))
                 pub.publish({'task_id': task_id, 'status': status,
                 pub.publish({'task_id': task_id, 'status': status,
                              'result': self.encode_result(result, status),
                              'result': self.encode_result(result, status),
                              'traceback': traceback,
                              'traceback': traceback,
                              'children': self.current_task_children()},
                              'children': self.current_task_children()},
                             exchange=self.exchange,
                             exchange=self.exchange,
-                            routing_key=task_id.replace('-', ''),
+                            routing_key=self._routing_key(task_id),
                             serializer=self.serializer,
                             serializer=self.serializer,
                             retry=True, retry_policy=self.retry_policy,
                             retry=True, retry_policy=self.retry_policy,
                             declare=[self._create_binding(task_id)])
                             declare=[self._create_binding(task_id)])

+ 53 - 0
celery/backends/amqrpc.py

@@ -0,0 +1,53 @@
+from __future__ import absolute_import
+
+import os
+import uuid
+
+from threading import local
+
+from celery.backends import amqp
+
+try:
+    from thread import get_ident            # noqa
+except ImportError:                         # pragma: no cover
+    try:
+        from dummy_thread import get_ident  # noqa
+    except ImportError:                     # pragma: no cover
+        from _thread import get_ident       # noqa
+
+_nodeid = uuid.getnode()
+
+
+class AMQRPCBackend(amqp.AMQPBackend):
+    _tls = local()
+
+    def _create_exchange(self, name, type='direct', persistent=False):
+        return self.Exchange('c.amqrpc', type=type, delivery_mode=1,
+                durable=False, auto_delete=False)
+
+    def on_task_apply(self, task_id):
+        with self.app.pool.acquire_channel(block=True) as (conn, channel):
+            self.binding(channel).declare()
+            return {'reply_to': self.oid}
+
+    def _create_binding(self, task_id):
+        print("BINDING: %r" % (self.binding, ))
+        return self.binding
+
+    def _routing_key(self, task_id):
+        from celery import current_task
+        return current_task.request.reply_to
+
+    @property
+    def binding(self):
+        return self.Queue(self.oid, self.exchange, self.oid,
+                          durable=False, auto_delete=True)
+
+    @property
+    def oid(self):
+        try:
+            return self._tls.OID
+        except AttributeError:
+            ent = '%x-%x-%x' % (_nodeid, os.getpid(), get_ident())
+            oid = self._tls.OID = str(uuid.uuid3(uuid.NAMESPACE_OID, ent))
+            return oid

+ 3 - 0
celery/backends/base.py

@@ -224,6 +224,9 @@ class BaseBackend(object):
         raise NotImplementedError(
         raise NotImplementedError(
                 'reload_group_result is not supported by this backend.')
                 'reload_group_result is not supported by this backend.')
 
 
+    def on_task_apply(self, task_id):
+        pass
+
     def on_chord_part_return(self, task, propagate=False):
     def on_chord_part_return(self, task, propagate=False):
         pass
         pass