Переглянути джерело

Now uses inspect.getargspec to only pass on the default kwargs the task supports.

Ask Solem 15 роки тому
батько
коміт
6374355a1a
4 змінених файлів з 97 додано та 17 видалено
  1. 26 13
      celery/task/base.py
  2. 33 0
      celery/tests/test_worker_job.py
  3. 28 0
      celery/utils.py
  4. 10 4
      celery/worker/job.py

+ 26 - 13
celery/task/base.py

@@ -16,9 +16,9 @@ class Task(object):
 
     All subclasses of :class:`Task` must define the :meth:`run` method,
     which is the actual method the ``celery`` daemon executes.
-    The :meth:`run` method must always take the positional keyword arguments
-    (\*\*kwargs), this is because of the standard arguments always passed to
-    a task (see :meth:`run` for more info)
+
+    The :meth:`run` method can take use of the default keyword arguments,
+    as listed in the :meth:`run` documentation.
 
     The :meth:`run` method supports both positional, and keyword arguments.
 
@@ -142,7 +142,8 @@ class Task(object):
     def run(self, *args, **kwargs):
         """The body of the task executed by the worker.
 
-        The following standard keyword arguments is passed by the worker:
+        The following standard keyword arguments are reserved and is passed
+        by the worker if the function/method supports them:
 
             * task_id
 
@@ -152,6 +153,11 @@ class Task(object):
 
                 Name of the currently executing task (same as :attr:`name`)
 
+            * task_retries
+
+                How many times the current task has been retried
+                (an integer starting at ``0``).
+
             * logfile
 
                 Name of the worker log file.
@@ -163,24 +169,29 @@ class Task(object):
                 ``logging.ERROR``, ``logging.CRITICAL``, ``logging.WARNING``,
                 ``logging.FATAL``.
 
-        Additional standard keyword arguments may be added in the future,
-        so the :meth:`run` method must always take an arbitrary list of
-        keyword arguments (\*\*kwargs).
+        Additional standard keyword arguments may be added in the future.
+        To take these default arguments, the task can either list the ones
+        it wants explicitly or just take an arbitrary list of keyword
+        arguments (\*\*kwargs).
 
-        Example:
+        Example using an explicit list of default arguments to take:
 
         .. code-block:: python
 
-            def run(self, x, y): # WRONG!
+            def run(self, x, y, logfile=None, loglevel=None):
+                self.get_logger(loglevel=loglevel, logfile=logfile)
                 return x * y
 
-        Will fail with an exception because the worker can't send the default
-        arguments. The correct way to define the run method would be:
+
+        Example taking all default keyword arguments, and any extra arguments
+        passed on by the caller:
 
         .. code-block:: python
 
             def run(self, x, y, **kwargs): # CORRECT!
-                return x * y
+                logger = self.get_logger(**kwargs)
+                adjust = kwargs.get("adjust", 0)
+                return x * y - adjust
 
         """
         raise NotImplementedError("Tasks must define a run method.")
@@ -191,7 +202,9 @@ class Task(object):
         See :func:`celery.log.setup_logger`.
 
         """
-        return setup_logger(**kwargs)
+        logfile = kwargs.get("logfile")
+        loglevel = kwargs.get("loglevel")
+        return setup_logger(loglevel=loglevel, logfile=logfile)
 
     def get_publisher(self, connect_timeout=AMQP_CONNECTION_TIMEOUT):
         """Get a celery task message publisher.

+ 33 - 0
celery/tests/test_worker_job.py

@@ -31,6 +31,19 @@ def mytask(i, **kwargs):
 tasks.register(mytask, name="cu.mytask")
 
 
+def mytask_no_kwargs(i):
+    return i ** i
+tasks.register(mytask_no_kwargs, name="mytask_no_kwargs")
+
+
+some_kwargs_scratchpad = {}
+
+def mytask_some_kwargs(i, logfile):
+    some_kwargs_scratchpad["logfile"] = logfile
+    return i ** i
+tasks.register(mytask_some_kwargs, name="mytask_some_kwargs")
+
+
 def mytask_raising(i, **kwargs):
     raise KeyError(i)
 tasks.register(mytask_raising, name="cu.mytask-raising")
@@ -48,6 +61,7 @@ class TestJail(unittest.TestCase):
         ret = jail(gen_unique_id(), gen_unique_id(), mytask, [2], {})
         self.assertEquals(ret, 4)
 
