Browse Source

Optimization: Now at 10.000 tasks/s with librabbitmq

Ask Solem 13 years ago
parent
commit
29ad3bde87

+ 1 - 0
celery/app/task/__init__.py

@@ -737,6 +737,7 @@ class BaseTask(object):
         :keyword consumer: The :class:`~celery.worker.consumer.Consumer`.
 
         """
+        #request.execute(loglevel, logfile)
         request.execute_using_pool(pool, loglevel, logfile)
 
     def __repr__(self):

+ 18 - 12
celery/platforms.py

@@ -561,20 +561,26 @@ def set_process_title(progname, info=None):
     return proctitle
 
 
-def set_mp_process_title(progname, info=None, hostname=None, rate_limit=False):
-    """Set the ps name using the multiprocessing process name.
+if os.environ.get("NOSETPS"):
 
-    Only works if :mod:`setproctitle` is installed.
+    def set_mp_process_title(*a, **k):
+        pass
+else:
 
-    """
-    if not rate_limit or _setps_bucket.can_consume(1):
-        if hostname:
-            progname = "%s@%s" % (progname, hostname.split(".")[0])
-        if current_process is not None:
-            return set_process_title(
-                "%s:%s" % (progname, current_process().name), info=info)
-        else:
-            return set_process_title(progname, info=info)
+    def set_mp_process_title(progname, info=None, hostname=None, rate_limit=False):
+        """Set the ps name using the multiprocessing process name.
+
+        Only works if :mod:`setproctitle` is installed.
+
+        """
+        if not rate_limit or _setps_bucket.can_consume(1):
+            if hostname:
+                progname = "%s@%s" % (progname, hostname.split(".")[0])
+            if current_process is not None:
+                return set_process_title(
+                    "%s:%s" % (progname, current_process().name), info=info)
+            else:
+                return set_process_title(progname, info=info)
 
 
 def shellsplit(s, posix=True):

+ 0 - 7
celery/registry.py

@@ -17,7 +17,6 @@ from .exceptions import NotRegistered
 
 
 class TaskRegistry(dict):
-
     NotRegistered = NotRegistered
 
     def regular(self):
@@ -59,12 +58,6 @@ class TaskRegistry(dict):
         return dict((name, task) for name, task in self.iteritems()
                                     if task.type == type)
 
-    def __getitem__(self, key):
-        try:
-            return dict.__getitem__(self, key)
-        except KeyError:
-            raise self.NotRegistered(key)
-
     def pop(self, key, *args):
         try:
             return dict.pop(self, key, *args)

+ 2 - 2
celery/tests/test_worker/test_worker_job.py

@@ -19,7 +19,7 @@ from celery import states
 from celery.app import app_or_default
 from celery.concurrency.base import BasePool
 from celery.datastructures import ExceptionInfo
-from celery.exceptions import (RetryTaskError, NotRegistered,
+from celery.exceptions import (RetryTaskError,
                                WorkerLostError, InvalidTaskError)
 from celery.execute.trace import eager_trace_task, TraceInfo
 from celery.log import setup_logger
@@ -579,7 +579,7 @@ class test_TaskRequest(unittest.TestCase):
         m = Message(None, body=anyjson.serialize(body), backend="foo",
                           content_type="application/json",
                           content_encoding="utf-8")
-        with self.assertRaises(NotRegistered):
+        with self.assertRaises(KeyError):
             TaskRequest.from_message(m, m.decode())
 
     def test_execute(self):

+ 1 - 11
celery/worker/consumer.py

@@ -409,16 +409,6 @@ class Consumer(object):
         :param message: The kombu message object.
 
         """
-        # need to guard against errors occurring while acking the message.
-        def ack():
-            try:
-                message.ack()
-            except self.connection_errors + (AttributeError, ), exc:
-                self.logger.critical(
-                    "Couldn't ack %r: %s reason:%r",
-                        message.delivery_tag,
-                        self._message_report(body, message), exc)
-
         try:
             name = body["task"]
         except (KeyError, TypeError):
@@ -430,7 +420,7 @@ class Consumer(object):
             return
 
         try:
