Ask Solem hace 13 años
padre
commit
2438c1f64d

+ 1 - 1
celery/app/task/__init__.py

@@ -731,7 +731,7 @@ class BaseTask(object):
     def execute(self, request, pool, loglevel, logfile, **kwargs):
         """The method the worker calls to execute the task.
 
-        :param request: A :class:`~celery.worker.job.TaskRequest`.
+        :param request: A :class:`~celery.worker.job.Request`.
         :param pool: A task pool.
         :param loglevel: Current loglevel.
         :param logfile: Name of the currently used logfile.

+ 1 - 1
celery/concurrency/solo.py

@@ -7,7 +7,7 @@ from .base import BasePool, apply_target
 
 
 class TaskPool(BasePool):
-    """Solo task pool (blocking, inline)."""
+    """Solo task pool (blocking, inline, fast)."""
 
     def __init__(self, *args, **kwargs):
         super(TaskPool, self).__init__(*args, **kwargs)

+ 7 - 7
celery/tests/test_worker/__init__.py

@@ -20,7 +20,7 @@ from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.worker import WorkController
 from celery.worker.buckets import FastQueue
-from celery.worker.job import TaskRequest
+from celery.worker.job import Request
 from celery.worker.consumer import Consumer as MainConsumer
 from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX, CLOSE
 from celery.utils.serialization import pickle
@@ -344,7 +344,7 @@ class test_Consumer(unittest.TestCase):
         l.receive_message(m.decode(), m)
 
         in_bucket = self.ready_queue.get_nowait()
-        self.assertIsInstance(in_bucket, TaskRequest)
+        self.assertIsInstance(in_bucket, Request)
         self.assertEqual(in_bucket.task_name, foo_task.name)
         self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
         self.assertTrue(self.eta_schedule.empty())
@@ -572,7 +572,7 @@ class test_Consumer(unittest.TestCase):
         self.assertEqual(len(in_hold), 3)
         eta, priority, entry = in_hold
         task = entry.args[0]
-        self.assertIsInstance(task, TaskRequest)
+        self.assertIsInstance(task, Request)
         self.assertEqual(task.task_name, foo_task.name)
         self.assertEqual(task.execute(), 2 * 4 * 8)
         with self.assertRaises(Empty):
@@ -815,7 +815,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
-        task = TaskRequest.from_message(m, m.decode())
+        task = Request.from_message(m, m.decode())
         worker.process_task(task)
         self.assertEqual(worker.pool.apply_async.call_count, 1)
         worker.pool.stop()
@@ -827,7 +827,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
-        task = TaskRequest.from_message(m, m.decode())
+        task = Request.from_message(m, m.decode())
         worker.components = []
         worker._state = worker.RUN
         with self.assertRaises(KeyboardInterrupt):
@@ -841,7 +841,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
-        task = TaskRequest.from_message(m, m.decode())
+        task = Request.from_message(m, m.decode())
         worker.components = []
         worker._state = worker.RUN
         with self.assertRaises(SystemExit):
@@ -855,7 +855,7 @@ class test_WorkController(AppCase):
         backend = Mock()
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
-        task = TaskRequest.from_message(m, m.decode())
+        task = Request.from_message(m, m.decode())
         worker.process_task(task)
         worker.pool.stop()
 

+ 1 - 1
celery/tests/test_worker/test_worker_job.py

@@ -29,7 +29,7 @@ from celery.task import task as task_dec
 from celery.task.base import Task
 from celery.utils import uuid
 from celery.utils.encoding import from_utf8, default_encode
-from celery.worker.job import TaskRequest, execute_and_trace
+from celery.worker.job import Request, execute_and_trace
 from celery.worker.state import revoked
 
 from celery.tests.compat import catch_warnings

+ 1 - 1
celery/worker/buckets.py

@@ -70,7 +70,7 @@ class TaskBucket(object):
         self.not_empty = threading.Condition(self.mutex)
 
     def put(self, request):
-        """Put a :class:`~celery.worker.job.TaskRequest` into
+        """Put a :class:`~celery.worker.job.Request` into
         the appropiate bucket."""
         with self.mutex:
             if request.task_name not in self.buckets:

