瀏覽代碼

Fix bugs from task refactor + add eta, countdown

Ask Solem 16 年之前
父節點
當前提交
121c33dbcf
共有 5 個文件被更改,包括 152 次插入98 次删除
  1. 105 0
      celery/execute.py
  2. 4 0
      celery/messaging.py
  3. 1 96
      celery/task/__init__.py
  4. 1 0
      celery/task/base.py
  5. 41 2
      celery/worker.py

+ 105 - 0
celery/execute.py

@@ -0,0 +1,105 @@
+from carrot.connection import DjangoAMQPConnection
+from celery.conf import AMQP_CONNECTION_TIMEOUT
+from celery.result import AsyncResult
+from celery.messaging import TaskPublisher
+from functools import partial as curry
+
+
+def apply_async(task, args=None, kwargs=None, routing_key=None,
+        immediate=None, mandatory=None, connection=None,
+        connect_timeout=AMQP_CONNECTION_TIMEOUT, priority=None,
+        eta=None, countdown=None, **opts):
+    """Run a task asynchronously by the celery daemon(s).
+
+    :param task: The task to run (a callable object, or a :class:`Task`
+        instance
+
+    :param args: The positional arguments to pass on to the task (a ``list``).
+
+    :param kwargs: The keyword arguments to pass on to the task (a ``dict``)
+
+    :keyword countdown:  Number of seconds in the future that the task
+        should execute.
+
+    :keyword eta: The task won't be run before this date
+        (a :class:`datetime.datetime` object).
+
+    :keyword routing_key: The routing key used to route the task to a worker
+        server.
+
+    :keyword immediate: Request immediate delivery. Will raise an exception
+        if the task cannot be routed to a worker immediately.
+
+    :keyword mandatory: Mandatory routing. Raises an exception if there's
+        no running workers able to take on this task.
+
+    :keyword connection: Re-use existing AMQP connection.
+        The ``connect_timeout`` argument is not respected if this is set.
+
+    :keyword connect_timeout: The timeout in seconds, before we give up
+        on establishing a connection to the AMQP server.
+
+    :keyword priority: The task priority, a number between ``0`` and ``9``.
+
+    """
+    args = args or []
+    kwargs = kwargs or {}
+    routing_key = routing_key or getattr(task, "routing_key", None)
+    immediate = immediate or getattr(task, "immediate", None)
+    mandatory = mandatory or getattr(task, "mandatory", None)
+    priority = priority or getattr(task, "priority", None)
+    taskset_id = opts.get("taskset_id")
+    publisher = opts.get("publisher")
+
+    need_to_close_connection = False
+    if not publisher:
+        if not connection:
+            connection = DjangoAMQPConnection(connect_timeout=connect_timeout)
+            need_to_close_connection = True
+        publisher = TaskPublisher(connection=connection)
+
+    delay_task = publisher.delay_task
+    if taskset_id:
+        delay_task = curry(publisher.delay_task_in_set, taskset_id)
+        
+    task_id = delay_task(task.name, args, kwargs,
+                         routing_key=routing_key, mandatory=mandatory,
+                         immediate=immediate, priority=priority,
+                         countdown=countdown, eta=eta)
+
+    if need_to_close_connection:
+        publisher.close()
+        connection.close()
+
+    return AsyncResult(task_id)
+
+
+def delay_task(task_name, *args, **kwargs):
+    """Delay a task for execution by the ``celery`` daemon.
+
+    :param task_name: the name of a task registered in the task registry.
+
+    :param \*args: positional arguments to pass on to the task.
+
+    :param \*\*kwargs: keyword arguments to pass on to the task.
+
+    :raises celery.registry.NotRegistered: exception if no such task
+        has been registered in the task registry.
+
+    :rtype: :class:`celery.result.AsyncResult`.
+
+    Example
+
+        >>> r = delay_task("update_record", name="George Constanza", age=32)
+        >>> r.ready()
+        True
+        >>> r.result
+        "Record was updated"
+
+    """
+    if task_name not in tasks:
+        raise tasks.NotRegistered(
+                "Task with name %s not registered in the task registry." % (
+                    task_name))
+    task = tasks[task_name]
+    return apply_async(task, args, kwargs)

