Explorar o código

Implemented deleting (revoking of tasks) using a broadcast message to the workers + refactored apply_async / celery.messaging

Ask Solem %!s(int64=15) %!d(string=hai) anos
pai
achega
0197805d65

+ 30 - 47
celery/execute.py

@@ -11,29 +11,33 @@ from celery.utils import gen_unique_id, noop, fun_takes_kwargs
 from celery.utils.functional import curry
 from celery.result import AsyncResult, EagerResult
 from celery.registry import tasks
-from celery.messaging import TaskPublisher
+from celery.messaging import TaskPublisher, with_connection
 from celery.exceptions import RetryTaskError
 from celery.datastructures import ExceptionInfo
 
+TASK_EXEC_OPTIONS = ("routing_key", "exchange",
+                     "immediate", "mandatory",
+                     "priority", "serializer")
+
 
 def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
-        routing_key=None, exchange=None, task_id=None,
-        immediate=None, mandatory=None, priority=None, connection=None,
-        connect_timeout=AMQP_CONNECTION_TIMEOUT, serializer=None, **opts):
+        task_id=None, publisher=None, connection=None, connect_timeout=None,
+        **options):
     """Run a task asynchronously by the celery daemon(s).
 
     :param task: The task to run (a callable object, or a :class:`Task`
         instance
 
-    :param args: The positional arguments to pass on to the task (a ``list``).
+    :keyword args: The positional arguments to pass on to the
+        task (a ``list``).
 
-    :param kwargs: The keyword arguments to pass on to the task (a ``dict``)
+    :keyword kwargs: The keyword arguments to pass on to the task (a ``dict``)
 
-    :param countdown: Number of seconds into the future that the task should
+    :keyword countdown: Number of seconds into the future that the task should
         execute. Defaults to immediate delivery (Do not confuse that with
         the ``immediate`` setting, they are unrelated).
 
-    :param eta: A :class:`datetime.datetime` object that describes the
+    :keyword eta: A :class:`datetime.datetime` object that describes the
         absolute time when the task should execute. May not be specified
         if ``countdown`` is also supplied. (Do not confuse this with the
         ``immediate`` setting, they are unrelated).
@@ -70,50 +74,29 @@ def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
     replaced by a local :func:`apply` call instead.
 
     """
-    args = args or []
-    kwargs = kwargs or {}
-    routing_key = routing_key or getattr(task, "routing_key", None)
-    exchange = exchange or getattr(task, "exchange", None)
-    if immediate is None:
-        immediate = getattr(task, "immediate", None)
-    if mandatory is None:
-        mandatory = getattr(task, "mandatory", None)
-    if priority is None:
-        priority = getattr(task, "priority", None)
-    serializer = serializer or getattr(task, "serializer", None)
-    taskset_id = opts.get("taskset_id")
-    publisher = opts.get("publisher")
-    retries = opts.get("retries", 0)
-    if countdown:
-        eta = datetime.now() + timedelta(seconds=countdown)
-
     from celery.conf import ALWAYS_EAGER
     if ALWAYS_EAGER:
         return apply(task, args, kwargs)
 
-    need_to_close_connection = False
-    if not publisher:
-        if not connection:
-            connection = DjangoBrokerConnection(
-                            connect_timeout=connect_timeout)
-            need_to_close_connection = True
-        publisher = TaskPublisher(connection=connection)
-
-    delay_task = publisher.delay_task
-    if taskset_id:
-        delay_task = curry(publisher.delay_task_in_set, taskset_id)
-
-    task_id = delay_task(task.name, args, kwargs,
-                         task_id=task_id, retries=retries,
-                         routing_key=routing_key, exchange=exchange,
-                         mandatory=mandatory, immediate=immediate,
-                         serializer=serializer, priority=priority,
-                         eta=eta)
-
-    if need_to_close_connection:
-        publisher.close()
-        connection.close()
+    for option_name in TASK_EXEC_OPTIONS:
+        if option_name not in options:
+            options[option_name] = getattr(task, option_name, None)
+
+    if countdown: # Convert countdown to ETA.
+        eta = datetime.now() + timedelta(seconds=countdown)
 