+ 3 - 2
celery/worker/consumer.py

@@ -34,7 +34,7 @@ up and running.
   a `task` key or a `control` key.
 
   If the message is a task, it verifies the validity of the message
-  converts it to a :class:`celery.worker.job.TaskRequest`, and sends
+  converts it to a :class:`celery.worker.job.Request`, and sends
   it to :meth:`~Consumer.on_task`.
 
   If the message is a control command the message is passed to
@@ -356,7 +356,8 @@ class Consumer(object):
         if self.event_dispatcher.enabled:
             self.event_dispatcher.send("task-received", uuid=task.task_id,
                     name=task.task_name, args=safe_repr(task.args),
-                    kwargs=safe_repr(task.kwargs), retries=task.retries,
+                    kwargs=safe_repr(task.kwargs),
+                    retries=task.request_dict.get("retries", 0),
                     eta=task.eta and task.eta.isoformat(),
                     expires=task.expires and task.expires.isoformat())
 

+ 84 - 138
celery/worker/job.py

@@ -3,7 +3,7 @@
     celery.worker.job
     ~~~~~~~~~~~~~~~~~
 
-    This module defines the :class:`TaskRequest` class,
+    This module defines the :class:`Request` class,
     which specifies how tasks are executed.
 
     :copyright: (c) 2009 - 2011 by Ask Solem.
@@ -15,6 +15,7 @@ from __future__ import absolute_import
 import logging
 import time
 import socket
+import sys
 
 from datetime import datetime
 
@@ -24,8 +25,7 @@ from ..registry import tasks
 from ..app import app_or_default
 from ..execute.trace import build_tracer, trace_task, report_internal_error
 from ..platforms import set_mp_process_title as setps
-from ..utils import (noop, kwdict, fun_takes_kwargs,
-                     cached_property, truncate_text)
+from ..utils import noop, kwdict, fun_takes_kwargs, truncate_text
 from ..utils.encoding import safe_repr, safe_str
 from ..utils.timeutils import maybe_iso8601, timezone
 
@@ -36,6 +36,8 @@ tz_to_local = timezone.to_local
 tz_or_local = timezone.tz_or_local
 tz_utc = timezone.utc
 
+NEEDS_KWDICT = sys.version_info <= (2, 6)
+
 
 def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
     """This is a pickleable method used as a target when applying to pools.
@@ -59,51 +61,16 @@ def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
         return report_internal_error(task, exc)
 
 
-class TaskRequest(object):
+class Request(object):
     """A request for task execution."""
-
-    #: Kind of task.  Must be a name registered in the task registry.
-    name = None
-
-    #: The task class (set by constructor using :attr:`name`).
-    task = None
-
-    #: UUID of the task.
-    id = None
-
-    #: UUID of the taskset that this task belongs to.
-    taskset = None
-
-    #: List of positional arguments to apply to the task.
-    args = None
-
-    #: Mapping of keyword arguments to apply to the task.
-    kwargs = None
-
-    #: Number of times the task has been retried.
-    retries = 0
-
-    #: The tasks eta (for information only).
-    eta = None
-
-    #: When the task expires.
-    expires = None
-
-    #: Body of a chord depending on this task.
-    chord = None
-
-    #: Callback called when the task should be acknowledged.
-    on_ack = None
-
-    #: The message object.  Used to acknowledge the message.
-    message = None
-
-    #: Additional delivery info, e.g. contains the path from
-    #: Producer to consumer.
-    delivery_info = None
-
-    #: Flag set when the task has been acknowledged.
-    acknowledged = False
+    __slots__ = ("app", "name", "id", "args", "kwargs",
+                 "on_ack", "delivery_info", "hostname",
+                 "logger", "eventer", "connection_errors",
+                 "task", "eta", "expires",
+                 "_does_debug", "_does_info", "request_dict",
+                 "acknowledged", "success_msg", "error_msg",
+                 "retry_msg", "time_start", "worker_pid",
+                 "_already_revoked", "_terminate_on_ack", "_tzinfo")
 
     #: Format string used to log task success.
     success_msg = """\