-            self.strategies[name](message, body, ack)
+            self.strategies[name](message, body, message.ack_log_error)
         except KeyError, exc:
             self.logger.error(UNKNOWN_TASK_ERROR, exc, safe_repr(body),
                               exc_info=sys.exc_info())

+ 131 - 99
celery/worker/job.py

@@ -17,13 +17,15 @@ import time
 import socket
 
 from datetime import datetime
+from operator import itemgetter
 
 from .. import exceptions
 from ..registry import tasks
 from ..app import app_or_default
 from ..execute.trace import build_tracer, trace_task, report_internal_error
 from ..platforms import set_mp_process_title as setps
-from ..utils import noop, kwdict, fun_takes_kwargs, truncate_text
+from ..utils import (noop, kwdict, fun_takes_kwargs,
+                     cached_property, truncate_text)
 from ..utils.encoding import safe_repr, safe_str
 from ..utils.timeutils import maybe_iso8601, timezone
 
@@ -31,7 +33,11 @@ from . import state
 
 #: Keys to keep from the message delivery info.  The values
 #: of these keys must be pickleable.
-WANTED_DELIVERY_INFO = ("exchange", "routing_key", "consumer_tag", )
+WANTED_DELIVERY_INFO = itemgetter("exchange", "routing_key")
+
+tz_to_local = timezone.to_local
+tz_or_local = timezone.tz_or_local
+tz_utc = timezone.utc
 
 
 def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
@@ -39,7 +45,7 @@ def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
 
     It's the same as::
 