+    def _delay_task(connection):
+        publish = publisher or TaskPublisher(connection)
+        try:
+            return publish.delay_task(task.name, args or [], kwargs or {},
+                                      task_id=task_id,
+                                      eta=eta,
+                                      **options)
+        finally:
+            publisher or publish.close()
+
+    task_id = with_connection(_delay_task, connection=connection,
+                                           connect_timeout=connect_timeout)
     return AsyncResult(task_id)
 
 

+ 37 - 10
celery/messaging.py

@@ -3,6 +3,7 @@
 Sending and Receiving Messages
 
 """
+from carrot.connection import DjangoBrokerConnection
 from carrot.messaging import Publisher, Consumer, ConsumerSet
 
 from celery import conf
@@ -34,14 +35,7 @@ class TaskPublisher(Publisher):
         return self._delay_task(task_name=task_name, task_args=task_args,
                                 task_kwargs=task_kwargs, **kwargs)
 
-    def delay_task_in_set(self, taskset_id, task_name, task_args, task_kwargs,
-            **kwargs):
-        """Delay a task which part of a task set."""
-        return self._delay_task(task_name=task_name, part_of_set=taskset_id,
-                                task_args=task_args, task_kwargs=task_kwargs,
-                                **kwargs)
-
-    def _delay_task(self, task_name, task_id=None, part_of_set=None,
+    def _delay_task(self, task_name, task_id=None, taskset_id=None,
             task_args=None, task_kwargs=None, **kwargs):
         """INTERNAL"""
 
@@ -58,8 +52,8 @@ class TaskPublisher(Publisher):
             "eta": eta,
         }
 
-        if part_of_set:
-            message_data["taskset"] = part_of_set
+        if taskset_id:
+            message_data["taskset"] = taskset_id
 
         self.send(message_data, **extract_msg_options(kwargs))
         signals.task_sent.send(sender=task_name, **message_data)
@@ -84,6 +78,7 @@ class TaskConsumer(Consumer):
 
 class StatsPublisher(Publisher):
     exchange = "celerygraph"
+    exchange_type = "direct"
     routing_key = "stats"
     encoder = pickle.dumps
 
@@ -95,3 +90,35 @@ class StatsConsumer(Consumer):
     exchange_type = "direct"
     decoder = pickle.loads
     no_ack=True
+
+
+class BroadcastPublisher(Publisher):
+    exchange = "celerycast"
+    exchange_type = "fanout"
+    routing_key = ""
+
+    def revoke(self, task_id):
+        self.send(dict(revoke=task_id))
+
+
+class BroadcastConsumer(Consumer):
+    queue = "celerycast"
+    exchange = "celerycast"
+    routing_key = ""
+    exchange_type = "fanout"
+    no_ack=True
+
+
+def establish_connection(connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
+    return DjangoBrokerConnection(connect_timeout=connect_timeout)
+
+
+def with_connection(fun, connection=None,
+        connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
+    conn = connection or establish_connection()
+    close_connection = not connection and conn.close or noop
+
+    try:
+        return fun(conn)
+    finally:
+        close_connection()

+ 19 - 1
celery/task/__init__.py

@@ -9,7 +9,7 @@ from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.execute import apply_async
 from celery.registry import tasks
 from celery.backends import default_backend
-from celery.messaging import TaskConsumer
+from celery.messaging import TaskConsumer, with_connection
 from celery.task.base import Task, TaskSet, PeriodicTask
 from celery.task.base import ExecuteRemoteTask, AsynchronousMapTask
 from celery.task.rest import RESTProxyTask
@@ -35,6 +35,24 @@ def discard_all(connect_timeout=AMQP_CONNECTION_TIMEOUT):
     return discarded_count
 
 
+def revoke(task_id, connection=None, connect_timeout=None):
+    """Revoke a task by id.
+
+    Revoked tasks will not be executed after all.
+
+    """
+
+    def _revoke(connection):
+        broadcast = BroadcastPublisher(conn)
+        try:
+            broadcast.revoke(uuid)
+        finally:
+            broadcast.close()
+
+    return with_connection(_revoke, connection=connection,
+                           connect_timeout=connect_timeout)
+
+
 def is_successful(task_id):
     """Returns ``True`` if task with ``task_id`` has been executed.
 

