Browse Source

Added the ability to set an expiry date and time for tasks.

Example:

    # Task expires after one minute from now.
    task.apply_async(args, kwargs,
                     expires=datetime.now() + timedelta(minutes=1)
Ask Solem 14 years ago
parent
commit
07c589fa27

+ 13 - 7
celery/execute/__init__.py

@@ -17,7 +17,7 @@ extract_exec_options = mattrgetter("queue", "routing_key", "exchange",
 @with_connection
 def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
         task_id=None, publisher=None, connection=None, connect_timeout=None,
-        router=None, **options):
+        router=None, expires=None, **options):
     """Run a task asynchronously by the celery daemon(s).
 
     :param task: The :class:`~celery.task.base.Task` to run.
@@ -33,9 +33,13 @@ def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
       the ``immediate`` setting, they are unrelated).
 
     :keyword eta: A :class:`~datetime.datetime` object that describes the
-      absolute time when the task should execute. May not be specified
-      if ``countdown`` is also supplied. (Do not confuse this with the
-      ``immediate`` setting, they are unrelated).
+      absolute time and date of when the task should execute. May not be
+      specified if ``countdown`` is also supplied. (Do not confuse this
+      with the ``immediate`` setting, they are unrelated).
+
+    :keyword expires: A :class:`~datetime.datetime` object that describes
+      the absolute time and date of when the task should expire.
+      The task will not be executed after the expiration time.
 
     :keyword connection: Re-use existing broker connection instead
       of establishing a new one. The ``connect_timeout`` argument is
@@ -96,7 +100,8 @@ def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
                                               exchange_type=exchange_type)
     try:
         task_id = publish.delay_task(task.name, args, kwargs, task_id=task_id,
-                                     countdown=countdown, eta=eta, **options)
+                                     countdown=countdown, eta=eta,
+                                     expires=expires, **options)
     finally:
         publisher or publish.close()
 
@@ -106,7 +111,7 @@ def apply_async(task, args=None, kwargs=None, countdown=None, eta=None,
 @with_connection
 def send_task(name, args=None, kwargs=None, countdown=None, eta=None,
         task_id=None, publisher=None, connection=None, connect_timeout=None,
-        result_cls=AsyncResult, **options):
+        result_cls=AsyncResult, expires=None, **options):
     """Send task by name.
 
     Useful if you don't have access to the :class:`~celery.task.base.Task`
@@ -124,7 +129,8 @@ def send_task(name, args=None, kwargs=None, countdown=None, eta=None,
                                          exchange_type=exchange_type)
     try:
         task_id = publish.delay_task(name, args, kwargs, task_id=task_id,
-                                     countdown=countdown, eta=eta, **options)
+                                     countdown=countdown, eta=eta,
+                                     expires=expires, **options)
     finally:
         publisher or publish.close()
 

+ 3 - 1
celery/messaging.py

@@ -52,7 +52,8 @@ class TaskPublisher(Publisher):
             _exchanges_declared.add(self.exchange)
 
     def delay_task(self, task_name, task_args=None, task_kwargs=None,
-            countdown=None, eta=None, task_id=None, taskset_id=None, **kwargs):
+            countdown=None, eta=None, task_id=None, taskset_id=None,
+            expires=None, **kwargs):
         """Delay task for execution by the celery nodes."""
 
         task_id = task_id or gen_unique_id()
@@ -73,6 +74,7 @@ class TaskPublisher(Publisher):
             "kwargs": task_kwargs or {},
             "retries": kwargs.get("retries", 0),
             "eta": eta and eta.isoformat(),
+            "expires": expires and expires.isoformat(),
         }
 
         if taskset_id:

+ 1 - 2
celery/tests/test_task.py

@@ -10,14 +10,13 @@ from celery import task
 from celery import messaging
 from celery.task.schedules import crontab, crontab_parser
 from celery.utils import timeutils
-from celery.utils import gen_unique_id
+from celery.utils import gen_unique_id, parse_iso8601
 from celery.utils.functional import wraps
 from celery.result import EagerResult
 from celery.execute import send_task
 from celery.backends import default_backend
 from celery.decorators import task as task_dec
 from celery.exceptions import RetryTaskError
-from celery.worker.listener import parse_iso8601
 
 from celery.tests.utils import with_eager_tasks
 

+ 11 - 0
celery/utils/__init__.py

@@ -7,11 +7,13 @@ try:
 except ImportError:
     ctypes = None
 import importlib
+from datetime import datetime
 from uuid import UUID, uuid4, _uuid_generate_random
 from inspect import getargspec
 from itertools import islice
 
 from carrot.utils import rpartition
+from dateutil.parser import parse as parse_iso8601
 
 from celery.utils.compat import all, any, defaultdict
 from celery.utils.timeutils import timedelta_seconds # was here before
@@ -98,6 +100,15 @@ def noop(*args, **kwargs):
     pass
 
 
