Bläddra i källkod

Fix bugs from task refactor + add eta, countdown

Ask Solem 16 år sedan
förälder
incheckning
121c33dbcf
5 ändrade filer med 152 tillägg och 98 borttagningar
  1. 105 0
      celery/execute.py
  2. 4 0
      celery/messaging.py
  3. 1 96
      celery/task/__init__.py
  4. 1 0
      celery/task/base.py
  5. 41 2
      celery/worker.py

+ 105 - 0
celery/execute.py

@@ -0,0 +1,105 @@
+from carrot.connection import DjangoAMQPConnection
+from celery.conf import AMQP_CONNECTION_TIMEOUT
+from celery.result import AsyncResult
+from celery.messaging import TaskPublisher
+from functools import partial as curry
+
+
+def apply_async(task, args=None, kwargs=None, routing_key=None,
+        immediate=None, mandatory=None, connection=None,
+        connect_timeout=AMQP_CONNECTION_TIMEOUT, priority=None,
+        eta=None, countdown=None, **opts):
+    """Run a task asynchronously by the celery daemon(s).
+
+    :param task: The task to run (a callable object, or a :class:`Task`
+        instance
+
+    :param args: The positional arguments to pass on to the task (a ``list``).
+
+    :param kwargs: The keyword arguments to pass on to the task (a ``dict``)
+
+    :keyword countdown:  Number of seconds in the future that the task
+        should execute.
+
+    :keyword eta: The task won't be run before this date
+        (a :class:`datetime.datetime` object).
+
+    :keyword routing_key: The routing key used to route the task to a worker
+        server.
+
+    :keyword immediate: Request immediate delivery. Will raise an exception
+        if the task cannot be routed to a worker immediately.
+
+    :keyword mandatory: Mandatory routing. Raises an exception if there's
+        no running workers able to take on this task.
+
+    :keyword connection: Re-use existing AMQP connection.
+        The ``connect_timeout`` argument is not respected if this is set.
+
+    :keyword connect_timeout: The timeout in seconds, before we give up
+        on establishing a connection to the AMQP server.
+
+    :keyword priority: The task priority, a number between ``0`` and ``9``.
+
+    """
+    args = args or []
+    kwargs = kwargs or {}
+    routing_key = routing_key or getattr(task, "routing_key", None)
+    immediate = immediate or getattr(task, "immediate", None)
+    mandatory = mandatory or getattr(task, "mandatory", None)
+    priority = priority or getattr(task, "priority", None)
+    taskset_id = opts.get("taskset_id")
+    publisher = opts.get("publisher")
+
+    need_to_close_connection = False
+    if not publisher:
+        if not connection:
+            connection = DjangoAMQPConnection(connect_timeout=connect_timeout)
+            need_to_close_connection = True
+        publisher = TaskPublisher(connection=connection)
+
+    delay_task = publisher.delay_task
+    if taskset_id:
+        delay_task = curry(publisher.delay_task_in_set, taskset_id)
+        
+    task_id = delay_task(task.name, args, kwargs,
+                         routing_key=routing_key, mandatory=mandatory,
+                         immediate=immediate, priority=priority,
+                         countdown=countdown, eta=eta)
+
+    if need_to_close_connection:
+        publisher.close()
+        connection.close()
+
+    return AsyncResult(task_id)
+
+
+def delay_task(task_name, *args, **kwargs):
+    """Delay a task for execution by the ``celery`` daemon.
+
+    :param task_name: the name of a task registered in the task registry.
+
+    :param \*args: positional arguments to pass on to the task.
+
+    :param \*\*kwargs: keyword arguments to pass on to the task.
+
+    :raises celery.registry.NotRegistered: exception if no such task
+        has been registered in the task registry.
+
+    :rtype: :class:`celery.result.AsyncResult`.
+
+    Example
+
+        >>> r = delay_task("update_record", name="George Constanza", age=32)
+        >>> r.ready()
+        True
+        >>> r.result
+        "Record was updated"
+
+    """
+    if task_name not in tasks:
+        raise tasks.NotRegistered(
+                "Task with name %s not registered in the task registry." % (
+                    task_name))
+    task = tasks[task_name]
+    return apply_async(task, args, kwargs)

+ 4 - 0
celery/messaging.py

@@ -42,6 +42,8 @@ class TaskPublisher(Publisher):
     def _delay_task(self, task_name, task_id=None, part_of_set=None,
     def _delay_task(self, task_name, task_id=None, part_of_set=None,
             task_args=None, task_kwargs=None, **kwargs):
             task_args=None, task_kwargs=None, **kwargs):
         """INTERNAL"""
         """INTERNAL"""
