Jelajahi Sumber

Implemented deleting (revoking of tasks) using a broadcast message to the workers + refactored apply_async / celery.messaging

Ask Solem 15 tahun lalu
induk
melakukan
0197805d65

+ 30 - 47
celery/execute.py

@@ -11,29 +11,33 @@ from celery.utils import gen_unique_id, noop, fun_takes_kwargs
 from celery.utils.functional import curry
 from celery.utils.functional import curry
 from celery.result import AsyncResult, EagerResult
 from celery.result import AsyncResult, EagerResult
 from celery.registry import tasks
 from celery.registry import tasks
-from celery.messaging import TaskPublisher
+from celery.messaging import TaskPublisher, with_connection
 from celery.exceptions import RetryTaskError
 from celery.exceptions import RetryTaskError
 from celery.datastructures import ExceptionInfo
 from celery.datastructures import ExceptionInfo
 
 
+TASK_EXEC_OPTIONS = ("routing_key", "exchange",
+                     "immediate", "mandatory",
+                     "priority", "serializer")
+
 
 
 def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
 def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
-        routing_key=None, exchange=None, task_id=None,
-        immediate=None, mandatory=None, priority=None, connection=None,
-        connect_timeout=AMQP_CONNECTION_TIMEOUT, serializer=None, **opts):
+        task_id=None, publisher=None, connection=None, connect_timeout=None,
+        **options):
     """Run a task asynchronously by the celery daemon(s).
     """Run a task asynchronously by the celery daemon(s).
 
 
     :param task: The task to run (a callable object, or a :class:`Task`
     :param task: The task to run (a callable object, or a :class:`Task`
         instance
         instance
 
 
-    :param args: The positional arguments to pass on to the task (a ``list``).
+    :keyword args: The positional arguments to pass on to the
+        task (a ``list``).
 
 
-    :param kwargs: The keyword arguments to pass on to the task (a ``dict``)
+    :keyword kwargs: The keyword arguments to pass on to the task (a ``dict``)
 
 
-    :param countdown: Number of seconds into the future that the task should
+    :keyword countdown: Number of seconds into the future that the task should
         execute. Defaults to immediate delivery (Do not confuse that with
         execute. Defaults to immediate delivery (Do not confuse that with
         the ``immediate`` setting, they are unrelated).
         the ``immediate`` setting, they are unrelated).
 
 
-    :param eta: A :class:`datetime.datetime` object that describes the
+    :keyword eta: A :class:`datetime.datetime` object that describes the
         absolute time when the task should execute. May not be specified
         absolute time when the task should execute. May not be specified
         if ``countdown`` is also supplied. (Do not confuse this with the
         if ``countdown`` is also supplied. (Do not confuse this with the
         ``immediate`` setting, they are unrelated).
         ``immediate`` setting, they are unrelated).
@@ -70,50 +74,29 @@ def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
     replaced by a local :func:`apply` call instead.
     replaced by a local :func:`apply` call instead.
 
 
     """
     """
-    args = args or []
-    kwargs = kwargs or {}
-    routing_key = routing_key or getattr(task, "routing_key", None)
-    exchange = exchange or getattr(task, "exchange", None)
-    if immediate is None:
-        immediate = getattr(task, "immediate", None)
-    if mandatory is None:
-        mandatory = getattr(task, "mandatory", None)
-    if priority is None:
-        priority = getattr(task, "priority", None)
-    serializer = serializer or getattr(task, "serializer", None)
-    taskset_id = opts.get("taskset_id")
-    publisher = opts.get("publisher")
-    retries = opts.get("retries", 0)
-    if countdown:
-        eta = datetime.now() + timedelta(seconds=countdown)
-
     from celery.conf import ALWAYS_EAGER
     from celery.conf import ALWAYS_EAGER
     if ALWAYS_EAGER:
     if ALWAYS_EAGER:
         return apply(task, args, kwargs)
         return apply(task, args, kwargs)
 
 
-    need_to_close_connection = False
-    if not publisher:
-        if not connection:
-            connection = DjangoBrokerConnection(
-                            connect_timeout=connect_timeout)
-            need_to_close_connection = True
-        publisher = TaskPublisher(connection=connection)
-
-    delay_task = publisher.delay_task
-    if taskset_id:
-        delay_task = curry(publisher.delay_task_in_set, taskset_id)
-
-    task_id = delay_task(task.name, args, kwargs,
-                         task_id=task_id, retries=retries,
-                         routing_key=routing_key, exchange=exchange,
-                         mandatory=mandatory, immediate=immediate,
-                         serializer=serializer, priority=priority,
-                         eta=eta)
-
-    if need_to_close_connection:
-        publisher.close()
-        connection.close()
+    for option_name in TASK_EXEC_OPTIONS:
+        if option_name not in options:
+            options[option_name] = getattr(task, option_name, None)
+
+    if countdown: # Convert countdown to ETA.
+        eta = datetime.now() + timedelta(seconds=countdown)
 
 
+    def _delay_task(connection):
+        publish = publisher or TaskPublisher(connection)
+        try:
+            return publish.delay_task(task.name, args or [], kwargs or {},
+                                      task_id=task_id,
+                                      eta=eta,
+                                      **options)
+        finally:
+            publisher or publish.close()
+
+    task_id = with_connection(_delay_task, connection=connection,
+                                           connect_timeout=connect_timeout)
     return AsyncResult(task_id)
     return AsyncResult(task_id)
 
 
 
 

+ 37 - 10
celery/messaging.py

@@ -3,6 +3,7 @@
 Sending and Receiving Messages
 Sending and Receiving Messages
 
 
 """
 """
+from carrot.connection import DjangoBrokerConnection
 from carrot.messaging import Publisher, Consumer, ConsumerSet
 from carrot.messaging import Publisher, Consumer, ConsumerSet
 
 
 from celery import conf
 from celery import conf
@@ -34,14 +35,7 @@ class TaskPublisher(Publisher):
         return self._delay_task(task_name=task_name, task_args=task_args,
         return self._delay_task(task_name=task_name, task_args=task_args,
                                 task_kwargs=task_kwargs, **kwargs)
                                 task_kwargs=task_kwargs, **kwargs)
 
 
-    def delay_task_in_set(self, taskset_id, task_name, task_args, task_kwargs,
-            **kwargs):
-        """Delay a task which part of a task set."""
-        return self._delay_task(task_name=task_name, part_of_set=taskset_id,
-                                task_args=task_args, task_kwargs=task_kwargs,
-                                **kwargs)
-
-    def _delay_task(self, task_name, task_id=None, part_of_set=None,
+    def _delay_task(self, task_name, task_id=None, taskset_id=None,
             task_args=None, task_kwargs=None, **kwargs):
             task_args=None, task_kwargs=None, **kwargs):
         """INTERNAL"""
         """INTERNAL"""
 
 
@@ -58,8 +52,8 @@ class TaskPublisher(Publisher):
             "eta": eta,
             "eta": eta,
         }
         }
 
 