+ 16 - 1
celery/worker/__init__.py

@@ -22,9 +22,10 @@ from celery.worker.job import TaskWrapper
 from celery.worker.scheduler import Scheduler
 from celery.worker.controllers import Mediator, ScheduleController
 from celery.worker.buckets import TaskBucket
-from celery.messaging import get_consumer_set
+from celery.messaging import get_consumer_set, BroadcastConsumer
 from celery.exceptions import NotRegistered
 from celery.datastructures import SharedCounter
+from celery.worker.revoke import revoked
 
 
 class CarrotListener(object):
@@ -99,6 +100,13 @@ class CarrotListener(object):
         otherwise we move it the bucket queue for immediate processing.
 
         """
+
+        revoke_uuid = message_data.get("revoke", None)
+        if revoke_uuid:
+            revoked.add(revoke_uuid)
+            self.logger.warn("Task %s marked as revoked." % revoke_uuid)
+            return
+
         try:
             task = TaskWrapper.from_message(message, message_data,
                                             logger=self.logger)
@@ -106,6 +114,11 @@ class CarrotListener(object):
             self.logger.error("Unknown task ignored: %s" % (exc))
             return
 
+        if task.task_id in revoked:
+            self.logger.warn("Got revoked task from broker: %s[%s]" % (
+                task.task_name, task.task_id))
+            return
+
         eta = message_data.get("eta")
         if eta:
             if not isinstance(eta, datetime):
@@ -144,6 +157,8 @@ class CarrotListener(object):
         self.close_connection()
         self.amqp_connection = self._open_connection()
         self.task_consumer = get_consumer_set(connection=self.amqp_connection)
+        self.broadcast_consumer = BroadcastConsumer(self.amqp_connection)
+        self.task_consumer.add_consumer(self.broadcast_consumer)
         self.task_consumer.register_callback(self.receive_message)
 
     def _open_connection(self):

+ 7 - 1
celery/worker/controllers.py

@@ -9,6 +9,7 @@ from Queue import Empty as QueueEmpty
 from datetime import datetime
 
 from celery.log import get_default_logger
+from celery.worker.revoke import revoked
 
 
 class BackgroundThread(threading.Thread):
@@ -91,9 +92,14 @@ class Mediator(BackgroundThread):
         except QueueEmpty:
             time.sleep(1)
         else:
+            if task.task_id in revoked: # task revoked
+                logger.warn("Mediator: Skipping revoked task: %s[%s]" % (
+                    task.task_name, task.task_id))
+                return
+
             logger.debug("Mediator: Running callback for task: %s[%s]" % (
                 task.task_name, task.task_id))
-            self.callback(task)
+            self.callback(task) # execute
 
 
 class ScheduleController(BackgroundThread):

+ 53 - 0
celery/worker/revoke.py

@@ -0,0 +1,53 @@
+import time
+from UserDict import UserDict
+
+from carrot.connection import DjangoBrokerConnection
+
+from celery.messaging import BroadcastPublisher
+from celery.utils import noop
+
+REVOKES_MAX = 1000
+REVOKE_EXPIRES = 60 * 60 # one hour.
+
+
+class RevokeRegistry(UserDict):
+
+    def __init__(self, maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES):
+        self.maxlen = maxlen
+        self.expires = expires
+        self.data = {}
+
+    def add(self, uuid):
+        self._expire_item()
+        self[uuid] = time.time()
+
+    def _expire_item(self):
+        while 1:
+            if len(self) > self.maxlen:
+                uuid, when = self.oldest
+                if time.time() > when + self.expires:
+                    try:
+                        self.pop(uuid, None)
+                    except TypeError:
+                        continue
+            break
+
+    @property
+    def oldest(self):
+        return sorted(self.items(), key=lambda (uuid, when): when)[0]
+
+
+def revoke(uuid, connection=None):
+    conn = connection or DjangoBrokerConnection()
+    close_connection = not connection and conn.close or noop
+
+    broadcast = BroadcastPublisher(conn)
+    try:
+        broadcast.send({"revoke": uuid})
+    finally:
+        broadcast.close()
+        close_connection()
+
+revoked = RevokeRegistry()
+
+