+ 4 - 0
celery/messaging.py

@@ -42,6 +42,8 @@ class TaskPublisher(Publisher):
     def _delay_task(self, task_name, task_id=None, part_of_set=None,
             task_args=None, task_kwargs=None, **kwargs):
         """INTERNAL"""
+        eta = kwargs.get("eta")
+        countdown = kwargs.get("countdown")
         priority = kwargs.get("priority")
         immediate = kwargs.get("immediate")
         mandatory = kwargs.get("mandatory")
@@ -55,6 +57,8 @@ class TaskPublisher(Publisher):
             "task": task_name,
             "args": task_args,
             "kwargs": task_kwargs,
+            "countdown": countdown,
+            "eta": eta,
         }
         if part_of_set:
             message_data["taskset"] = part_of_set

+ 1 - 96
celery/task/__init__.py

@@ -5,113 +5,18 @@ Working with tasks and task sets.
 """
 from carrot.connection import DjangoAMQPConnection
 from celery.conf import AMQP_CONNECTION_TIMEOUT
-from celery.messaging import TaskPublisher
 from celery.registry import tasks
 from celery.backends import default_backend
-from celery.result import AsyncResult
 from celery.task.base import Task, TaskSet, PeriodicTask
 from celery.task.builtins import AsynchronousMapTask, ExecuteRemoteTask
 from celery.task.builtins import DeleteExpiredTaskMetaTask, PingTask
-from functools import partial as curry
+from celery.execute import apply_async, delay_task
 try:
     import cPickle as pickle
 except ImportError:
     import pickle
 
 
-def apply_async(task, args=None, kwargs=None, routing_key=None,
-        immediate=None, mandatory=None, connection=None,
-        connect_timeout=AMQP_CONNECTION_TIMEOUT, priority=None, **opts):
-    """Run a task asynchronously by the celery daemon(s).
-
-    :param task: The task to run (a callable object, or a :class:`Task`
-        instance
-
-    :param args: The positional arguments to pass on to the task (a ``list``).
-
-    :param kwargs: The keyword arguments to pass on to the task (a ``dict``)
-
-
-    :keyword routing_key: The routing key used to route the task to a worker
-        server.
-
-    :keyword immediate: Request immediate delivery. Will raise an exception
-        if the task cannot be routed to a worker immediately.
-
-    :keyword mandatory: Mandatory routing. Raises an exception if there's
-        no running workers able to take on this task.
-
-    :keyword connection: Re-use existing AMQP connection.
-        The ``connect_timeout`` argument is not respected if this is set.
-
-    :keyword connect_timeout: The timeout in seconds, before we give up
-        on establishing a connection to the AMQP server.
-
-    :keyword priority: The task priority, a number between ``0`` and ``9``.
-
-    """
-    args = args or []
-    kwargs = kwargs or {}
-    routing_key = routing_key or getattr(task, "routing_key", None)
-    immediate = immediate or getattr(task, "immediate", None)
-    mandatory = mandatory or getattr(task, "mandatory", None)
-    priority = priority or getattr(task, "priority", None)
-    taskset_id = opts.get("taskset_id")
-    publisher = opts.get("publisher")
-
-    need_to_close_connection = False
-    if not publisher:
-        if not connection:
-            connection = DjangoAMQPConnection(connect_timeout=connect_timeout)
-            need_to_close_connection = True
-        publisher = TaskPublisher(connection=connection)
-
-    delay_task = publisher.delay_task
-    if taskset_id:
-        delay_task = curry(publisher.delay_task_in_set, taskset_id)
-        
-    task_id = delay_task(task.name, args, kwargs,
-                         routing_key=routing_key, mandatory=mandatory,
-                         immediate=immediate, priority=priority)
-
-    if need_to_close_connection:
-        publisher.close()
-        connection.close()
-
-    return AsyncResult(task_id)
-
-
-def delay_task(task_name, *args, **kwargs):
-    """Delay a task for execution by the ``celery`` daemon.
-
-    :param task_name: the name of a task registered in the task registry.
-
-    :param \*args: positional arguments to pass on to the task.
-
-    :param \*\*kwargs: keyword arguments to pass on to the task.
-
-    :raises celery.registry.NotRegistered: exception if no such task
-        has been registered in the task registry.
-
-    :rtype: :class:`celery.result.AsyncResult`.
-
-    Example
-
-        >>> r = delay_task("update_record", name="George Constanza", age=32)
-        >>> r.ready()
-        True
-        >>> r.result
-        "Record was updated"
-
-    """
-    if task_name not in tasks:
-        raise tasks.NotRegistered(
-                "Task with name %s not registered in the task registry." % (
-                    task_name))
-    task = tasks[task_name]
-    return apply_async(task, args, kwargs)
-
-
 def discard_all(connect_timeout=AMQP_CONNECTION_TIMEOUT):
     """Discard all waiting tasks.
 

+ 1 - 0
celery/task/base.py

@@ -3,6 +3,7 @@ from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.messaging import TaskPublisher, TaskConsumer
 from celery.log import setup_logger
 from celery.result import TaskSetResult
+from celery.execute import apply_async, delay_task
 from datetime import timedelta
 import uuid
 try:

+ 41 - 2
celery/worker.py

@@ -11,6 +11,10 @@ from celery.backends import default_backend, default_periodic_status_backend
 from celery.timer import EventTimer
 from django.core.mail import mail_admins
 from celery.monitoring import TaskTimerStats
+from datetime import datetime, timedelta
+from Queue import Queue
+from Queue import Empty as QueueEmpty
+from multiprocessing import TimeoutError
 import multiprocessing
 import traceback
 import threading
@@ -345,6 +349,7 @@ class WorkController(object):
         self.is_detached = is_detached
         self.amqp_connection = None
         self.task_consumer = None
+        self.bucket_queue = Queue()
 
     def close_connection(self):
         """Close the AMQP connection."""
@@ -395,6 +400,15 @@ class WorkController(object):
 
     def process_task(self, message_data, message):
         """Process task message by passing it to the pool of workers."""
+        
+        countdown = message_data.get("countdown")
+        eta = message_data.get("eta")
+        if countdown:
+            eta = datetime.now() + timedelta(seconds=int(countdown))
+        if eta:
+            self.bucket_queue.put((message, message_data, eta))
+            return
+
         task = TaskWrapper.from_message(message, message_data,
                                         logger=self.logger)
         self.logger.info("Got task from broker: %s[%s]" % (
@@ -406,7 +420,7 @@ class WorkController(object):
 
         self.logger.debug("Task %s has been executed asynchronously." % task)
 
-        return result
+        return
 
     def shutdown(self):
         """Make sure ``celeryd`` exits cleanly."""
@@ -437,6 +451,31 @@ class WorkController(object):
 
         try:
             while True:
-                it.next()
+                try:
+                    self.process_bucket()
+                    self.process_next(it, timeout=1)
+                except TimeoutError:
+                    pass
         except (SystemExit, KeyboardInterrupt):
             self.shutdown()
+
+    def process_next(self, it, timeout=1):
+        def on_timeout():
+            raise TimeoutError()
+        timer = threading.Timer(timeout, on_timeout)
+        timer.start()
+        try:
+            it.next()
+        finally:
+            timer.cancel()
+
+    def process_bucket(self):
+        try:
+            message, msg_data, eta = self.bucket_queue.get_nowait()
+        except QueueEmpty:
+            pass
+        else:
+            if datetime.now() >= eta:
+                self.process_task(message, msg_data)
+            else:
+                self.bucket_queue.put((message, msg_data, eta))