+        eta = kwargs.get("eta")
+        countdown = kwargs.get("countdown")
         priority = kwargs.get("priority")
         priority = kwargs.get("priority")
         immediate = kwargs.get("immediate")
         immediate = kwargs.get("immediate")
         mandatory = kwargs.get("mandatory")
         mandatory = kwargs.get("mandatory")
@@ -55,6 +57,8 @@ class TaskPublisher(Publisher):
             "task": task_name,
             "task": task_name,
             "args": task_args,
             "args": task_args,
             "kwargs": task_kwargs,
             "kwargs": task_kwargs,
+            "countdown": countdown,
+            "eta": eta,
         }
         }
         if part_of_set:
         if part_of_set:
             message_data["taskset"] = part_of_set
             message_data["taskset"] = part_of_set

+ 1 - 96
celery/task/__init__.py

@@ -5,113 +5,18 @@ Working with tasks and task sets.
 """
 """
 from carrot.connection import DjangoAMQPConnection
 from carrot.connection import DjangoAMQPConnection
 from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.conf import AMQP_CONNECTION_TIMEOUT
-from celery.messaging import TaskPublisher
 from celery.registry import tasks
 from celery.registry import tasks
 from celery.backends import default_backend
 from celery.backends import default_backend
-from celery.result import AsyncResult
 from celery.task.base import Task, TaskSet, PeriodicTask
 from celery.task.base import Task, TaskSet, PeriodicTask
 from celery.task.builtins import AsynchronousMapTask, ExecuteRemoteTask
 from celery.task.builtins import AsynchronousMapTask, ExecuteRemoteTask
 from celery.task.builtins import DeleteExpiredTaskMetaTask, PingTask
 from celery.task.builtins import DeleteExpiredTaskMetaTask, PingTask
-from functools import partial as curry
+from celery.execute import apply_async, delay_task
 try:
 try:
     import cPickle as pickle
     import cPickle as pickle
 except ImportError:
 except ImportError:
     import pickle
     import pickle
 
 
 
 
-def apply_async(task, args=None, kwargs=None, routing_key=None,
-        immediate=None, mandatory=None, connection=None,
-        connect_timeout=AMQP_CONNECTION_TIMEOUT, priority=None, **opts):
-    """Run a task asynchronously by the celery daemon(s).
-
-    :param task: The task to run (a callable object, or a :class:`Task`
-        instance
-
-    :param args: The positional arguments to pass on to the task (a ``list``).
-
-    :param kwargs: The keyword arguments to pass on to the task (a ``dict``)
-
-
-    :keyword routing_key: The routing key used to route the task to a worker
-        server.
-
-    :keyword immediate: Request immediate delivery. Will raise an exception
-        if the task cannot be routed to a worker immediately.
-
-    :keyword mandatory: Mandatory routing. Raises an exception if there's
-        no running workers able to take on this task.
-
-    :keyword connection: Re-use existing AMQP connection.
-        The ``connect_timeout`` argument is not respected if this is set.
-
-    :keyword connect_timeout: The timeout in seconds, before we give up
-        on establishing a connection to the AMQP server.
-
-    :keyword priority: The task priority, a number between ``0`` and ``9``.
-
-    """
-    args = args or []
-    kwargs = kwargs or {}
-    routing_key = routing_key or getattr(task, "routing_key", None)
-    immediate = immediate or getattr(task, "immediate", None)
-    mandatory = mandatory or getattr(task, "mandatory", None)
-    priority = priority or getattr(task, "priority", None)
-    taskset_id = opts.get("taskset_id")
-    publisher = opts.get("publisher")
-
-    need_to_close_connection = False
-    if not publisher:
-        if not connection:
-            connection = DjangoAMQPConnection(connect_timeout=connect_timeout)
-            need_to_close_connection = True
-        publisher = TaskPublisher(connection=connection)
-
-    delay_task = publisher.delay_task
-    if taskset_id:
-        delay_task = curry(publisher.delay_task_in_set, taskset_id)
-        
-    task_id = delay_task(task.name, args, kwargs,
-                         routing_key=routing_key, mandatory=mandatory,
-                         immediate=immediate, priority=priority)
-
-    if need_to_close_connection:
-        publisher.close()
-        connection.close()
-
-    return AsyncResult(task_id)
-
-
-def delay_task(task_name, *args, **kwargs):
-    """Delay a task for execution by the ``celery`` daemon.
-
-    :param task_name: the name of a task registered in the task registry.
-
-    :param \*args: positional arguments to pass on to the task.
-
-    :param \*\*kwargs: keyword arguments to pass on to the task.
-
-    :raises celery.registry.NotRegistered: exception if no such task
-        has been registered in the task registry.
-
-    :rtype: :class:`celery.result.AsyncResult`.
-
-    Example
-
-        >>> r = delay_task("update_record", name="George Constanza", age=32)
-        >>> r.ready()
-        True
-        >>> r.result
-        "Record was updated"
-
-    """
-    if task_name not in tasks:
-        raise tasks.NotRegistered(
-                "Task with name %s not registered in the task registry." % (
-                    task_name))
-    task = tasks[task_name]
-    return apply_async(task, args, kwargs)
-
-
 def discard_all(connect_timeout=AMQP_CONNECTION_TIMEOUT):
 def discard_all(connect_timeout=AMQP_CONNECTION_TIMEOUT):
     """Discard all waiting tasks.
     """Discard all waiting tasks.
 
 