-        if part_of_set:
-            message_data["taskset"] = part_of_set
+        if taskset_id:
+            message_data["taskset"] = taskset_id
 
 
         self.send(message_data, **extract_msg_options(kwargs))
         self.send(message_data, **extract_msg_options(kwargs))
         signals.task_sent.send(sender=task_name, **message_data)
         signals.task_sent.send(sender=task_name, **message_data)
@@ -84,6 +78,7 @@ class TaskConsumer(Consumer):
 
 
 class StatsPublisher(Publisher):
 class StatsPublisher(Publisher):
     exchange = "celerygraph"
     exchange = "celerygraph"
+    exchange_type = "direct"
     routing_key = "stats"
     routing_key = "stats"
     encoder = pickle.dumps
     encoder = pickle.dumps
 
 
@@ -95,3 +90,35 @@ class StatsConsumer(Consumer):
     exchange_type = "direct"
     exchange_type = "direct"
     decoder = pickle.loads
     decoder = pickle.loads
     no_ack=True
     no_ack=True
+
+
+class BroadcastPublisher(Publisher):
+    exchange = "celerycast"
+    exchange_type = "fanout"
+    routing_key = ""
+
+    def revoke(self, task_id):
+        self.send(dict(revoke=task_id))
+
+
+class BroadcastConsumer(Consumer):
+    queue = "celerycast"
+    exchange = "celerycast"
+    routing_key = ""
+    exchange_type = "fanout"
+    no_ack=True
+
+
+def establish_connection(connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
+    return DjangoBrokerConnection(connect_timeout=connect_timeout)
+
+
+def with_connection(fun, connection=None,
+        connect_timeout=conf.AMQP_CONNECTION_TIMEOUT):
+    conn = connection or establish_connection()
+    close_connection = not connection and conn.close or noop
+
+    try:
+        return fun(conn)
+    finally:
+        close_connection()

+ 19 - 1
celery/task/__init__.py

@@ -9,7 +9,7 @@ from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.execute import apply_async
 from celery.execute import apply_async
 from celery.registry import tasks
 from celery.registry import tasks
 from celery.backends import default_backend
 from celery.backends import default_backend
-from celery.messaging import TaskConsumer
+from celery.messaging import TaskConsumer, with_connection
 from celery.task.base import Task, TaskSet, PeriodicTask
 from celery.task.base import Task, TaskSet, PeriodicTask
 from celery.task.base import ExecuteRemoteTask, AsynchronousMapTask
 from celery.task.base import ExecuteRemoteTask, AsynchronousMapTask
 from celery.task.rest import RESTProxyTask
 from celery.task.rest import RESTProxyTask
@@ -35,6 +35,24 @@ def discard_all(connect_timeout=AMQP_CONNECTION_TIMEOUT):
     return discarded_count
     return discarded_count
 
 
 
 
+def revoke(task_id, connection=None, connect_timeout=None):
+    """Revoke a task by id.
+
+    Revoked tasks will not be executed after all.
+
+    """
+
+    def _revoke(connection):
+        broadcast = BroadcastPublisher(conn)
+        try:
+            broadcast.revoke(uuid)
+        finally:
+            broadcast.close()
+
+    return with_connection(_revoke, connection=connection,
+                           connect_timeout=connect_timeout)
+
+
 def is_successful(task_id):
 def is_successful(task_id):
     """Returns ``True`` if task with ``task_id`` has been executed.
     """Returns ``True`` if task with ``task_id`` has been executed.
 
 

+ 16 - 1
celery/worker/__init__.py

@@ -22,9 +22,10 @@ from celery.worker.job import TaskWrapper
 from celery.worker.scheduler import Scheduler
 from celery.worker.scheduler import Scheduler
 from celery.worker.controllers import Mediator, ScheduleController
 from celery.worker.controllers import Mediator, ScheduleController
 from celery.worker.buckets import TaskBucket
 from celery.worker.buckets import TaskBucket
-from celery.messaging import get_consumer_set
+from celery.messaging import get_consumer_set, BroadcastConsumer
 from celery.exceptions import NotRegistered
 from celery.exceptions import NotRegistered
 from celery.datastructures import SharedCounter
 from celery.datastructures import SharedCounter
+from celery.worker.revoke import revoked
 
 
 
 
 class CarrotListener(object):
 class CarrotListener(object):
@@ -99,6 +100,13 @@ class CarrotListener(object):
         otherwise we move it the bucket queue for immediate processing.
         otherwise we move it the bucket queue for immediate processing.
 
 
         """
         """
+
+        revoke_uuid = message_data.get("revoke", None)
+        if revoke_uuid:
+            revoked.add(revoke_uuid)
+            self.logger.warn("Task %s marked as revoked." % revoke_uuid)
+            return
+
         try:
         try:
             task = TaskWrapper.from_message(message, message_data,
             task = TaskWrapper.from_message(message, message_data,
                                             logger=self.logger)
                                             logger=self.logger)
@@ -106,6 +114,11 @@ class CarrotListener(object):
             self.logger.error("Unknown task ignored: %s" % (exc))
             self.logger.error("Unknown task ignored: %s" % (exc))
             return
             return
 
 
+        if task.task_id in revoked:
+            self.logger.warn("Got revoked task from broker: %s[%s]" % (
+                task.task_name, task.task_id))
+            return
+
         eta = message_data.get("eta")
         eta = message_data.get("eta")
         if eta:
         if eta:
             if not isinstance(eta, datetime):
             if not isinstance(eta, datetime):
@@ -144,6 +157,8 @@ class CarrotListener(object):
         self.close_connection()
         self.close_connection()
         self.amqp_connection = self._open_connection()
         self.amqp_connection = self._open_connection()
         self.task_consumer = get_consumer_set(connection=self.amqp_connection)
         self.task_consumer = get_consumer_set(connection=self.amqp_connection)
+        self.broadcast_consumer = BroadcastConsumer(self.amqp_connection)
+        self.task_consumer.add_consumer(self.broadcast_consumer)
         self.task_consumer.register_callback(self.receive_message)
         self.task_consumer.register_callback(self.receive_message)
 
 
     def _open_connection(self):
     def _open_connection(self):

+ 7 - 1
celery/worker/controllers.py

@@ -9,6 +9,7 @@ from Queue import Empty as QueueEmpty
 from datetime import datetime
 from datetime import datetime
 
 
 from celery.log import get_default_logger
 from celery.log import get_default_logger
+from celery.worker.revoke import revoked
 
 
 
 
 class BackgroundThread(threading.Thread):
 class BackgroundThread(threading.Thread):
@@ -91,9 +92,14 @@ class Mediator(BackgroundThread):
         except QueueEmpty:
         except QueueEmpty:
             time.sleep(1)
             time.sleep(1)
         else:
         else:
+            if task.task_id in revoked: # task revoked
+                logger.warn("Mediator: Skipping revoked task: %s[%s]" % (
+                    task.task_name, task.task_id))
+                return
+
             logger.debug("Mediator: Running callback for task: %s[%s]" % (
             logger.debug("Mediator: Running callback for task: %s[%s]" % (
                 task.task_name, task.task_id))
                 task.task_name, task.task_id))
-            self.callback(task)
+            self.callback(task) # execute
 
 
 
 
 class ScheduleController(BackgroundThread):
 class ScheduleController(BackgroundThread):

+ 53 - 0
celery/worker/revoke.py

@@ -0,0 +1,53 @@
+import time
+from UserDict import UserDict
+
+from carrot.connection import DjangoBrokerConnection
+
+from celery.messaging import BroadcastPublisher
+from celery.utils import noop
+
+REVOKES_MAX = 1000
+REVOKE_EXPIRES = 60 * 60 # one hour.
+
+
+class RevokeRegistry(UserDict):
+
+    def __init__(self, maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES):
+        self.maxlen = maxlen
+        self.expires = expires
+        self.data = {}
+
+    def add(self, uuid):
+        self._expire_item()
+        self[uuid] = time.time()
+
+    def _expire_item(self):
+        while 1:
+            if len(self) > self.maxlen:
+                uuid, when = self.oldest
+                if time.time() > when + self.expires:
+                    try:
+                        self.pop(uuid, None)
+                    except TypeError:
+                        continue
+            break
+
+    @property
+    def oldest(self):
+        return sorted(self.items(), key=lambda (uuid, when): when)[0]
+
+
+def revoke(uuid, connection=None):
+    conn = connection or DjangoBrokerConnection()
+    close_connection = not connection and conn.close or noop
+
+    broadcast = BroadcastPublisher(conn)
+    try:
+        broadcast.send({"revoke": uuid})
+    finally:
+        broadcast.close()
+        close_connection()
+
+revoked = RevokeRegistry()
+
+