@@ -118,94 +85,63 @@ class TaskRequest(object):
     #: Format string used to log task retry.
     retry_msg = """Task %(name)s[%(id)s] retry: %(exc)s"""
 
-    #: Timestamp set when the task is started.
-    time_start = None
-
-    #: Process id of the worker processing this task (if any).
-    worker_pid = None
 
-    _already_revoked = False
-    _terminate_on_ack = None
-
-    def __init__(self, task, id, args=[], kwargs={},
-            on_ack=noop, retries=0, delivery_info={}, hostname=None,
-            logger=None, eventer=None, eta=None, expires=None, app=None,
-            taskset=None, chord=None, utc=False, connection_errors=None,
-            **opts):
-        try:
-            kwargs.items
-        except AttributeError:
-            raise exceptions.InvalidTaskError(
-                    "Task keyword arguments is not a mapping")
+    def __init__(self, body, on_ack=noop,
+            hostname=None, logger=None, eventer=None, app=None,
+            connection_errors=None, request_dict=None,
+            delivery_info=None, task=None, **opts):
         self.app = app or app_or_default(app)
-        self.name = task
-        self.id = id
-        self.args = args
-        self.kwargs = kwdict(kwargs)
-        self.taskset = taskset
-        self.retries = retries
-        self.chord = chord
+        name = self.name = body["task"]
+        self.id = body["id"]
+        self.args = body.get("args", [])
+        self.kwargs = body.get("kwargs", {})
+        eta = body.get("eta")
+        expires = body.get("expires")
         self.on_ack = on_ack
-        self.delivery_info = {} if delivery_info is None else delivery_info
         self.hostname = hostname or socket.gethostname()
         self.logger = logger or self.app.log.get_default_logger()
         self.eventer = eventer
         self.connection_errors = connection_errors or ()
-
-        task = self.task = tasks[task]
+        self.task = task or tasks[name]
+        self.acknowledged = self._already_revoked = False
+        self.time_start = self.worker_pid = self._terminate_on_ack = None
+        self._tzinfo = None
 
         # timezone means the message is timezone-aware, and the only timezone
         # supported at this point is UTC.
         if eta is not None:
             tz = tz_utc if utc else self.tzlocal
             self.eta = tz_to_local(maybe_iso8601(eta), self.tzlocal, tz)
+        else:
+            self.eta = None
         if expires is not None:
             tz = tz_utc if utc else self.tzlocal
             self.expires = tz_to_local(maybe_iso8601(expires),
                                        self.tzlocal, tz)
+        else:
+            self.expires = None
 
-        # shortcuts
+        self.delivery_info = {
+            "exchange": delivery_info.get("exchange"),
+            "routing_key": delivery_info.get("routing_key"),
+        }
+
+        ## shortcuts
         self._does_debug = self.logger.isEnabledFor(logging.DEBUG)
         self._does_info = self.logger.isEnabledFor(logging.INFO)
 
-        self.request_dict = {"hostname": self.hostname,
-                             "id": id, "taskset": taskset,
-                             "retries": retries, "is_eager": False,
-                             "delivery_info": delivery_info, "chord": chord}
-
-    @classmethod
-    def from_message(cls, message, body, on_ack=noop, delivery_info={},
-            logger=None, hostname=None, eventer=None, app=None,
-            connection_errors=None):
-        """Create request from a task message.
+        self.request_dict = body
 
-        :raises UnknownTaskError: if the message does not describe a task,
-            the message is also rejected.
-
-        """
         try:
-            D = message.delivery_info
-            delivery_info = {"exchange": D.get("exchange"),
-                             "routing_key": D.get("routing_key")}
-        except (AttributeError, KeyError):
-            pass
+            self.kwargs.items
+        except AttributeError:
+            raise exceptions.InvalidTaskError(
+                    "Task keyword arguments is not a mapping")
 