+ 1 - 0
celery/task/base.py

@@ -3,6 +3,7 @@ from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.messaging import TaskPublisher, TaskConsumer
 from celery.messaging import TaskPublisher, TaskConsumer
 from celery.log import setup_logger
 from celery.log import setup_logger
 from celery.result import TaskSetResult
 from celery.result import TaskSetResult
+from celery.execute import apply_async, delay_task
 from datetime import timedelta
 from datetime import timedelta
 import uuid
 import uuid
 try:
 try:

+ 41 - 2
celery/worker.py

@@ -11,6 +11,10 @@ from celery.backends import default_backend, default_periodic_status_backend
 from celery.timer import EventTimer
 from celery.timer import EventTimer
 from django.core.mail import mail_admins
 from django.core.mail import mail_admins
 from celery.monitoring import TaskTimerStats
 from celery.monitoring import TaskTimerStats
+from datetime import datetime, timedelta
+from Queue import Queue
+from Queue import Empty as QueueEmpty
+from multiprocessing import TimeoutError
 import multiprocessing
 import multiprocessing
 import traceback
 import traceback
 import threading
 import threading
@@ -345,6 +349,7 @@ class WorkController(object):
         self.is_detached = is_detached
         self.is_detached = is_detached
         self.amqp_connection = None
         self.amqp_connection = None
         self.task_consumer = None
         self.task_consumer = None
+        self.bucket_queue = Queue()
 
 
     def close_connection(self):
     def close_connection(self):
         """Close the AMQP connection."""
         """Close the AMQP connection."""
@@ -395,6 +400,15 @@ class WorkController(object):
 
 
     def process_task(self, message_data, message):
     def process_task(self, message_data, message):
         """Process task message by passing it to the pool of workers."""
         """Process task message by passing it to the pool of workers."""
+        
+        countdown = message_data.get("countdown")
+        eta = message_data.get("eta")
+        if countdown:
+            eta = datetime.now() + timedelta(seconds=int(countdown))
+        if eta:
+            self.bucket_queue.put((message, message_data, eta))
+            return
+
         task = TaskWrapper.from_message(message, message_data,
         task = TaskWrapper.from_message(message, message_data,
                                         logger=self.logger)
                                         logger=self.logger)
         self.logger.info("Got task from broker: %s[%s]" % (
         self.logger.info("Got task from broker: %s[%s]" % (
@@ -406,7 +420,7 @@ class WorkController(object):
 
 
         self.logger.debug("Task %s has been executed asynchronously." % task)
         self.logger.debug("Task %s has been executed asynchronously." % task)
 
 
-        return result
+        return
 
 
     def shutdown(self):
     def shutdown(self):
         """Make sure ``celeryd`` exits cleanly."""
         """Make sure ``celeryd`` exits cleanly."""
@@ -437,6 +451,31 @@ class WorkController(object):
 
 
         try:
         try:
             while True:
             while True:
-                it.next()
+                try:
+                    self.process_bucket()
+                    self.process_next(it, timeout=1)
+                except TimeoutError:
+                    pass
         except (SystemExit, KeyboardInterrupt):
         except (SystemExit, KeyboardInterrupt):
             self.shutdown()
             self.shutdown()
+
+    def process_next(self, it, timeout=1):
+        def on_timeout():
+            raise TimeoutError()
+        timer = threading.Timer(timeout, on_timeout)
+        timer.start()
+        try:
+            it.next()
+        finally:
+            timer.cancel()
+
+    def process_bucket(self):
+        try:
+            message, msg_data, eta = self.bucket_queue.get_nowait()
+        except QueueEmpty:
+            pass
+        else:
+            if datetime.now() >= eta:
+                self.process_task(message, msg_data)
+            else:
+                self.bucket_queue.put((message, msg_data, eta))