+def maybe_iso8601(dt):
+    """``Either datetime | str -> datetime or None -> None``"""
+    if not dt:
+        return
+    if isinstance(dt, datetime):
+        return dt
+    return parse_iso8601(dt)
+
+
 def kwdict(kwargs):
     """Make sure keyword arguments are not in unicode.
 

+ 24 - 3
celery/worker/job.py

@@ -3,6 +3,8 @@ import time
 import socket
 import warnings
 
+from datetime import datetime
+
 from celery import conf
 from celery import log
 from celery import platform
@@ -10,7 +12,7 @@ from celery.datastructures import ExceptionInfo
 from celery.execute.trace import TaskTrace
 from celery.loaders import current_loader
 from celery.registry import tasks
-from celery.utils import noop, kwdict, fun_takes_kwargs
+from celery.utils import noop, kwdict, fun_takes_kwargs, maybe_iso8601
 from celery.utils.compat import any
 from celery.utils.mail import mail_admins
 from celery.worker import state
@@ -208,12 +210,14 @@ class TaskRequest(object):
     def __init__(self, task_name, task_id, args, kwargs,
             on_ack=noop, retries=0, delivery_info=None, hostname=None,
             email_subject=None, email_body=None, logger=None,
-            eventer=None, **opts):
+            eventer=None, eta=None, expires=None, **opts):
         self.task_name = task_name
         self.task_id = task_id
         self.retries = retries
         self.args = args
         self.kwargs = kwargs
+        self.eta = eta
+        self.expires = expires
         self.on_ack = on_ack
         self.delivery_info = delivery_info or {}
         self.hostname = hostname or socket.gethostname()
@@ -224,9 +228,16 @@ class TaskRequest(object):
 
         self.task = tasks[self.task_name]
 
+    def maybe_expire(self):
+        if self.expires and datetime.now() > self.expires:
+            state.revoked.add(self.task_id)
+            self.task.backend.mark_as_revoked(self.task_id)
+
     def revoked(self):
         if self._already_revoked:
             return True
+        if self.expires:
+            self.maybe_expire()
         if self.task_id in state.revoked:
             self.logger.warn("Skipping revoked task: %s[%s]" % (
                 self.task_name, self.task_id))
@@ -253,6 +264,8 @@ class TaskRequest(object):
         args = message_data["args"]
         kwargs = message_data["kwargs"]
         retries = message_data.get("retries", 0)
+        eta = maybe_iso8601(message_data.get("eta"))
+        expires = maybe_iso8601(message_data.get("expires"))
 
         _delivery_info = getattr(message, "delivery_info", {})
         delivery_info = dict((key, _delivery_info.get(key))
@@ -265,7 +278,8 @@ class TaskRequest(object):
         return cls(task_name, task_id, args, kwdict(kwargs),
                    retries=retries, on_ack=message.ack,
                    delivery_info=delivery_info, logger=logger,
-                   eventer=eventer, hostname=hostname)
+                   eventer=eventer, hostname=hostname,
+                   eta=eta, expires=expires)
 
     def extend_with_default_kwargs(self, loglevel, logfile):
         """Extend the tasks keyword arguments with standard task arguments.
@@ -445,3 +459,10 @@ class TaskRequest(object):
                 "time_start": self.time_start,
                 "acknowledged": self.acknowledged,
                 "delivery_info": self.delivery_info}
+
+    def shortinfo(self):
+        return "%s[%s]%s%s" % (
+                    self.task_name,
+                    self.task_id,
+                    self.eta and " eta:[%s]" % (self.eta, ),
+                    self.expires and " expires:[%s]" % (self.expires, ))

+ 8 - 13
celery/worker/listener.py

@@ -80,11 +80,10 @@ import warnings
 
 from datetime import datetime
 
-from dateutil.parser import parse as parse_iso8601
 from carrot.connection import AMQPConnectionException
 
 from celery import conf
-from celery.utils import noop, retry_over_time
+from celery.utils import noop, retry_over_time, maybe_iso8601
 from celery.worker.job import TaskRequest, InvalidTaskError
 from celery.worker.control import ControlDispatch
 from celery.worker.heartbeat import Heart
@@ -249,7 +248,7 @@ class CarrotListener(object):
                 self.qos.update()
             wait_for_message()
 
-    def on_task(self, task, eta=None):
+    def on_task(self, task):
         """Handle received task.
 
         If the task has an ``eta`` we enter it into the ETA schedule,
@@ -260,21 +259,17 @@ class CarrotListener(object):
         if task.revoked():
             return
 
+        self.logger.info("Got task from broker: %s" % (task.shortinfo(), ))
+
         self.event_dispatcher.send("task-received", uuid=task.task_id,
                 name=task.task_name, args=repr(task.args),
-                kwargs=repr(task.kwargs), retries=task.retries, eta=eta)
+                kwargs=repr(task.kwargs), retries=task.retries, eta=task.eta)
 
-        if eta:
-            if not isinstance(eta, datetime):
-                eta = parse_iso8601(eta)
+        if task.eta:
             self.qos.increment()
-            self.logger.info("Got task from broker: %s[%s] eta:[%s]" % (
-                    task.task_name, task.task_id, eta))
-            self.eta_schedule.enter(task, eta=eta,
+            self.eta_schedule.enter(task, eta=task.eta,
                     callback=self.qos.decrement_eventually)
         else:
-            self.logger.info("Got task from broker: %s[%s]" % (
-                    task.task_name, task.task_id))
             self.ready_queue.put(task)
 
     def on_control(self, control):
@@ -300,7 +295,7 @@ class CarrotListener(object):
                         str(exc), message_data))
                 message.ack()
             else:
-                self.on_task(task, eta=message_data.get("eta"))
+                self.on_task(task)
             return
 
         # Handle control command

+ 8 - 0
docs/internals/protocol.rst

@@ -42,6 +42,14 @@ Message format
     format. If not provided the message is not scheduled, but will be
     executed asap.
 
+* expires (introduced after v2.0.2)
+    ``string`` (ISO 8601)
+
+    Expiration date. This is the date and time in ISO 8601 format.
+    If not provided the message will never expire. The message
+    will be expired when the message is received and the expiration date
+    has been exceeded.
+
 Example message
 ===============