+
     def test_execute_jail_failure(self):
         ret = jail(gen_unique_id(), gen_unique_id(), mytask_raising, [4], {})
         self.assertTrue(isinstance(ret, ExceptionInfo))
@@ -177,6 +191,25 @@ class TestTaskWrapper(unittest.TestCase):
         meta = TaskMeta.objects.get(task_id=tid)
         self.assertEquals(meta.result, 256)
         self.assertEquals(meta.status, "DONE")
+    
+    def test_execute_success_no_kwargs(self):
+        tid = gen_unique_id()
+        tw = TaskWrapper("cu.mytask_no_kwargs", tid, mytask_no_kwargs,
+                         [4], {})
+        self.assertEquals(tw.execute(), 256)
+        meta = TaskMeta.objects.get(task_id=tid)
+        self.assertEquals(meta.result, 256)
+        self.assertEquals(meta.status, "DONE")
+    
+    def test_execute_success_some_kwargs(self):
+        tid = gen_unique_id()
+        tw = TaskWrapper("cu.mytask_some_kwargs", tid, mytask_some_kwargs,
+                         [4], {})
+        self.assertEquals(tw.execute(logfile="foobaz.log"), 256)
+        meta = TaskMeta.objects.get(task_id=tid)
+        self.assertEquals(some_kwargs_scratchpad.get("logfile"), "foobaz.log")
+        self.assertEquals(meta.result, 256)
+        self.assertEquals(meta.status, "DONE")
 
     def test_execute_ack(self):
         tid = gen_unique_id()

+ 28 - 0
celery/utils.py

@@ -5,8 +5,11 @@ Utility functions
 """
 import time
 from itertools import repeat
+from inspect import getargspec
+from functools import partial as curry
 from uuid import UUID, uuid4, _uuid_generate_random
 import ctypes
+import operator
 
 noop = lambda *args, **kwargs: None
 
@@ -112,3 +115,28 @@ def retry_over_time(fun, catch, args=[], kwargs={}, errback=noop,
             time.sleep(interval)
         else:
             return retval
+
+
+def fun_takes_kwargs(fun, kwlist=[]):
+    """With a function, and a list of keyword arguments, returns arguments
+    in the list which the function takes.
+
+    :param fun: The function to inspect arguments of.
+    :param kwlist: The list of keyword arguments.
+
+    Examples
+
+        >>> def foo(self, x, y, logfile=None, loglevel=None):
+        ...     return x * y
+        >>> fun_takes_kwargs(foo, ["logfile", "loglevel", "task_id"])
+        ["logfile", "loglevel"]
+
+        >>> def foo(self, x, y, **kwargs):
+        >>> fun_takes_kwargs(foo, ["logfile", "loglevel", "task_id"])
+        ["logfile", "loglevel", "task_id"]
+
+    """
+    args, _varargs, keywords, _defaults = getargspec(fun)
+    if keywords != None:
+        return kwlist
+    return filter(curry(operator.contains, args), kwlist)

+ 10 - 4
celery/worker/job.py

@@ -6,7 +6,7 @@ Jobs Executable by the Worker Server.
 from celery.registry import tasks
 from celery.exceptions import NotRegistered
 from celery.execute import ExecuteWrapper
-from celery.utils import noop
+from celery.utils import noop, fun_takes_kwargs
 from django.core.mail import mail_admins
 import multiprocessing
 import socket
@@ -130,16 +130,22 @@ class TaskWrapper(object):
     def extend_with_default_kwargs(self, loglevel, logfile):
         """Extend the tasks keyword arguments with standard task arguments.
 
-        These are ``logfile``, ``loglevel``, ``task_id`` and ``task_name``.
+        Currently these are ``logfile``, ``loglevel``, ``task_id``,
+        ``task_name`` and ``task_retries``.
+
+        See :meth:`celery.task.base.Task.run` for more information.
 
         """
         kwargs = dict(self.kwargs)
-        task_func_kwargs = {"logfile": logfile,
+        default_kwargs = {"logfile": logfile,
                             "loglevel": loglevel,
                             "task_id": self.task_id,
                             "task_name": self.task_name,
                             "task_retries": self.retries}
-        kwargs.update(task_func_kwargs)
+        supported_keys = fun_takes_kwargs(self.task_func, default_kwargs)
+        extend_with = dict((key, val) for key, val in default_kwargs.items()
+                                if key in supported_keys)
+        kwargs.update(extend_with)
         return kwargs
 
     def _executeable(self, loglevel=None, logfile=None):