-        try:
-            return cls(on_ack=on_ack, logger=logger, eventer=eventer, app=app,
-                       delivery_info=delivery_info, hostname=hostname,
-                       connection_errors=connection_errors, **body)
-        except TypeError:
-            for f in ("task", "id"):
-                if f not in body:
-                    raise exceptions.InvalidTaskError(
-                        "Task message is missing required field %r" % (f, ))
-
-    def get_instance_attrs(self, loglevel, logfile):
-        return {"logfile": logfile, "loglevel": loglevel,
-                "hostname": self.hostname,
-                "id": self.id, "taskset": self.taskset,
-                "retries": self.retries, "is_eager": False,
-                "delivery_info": self.delivery_info, "chord": self.chord}
+    @classmethod
+    def from_message(cls, message, body, **kwargs):
+        # should be deprecated
+        return cls(body, delivery_info=message.delivery_info, **kwargs)
 
     def extend_with_default_kwargs(self, loglevel, logfile):
         """Extend the tasks keyword arguments with standard task arguments.
@@ -224,7 +160,7 @@ class TaskRequest(object):
                           "loglevel": loglevel,
                           "task_id": self.id,
                           "task_name": self.name,
-                          "task_retries": self.retries,
+                          "task_retries": self.request_dict["retries"],
                           "task_is_eager": False,
                           "delivery_info": self.delivery_info}
         fun = self.task.run
@@ -235,7 +171,7 @@ class TaskRequest(object):
         return kwargs
 
     def execute_using_pool(self, pool, loglevel=None, logfile=None):
-        """Like :meth:`execute`, but using the :mod:`multiprocessing` pool.
+        """Like :meth:`execute`, but using a worker pool.
 
         :param pool: A :class:`multiprocessing.Pool` instance.
 
@@ -246,22 +182,26 @@ class TaskRequest(object):
         """
         if self.revoked():
             return
-        request = self.request_dict
 
+        task = self.task
+        hostname = self.hostname
         kwargs = self.kwargs
         if self.task.accept_magic_kwargs:
             kwargs = self.extend_with_default_kwargs(loglevel, logfile)
-        request.update({"loglevel": loglevel, "logfile": logfile})
+        request = self.request_dict
+        request.update({"loglevel": loglevel, "logfile": logfile,
+                        "hostname": hostname, "is_eager": False,
+                        "delivery_info": self.delivery_info})
         result = pool.apply_async(execute_and_trace,
                                   args=(self.name, self.id, self.args, kwargs),
-                                  kwargs={"hostname": self.hostname,
+                                  kwargs={"hostname": hostname,
                                           "request": request},
                                   accept_callback=self.on_accepted,
                                   timeout_callback=self.on_timeout,
                                   callback=self.on_success,
                                   error_callback=self.on_failure,
-                                  soft_timeout=self.task.soft_time_limit,
-                                  timeout=self.task.time_limit)
+                                  soft_timeout=task.soft_time_limit,
+                                  timeout=task.time_limit)
         return result
 
     def execute(self, loglevel=None, logfile=None):
@@ -279,11 +219,17 @@ class TaskRequest(object):
         if not self.task.acks_late:
             self.acknowledge()
 
-        instance_attrs = self.get_instance_attrs(loglevel, logfile)
-        retval, _ = trace_task(*self._get_tracer_args(loglevel, logfile, True),
+        kwargs = self.kwargs
+        if self.task.accept_magic_kwargs:
+            kwargs = self.extend_with_default_kwargs(loglevel, logfile)
+        request = self.request_dict
+        request.update({"loglevel": loglevel, "logfile": logfile,
+                        "hostname": self.hostname, "is_eager": False,
+                        "delivery_info": self.delivery_info})
+        retval, _ = trace_task(self.task, self.id, self.args, kwargs,
                                **{"hostname": self.hostname,
                                   "loader": self.app.loader,
-                                  "request": instance_attrs})
+                                  "request": request})
         self.acknowledge()
         return retval
 
@@ -460,17 +406,11 @@ class TaskRequest(object):
                 self.__class__.__name__,
                 self.name, self.id, self.args, self.kwargs)
 