-        >>> trace_task(task_name, *args, **kwargs)[0]
+        >>> trace_task(name, *args, **kwargs)[0]
 
     """
     task = tasks[name]
@@ -62,14 +68,14 @@ class TaskRequest(object):
     #: Kind of task.  Must be a name registered in the task registry.
     name = None
 
-    #: The task class (set by constructor using :attr:`task_name`).
+    #: The task class (set by constructor using :attr:`name`).
     task = None
 
     #: UUID of the task.
-    task_id = None
+    id = None
 
     #: UUID of the taskset that this task belongs to.
-    taskset_id = None
+    taskset = None
 
     #: List of positional arguments to apply to the task.
     args = None
@@ -124,85 +130,82 @@ class TaskRequest(object):
     _already_revoked = False
     _terminate_on_ack = None
 
-    def __init__(self, task_name, task_id, args, kwargs,
-            on_ack=noop, retries=0, delivery_info=None, hostname=None,
+    def __init__(self, task, id, args=[], kwargs={},
+            on_ack=noop, retries=0, delivery_info={}, hostname=None,
             logger=None, eventer=None, eta=None, expires=None, app=None,
-            taskset_id=None, chord=None, utc=False, **opts):
-        self.app = app_or_default(app)
-        self.task_name = task_name
-        self.task_id = task_id
-        self.taskset_id = taskset_id
-        self.retries = retries
+            taskset=None, chord=None, utc=False, connection_errors=None,
+            **opts):
+        self.app = app or app_or_default(app)
+        self.name = task
+        self.id = id
         self.args = args
-        self.kwargs = kwargs
-        self.eta = eta
-        self.expires = expires
+        self.kwargs = kwdict(kwargs)
+        self.taskset = taskset
+        self.retries = retries
         self.chord = chord
         self.on_ack = on_ack
         self.delivery_info = {} if delivery_info is None else delivery_info
         self.hostname = hostname or socket.gethostname()
         self.logger = logger or self.app.log.get_default_logger()
         self.eventer = eventer
+        self.connection_errors = connection_errors or ()
 
-        self.task = tasks[self.task_name]
-        self._store_errors = True
-        if self.task.ignore_result:
-            self._store_errors = self.task.store_errors_even_if_ignored
+        task = self.task = tasks[task]
 
         # timezone means the message is timezone-aware, and the only timezone
         # supported at this point is UTC.
-        self.tzlocal = timezone.tz_or_local(self.app.conf.CELERY_TIMEZONE)
-        tz = timezone.utc if utc else self.tzlocal
-        if self.eta is not None:
-            self.eta = timezone.to_local(self.eta, self.tzlocal, tz)
-        if self.expires is not None:
-            self.expires = timezone.to_local(self.expires, self.tzlocal, tz)
+        if eta is not None:
+            tz = tz_utc if utc else self.tzlocal
+            self.eta = tz_to_local(maybe_iso8601(eta), self.tzlocal, tz)
+        if expires is not None:
+            tz = tz_utc if utc else self.tzlocal
+            self.expires = tz_to_local(maybe_iso8601(expires),
+                                       self.tzlocal, tz)
 
         # shortcuts
         self._does_debug = self.logger.isEnabledFor(logging.DEBUG)
         self._does_info = self.logger.isEnabledFor(logging.INFO)
 
+        self.request_dict = {"hostname": self.hostname,
+                             "id": id, "taskset": taskset,
+                             "retries": retries, "is_eager": False,
+                             "delivery_info": delivery_info, "chord": chord}
+
+    @cached_property
+    def tzlocal(self):
+        return tz_or_local(self.app.conf.CELERY_TIMEZONE)
+
     @classmethod
-    def from_message(cls, message, body, on_ack=noop, **kw):
+    def from_message(cls, message, body, on_ack=noop, delivery_info={},
+            logger=None, hostname=None, eventer=None, app=None,
+            connection_errors=None):
         """Create request from a task message.
 
         :raises UnknownTaskError: if the message does not describe a task,
             the message is also rejected.
 
         """
-        delivery_info = getattr(message, "delivery_info", {})
-        delivery_info = dict((key, delivery_info.get(key))
-                                for key in WANTED_DELIVERY_INFO)
-
-        kwargs = body.get("kwargs", {})
-        if not hasattr(kwargs, "items"):
-            raise exceptions.InvalidTaskError(
-                    "Task keyword arguments is not a mapping.")
         try:
-            task_name = body["task"]
-            task_id = body["id"]
-        except KeyError, exc:
-            raise exceptions.InvalidTaskError(
-                "Task message is missing required field %r" % (exc, ))
-
-        return cls(task_name=task_name,
-                   task_id=task_id,
-                   taskset_id=body.get("taskset", None),
-                   args=body.get("args", []),
-                   kwargs=kwdict(kwargs),
-                   chord=body.get("chord"),
-                   retries=body.get("retries", 0),
-                   eta=maybe_iso8601(body.get("eta")),
-                   expires=maybe_iso8601(body.get("expires")),
-                   on_ack=on_ack,
-                   delivery_info=delivery_info,
-                   utc=body.get("utc", None),
-                   **kw)
+            D = message.delivery_info
+            delivery_info = {"exchange": D.get("exchange"),
+                             "routing_key": D.get("routing_key")}
+        except (AttributeError, KeyError):
+            pass
+
+        try:
+            return cls(on_ack=on_ack, logger=logger, eventer=eventer, app=app,
+                       delivery_info=delivery_info, hostname=hostname,
+                       connection_errors=connection_errors, **body)
+        except TypeError:
+            for f in ("task", "id"):
+                if f not in body:
+                    raise exceptions.InvalidTaskError(
+                        "Task message is missing required field %r" % (f, ))
 
     def get_instance_attrs(self, loglevel, logfile):
         return {"logfile": logfile, "loglevel": loglevel,
                 "hostname": self.hostname,
-                "id": self.task_id, "taskset": self.taskset_id,
+                "id": self.id, "taskset": self.taskset,
                 "retries": self.retries, "is_eager": False,
                 "delivery_info": self.delivery_info, "chord": self.chord}
 
@@ -218,13 +221,11 @@ class TaskRequest(object):
         in version 3.0.
 
         """
-        if not self.task.accept_magic_kwargs:
-            return self.kwargs
         kwargs = dict(self.kwargs)
         default_kwargs = {"logfile": logfile,
                           "loglevel": loglevel,
-                          "task_id": self.task_id,
-                          "task_name": self.task_name,
+                          "task_id": self.id,
+                          "task_name": self.name,
                           "task_retries": self.retries,
                           "task_is_eager": False,
                           "delivery_info": self.delivery_info}