-    def _get_tracer_args(self, loglevel=None, logfile=None, use_real=False):
-        """Get the task trace args for this task."""
-        kwargs = self.kwargs
-        if self.task.accept_magic_kwargs:
-            kwargs = self.extend_with_default_kwargs(loglevel, logfile)
-        first = self.task if use_real else self.name
-        return first, self.id, self.args, kwargs
-
-    @cached_property
+    @property
     def tzlocal(self):
-        return tz_or_local(self.app.conf.CELERY_TIMEZONE)
+        if self._tzlocal is None:
+            self._tzlocal = tz_or_local(self.app.conf.CELERY_TIMEZONE)
+        return self._tzlocal
 
     @property
     def store_errors(self):
@@ -489,12 +429,18 @@ class TaskRequest(object):
     def _compat_set_task_name(self, value):
         self.name = value
 
-    def _compat_get_taskset_id(self):
-        return self.taskset
-
-    def _compat_set_taskset_id(self, value):
-        self.taskset = value
-
     task_id = property(_compat_get_task_id, _compat_set_task_id)
     task_name = property(_compat_get_task_name, _compat_set_task_name)
-    taskset_id = property(_compat_get_taskset_id, _compat_set_taskset_id)
+
+
+class TaskRequest(Request):
+
+    def __init__(name, id, args=(), kwargs={},
+            eta=None, expires=None, **options):
+        """Compatibility class."""
+
+        super(TaskRequest, self).__init__({
+            "task": name, "id": id, "args": args,
+            "kwargs": kwargs, "eta": eta,
+            "expires": expires}, **options)
+

+ 2 - 2
celery/worker/state.py

@@ -36,10 +36,10 @@ REVOKES_MAX = 10000
 #: being expired when the max limit has been exceeded.
 REVOKE_EXPIRES = 3600
 
-#: set of all reserved :class:`~celery.worker.job.TaskRequest`'s.
+#: set of all reserved :class:`~celery.worker.job.Request`'s.
 reserved_requests = set()
 
-#: set of currently active :class:`~celery.worker.job.TaskRequest`'s.
+#: set of currently active :class:`~celery.worker.job.Request`'s.
 active_requests = set()
 
 #: count of tasks executed by the worker, sorted by type.

+ 9 - 5
celery/worker/strategy.py

@@ -1,19 +1,23 @@
 from __future__ import absolute_import
 
-from .job import TaskRequest
+from .job import Request
+
+from celery.execute.trace import trace_task
 
 
 def default(task, app, consumer):
     logger = consumer.logger
     hostname = consumer.hostname
     eventer = consumer.event_dispatcher
-    Request = TaskRequest.from_message
+    Req = Request
     handle = consumer.on_task
     connection_errors = consumer.connection_errors
 
     def task_message_handler(M, B, A):
-        handle(Request(M, B, A, app=app, logger=logger,
-                                hostname=hostname, eventer=eventer,
-                                connection_errors=connection_errors))
+        handle(Req(B, on_ack=A, app=app, hostname=hostname,
+                         eventer=eventer, logger=logger,
+                         connection_errors=connection_errors,
+                         delivery_info=M.delivery_info,
+                         task=task))
 
     return task_message_handler

+ 2 - 1
funtests/benchmarks/bench_worker.py

@@ -29,6 +29,7 @@ celery.conf.update(BROKER_TRANSPORT="librabbitmq",
                            "no_ack": True,
                            #"exchange_durable": False,
                            #"queue_durable": False,
+                           "auto_delete": True,
                         }
                    },
                    CELERY_TASK_SERIALIZER="json",
@@ -65,7 +66,7 @@ def bench_apply(n=DEFAULT_ITS):
     print("-- apply %s tasks: %ss" % (n, time.time() - time_start, ))
 
 
-def bench_work(n=DEFAULT_ITS, loglevel=None):
+def bench_work(n=DEFAULT_ITS, loglevel="CRITICAL"):
     loglevel = os.environ.get("BENCH_LOGLEVEL") or loglevel
     if loglevel:
         celery.log.setup_logging_subsystem(loglevel=loglevel)