@@ -247,13 +248,16 @@ class TaskRequest(object):
         """
         if self.revoked():
             return
+        request = self.request_dict
 
-        args = self._get_tracer_args(loglevel, logfile)
-        instance_attrs = self.get_instance_attrs(loglevel, logfile)
+        kwargs = self.kwargs
+        if self.task.accept_magic_kwargs:
+            kwargs = self.extend_with_default_kwargs(loglevel, logfile)
+        request.update({"loglevel": loglevel, "logfile": logfile})
         result = pool.apply_async(execute_and_trace,
-                                  args=args,
+                                  args=(self.name, self.id, self.args, kwargs),
                                   kwargs={"hostname": self.hostname,
-                                          "request": instance_attrs},
+                                          "request": request},
                                   accept_callback=self.on_accepted,
                                   timeout_callback=self.on_timeout,
                                   callback=self.on_success,
@@ -288,9 +292,9 @@ class TaskRequest(object):
     def maybe_expire(self):
         """If expired, mark the task as revoked."""
         if self.expires and datetime.now(self.tzlocal) > self.expires:
-            state.revoked.add(self.task_id)
-            if self._store_errors:
-                self.task.backend.mark_as_revoked(self.task_id)
+            state.revoked.add(self.id)
+            if self.store_errors:
+                self.task.backend.mark_as_revoked(self.id)
 
     def terminate(self, pool, signal=None):
         if self.time_start:
@@ -304,10 +308,10 @@ class TaskRequest(object):
             return True
         if self.expires:
             self.maybe_expire()
-        if self.task_id in state.revoked:
+        if self.id in state.revoked:
             self.logger.warn("Skipping revoked task: %s[%s]",
-                             self.task_name, self.task_id)
-            self.send_event("task-revoked", uuid=self.task_id)
+                             self.name, self.id)
+            self.send_event("task-revoked", uuid=self.id)
             self.acknowledge()
             self._already_revoked = True
             return True
@@ -324,10 +328,10 @@ class TaskRequest(object):
         state.task_accepted(self)
         if not self.task.acks_late:
             self.acknowledge()
-        self.send_event("task-started", uuid=self.task_id, pid=pid)
+        self.send_event("task-started", uuid=self.id, pid=pid)
         if self._does_debug:
             self.logger.debug("Task accepted: %s[%s] pid:%r",
-                              self.task_name, self.task_id, pid)
+                              self.name, self.id, pid)
         if self._terminate_on_ack is not None:
             _, pool, signal = self._terminate_on_ack
             self.terminate(pool, signal)
@@ -337,15 +341,15 @@ class TaskRequest(object):
         state.task_ready(self)
         if soft:
             self.logger.warning("Soft time limit (%ss) exceeded for %s[%s]",
-                                timeout, self.task_name, self.task_id)
+                                timeout, self.name, self.id)
             exc = exceptions.SoftTimeLimitExceeded(timeout)
         else:
             self.logger.error("Hard time limit (%ss) exceeded for %s[%s]",
-                              timeout, self.task_name, self.task_id)
+                              timeout, self.name, self.id)
             exc = exceptions.TimeLimitExceeded(timeout)
 
-        if self._store_errors:
-            self.task.backend.mark_as_failure(self.task_id, exc)
+        if self.store_errors:
+            self.task.backend.mark_as_failure(self.id, exc)
 
     def on_success(self, ret_value):
         """Handler called if the task was successfully processed."""
@@ -355,26 +359,26 @@ class TaskRequest(object):
             self.acknowledge()
 
         runtime = self.time_start and (time.time() - self.time_start) or 0
-        self.send_event("task-succeeded", uuid=self.task_id,
+        self.send_event("task-succeeded", uuid=self.id,
                         result=safe_repr(ret_value), runtime=runtime)
 
         if self._does_info:
             self.logger.info(self.success_msg.strip(),
-                            {"id": self.task_id,
-                             "name": self.task_name,
+                            {"id": self.id,
+                             "name": self.name,
                              "return_value": self.repr_result(ret_value),
                              "runtime": runtime})
 
     def on_retry(self, exc_info):
         """Handler called if the task should be retried."""
-        self.send_event("task-retried", uuid=self.task_id,
+        self.send_event("task-retried", uuid=self.id,
                          exception=safe_repr(exc_info.exception.exc),
                          traceback=safe_str(exc_info.traceback))
 
         if self._does_info:
             self.logger.info(self.retry_msg.strip(),
-                            {"id": self.task_id,
-                             "name": self.task_name,
+                            {"id": self.id,
+                             "name": self.name,
                              "exc": safe_repr(exc_info.exception.exc)},
                             exc_info=exc_info)
 
@@ -391,16 +395,16 @@ class TaskRequest(object):
         # This is a special case as the process would not have had
         # time to write the result.
         if isinstance(exc_info.exception, exceptions.WorkerLostError) and \
-                self._store_errors:
-            self.task.backend.mark_as_failure(self.task_id, exc_info.exception)
+                self.store_errors:
+            self.task.backend.mark_as_failure(self.id, exc_info.exception)
 
-        self.send_event("task-failed", uuid=self.task_id,
+        self.send_event("task-failed", uuid=self.id,
                          exception=safe_repr(exc_info.exception),
                          traceback=safe_str(exc_info.traceback))
 
         context = {"hostname": self.hostname,
-                   "id": self.task_id,
-                   "name": self.task_name,
+                   "id": self.id,
+                   "name": self.name,
                    "exc": safe_repr(exc_info.exception),
                    "traceback": safe_str(exc_info.traceback),
                    "args": safe_repr(self.args),
@@ -408,17 +412,17 @@ class TaskRequest(object):
 
         self.logger.error(self.error_msg.strip(), context,
                           exc_info=exc_info.exc_info,
-                          extra={"data": {"id": self.task_id,
-                                          "name": self.task_name,
+                          extra={"data": {"id": self.id,
+                                          "name": self.name,
                                           "hostname": self.hostname}})
 
-        task_obj = tasks.get(self.task_name, object)
+        task_obj = tasks.get(self.name, object)
         task_obj.send_error_email(context, exc_info.exception)
 
     def acknowledge(self):
         """Acknowledge task."""
         if not self.acknowledged:
-            self.on_ack()
+            self.on_ack(self.logger, self.connection_errors)
             self.acknowledged = True
 
     def repr_result(self, result, maxlen=46):
@@ -427,8 +431,8 @@ class TaskRequest(object):
         return truncate_text(safe_repr(result), maxlen)
 
     def info(self, safe=False):
-        return {"id": self.task_id,
-                "name": self.task_name,
+        return {"id": self.id,
+                "name": self.name,
                 "args": self.args if safe else safe_repr(self.args),
                 "kwargs": self.kwargs if safe else safe_repr(self.kwargs),
                 "hostname": self.hostname,
@@ -439,8 +443,7 @@ class TaskRequest(object):
 
     def shortinfo(self):
         return "%s[%s]%s%s" % (
-                    self.task_name,
-                    self.task_id,
+                    self.name, self.id,
                     " eta:[%s]" % (self.eta, ) if self.eta else "",
                     " expires:[%s]" % (self.expires, ) if self.expires else "")
     __str__ = shortinfo
@@ -448,10 +451,39 @@ class TaskRequest(object):
     def __repr__(self):
         return '<%s: {name:"%s", id:"%s", args:"%s", kwargs:"%s"}>' % (
                 self.__class__.__name__,
-                self.task_name, self.task_id, self.args, self.kwargs)
+                self.name, self.id, self.args, self.kwargs)
 
     def _get_tracer_args(self, loglevel=None, logfile=None, use_real=False):
         """Get the task trace args for this task."""
-        task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
-        first = self.task if use_real else self.task_name
-        return first, self.task_id, self.args, task_func_kwargs
+        kwargs = self.kwargs
+        if self.task.accept_magic_kwargs:
+            kwargs = self.extend_with_default_kwargs(loglevel, logfile)
+        first = self.task if use_real else self.name
+        return first, self.id, self.args, kwargs
+
+    @property
+    def store_errors(self):
+        return (not self.task.ignore_result
+                or self.task.store_errors_even_if_ignored)
+
+    def _compat_get_task_id(self):
+        return self.id
+
+    def _compat_set_task_id(self, value):
+        self.id = value
+
+    def _compat_get_task_name(self):
+        return self.name
+
+    def _compat_set_task_name(self, value):
+        self.name = value
+
+    def _compat_get_taskset_id(self):
+        return self.taskset
+
+    def _compat_set_taskset_id(self, value):
+        self.taskset = value
+
+    task_id = property(_compat_get_task_id, _compat_set_task_id)
+    task_name = property(_compat_get_task_name, _compat_set_task_name)
+    taskset_id = property(_compat_get_taskset_id, _compat_set_taskset_id)

+ 3 - 5
celery/worker/state.py

@@ -48,16 +48,14 @@ total_count = defaultdict(lambda: 0)
 #: the list of currently revoked tasks.  Persistent if statedb set.
 revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
 
-
-def task_reserved(request):
-    """Updates global state when a task has been reserved."""
-    reserved_requests.add(request)
+#: Updates global state when a task has been reserved.
+task_reserved = reserved_requests.add
 
 
 def task_accepted(request):
     """Updates global state when a task has been accepted."""
     active_requests.add(request)
-    total_count[request.task_name] += 1
+    #total_count[request.task_name] += 1
 
 
 def task_ready(request):

+ 3 - 1
celery/worker/strategy.py

@@ -9,9 +9,11 @@ def default(task, app, consumer):
     eventer = consumer.event_dispatcher
     Request = TaskRequest.from_message
     handle = consumer.on_task
+    connection_errors = consumer.connection_errors
 
     def task_message_handler(M, B, A):
         handle(Request(M, B, A, app=app, logger=logger,
-                                hostname=hostname, eventer=eventer))
+                                hostname=hostname, eventer=eventer,
+                                connection_errors=connection_errors))
 
     return task_message_handler

+ 37 - 1
funtests/benchmarks/bench_worker.py

@@ -2,6 +2,39 @@ import os
 import sys
 import time
 
+os.environ["NOSETPS"] = "yes"
+
+from threading import Lock
+
+class DLock(object):
+
+    def __init__(self):
+        self.l = Lock()
+
+    def acquire(self, *args, **kwargs):
+        print("ACQUIRE: %r %r" % (args, kwargs))
+        import traceback
+        traceback.print_stack()
+        return self.l.acquire(*args, **kwargs)
+
+    def release(self):
+        print("RELEASE")
+        return self.l.release()
+
+    def __enter__(self):
+        self.acquire()
+        return self
+
+    def __exit__(self, *exc_info):
+        self.release()
+
+
+import threading
+#threading.Lock = DLock
+
+
+
+
 import anyjson
 JSONIMP = os.environ.get("JSONIMP")
 if JSONIMP:
@@ -29,7 +62,8 @@ celery.conf.update(BROKER_TRANSPORT="librabbitmq",
                    },
                    CELERY_TASK_SERIALIZER="json",
                    CELERY_DEFAULT_QUEUE="bench.worker",
-                   CELERY_BACKEND=None)
+                   CELERY_BACKEND=None,
+                   )#CELERY_MESSAGE_COMPRESSION="zlib")
 
 
 def tdiff(then):
@@ -47,6 +81,7 @@ def it(_, n):
         it.subt = it.time_start = time.time()
     elif i == n - 1:
         total = tdiff(it.time_start)
+        print >> sys.stderr, "(%s so far: %ss)" % (i, tdiff(it.subt))
         print("-- process %s tasks: %ss total, %s tasks/s} " % (
                 n, total, n / (total + .0)))
         sys.exit()
@@ -60,6 +95,7 @@ def bench_apply(n=DEFAULT_ITS):
 
 
 def bench_work(n=DEFAULT_ITS, loglevel=None):
+    loglevel = os.environ.get("BENCH_LOGLEVEL") or loglevel
     if loglevel:
         celery.log.setup_logging_subsystem(loglevel=loglevel)
     worker = celery.WorkController(concurrency=15,