Selaa lähdekoodia

Optimization: A tracer is now generated for each task (+ ~1500msg/s)

Ask Solem 13 vuotta sitten
vanhempi
commit
9b94556b8c

+ 2 - 0
celery/actors.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from celery.app import app_or_default
 
 import cl

+ 0 - 3
celery/app/__init__.py

@@ -15,9 +15,6 @@ from __future__ import absolute_import
 import os
 import threading
 
-from functools import wraps
-from inspect import getargspec
-
 from .. import registry
 from ..utils import cached_property, instantiate
 

+ 5 - 3
celery/app/task/__init__.py

@@ -18,7 +18,7 @@ import threading
 from ... import states
 from ...datastructures import ExceptionInfo
 from ...exceptions import MaxRetriesExceededError, RetryTaskError
-from ...execute.trace import trace_task
+from ...execute.trace import eager_trace_task
 from ...registry import tasks, _unpickle_task
 from ...result import EagerResult
 from ...utils import fun_takes_kwargs, instantiate, mattrgetter, uuid
@@ -105,6 +105,7 @@ class TaskType(type):
             try:
                 attrs["__call__"] = attrs["run"]
             except KeyError:
+
                 # the class does not yet define run,
                 # so we can't optimize this case.
                 def __call__(self, *args, **kwargs):
@@ -141,6 +142,7 @@ class BaseTask(object):
 
     """
     __metaclass__ = TaskType
+    __tracer__ = None
 
     ErrorMail = ErrorMail
     MaxRetriesExceededError = MaxRetriesExceededError
@@ -618,8 +620,8 @@ class BaseTask(object):
                                         if key in supported_keys)
             kwargs.update(extend_with)
 
-        retval, info = trace_task(task.name, task_id, args, kwargs, eager=True,
-                                  task=task, request=request, propagate=throw)
+        retval, info = eager_trace_task(task, task_id, args, kwargs,
+                                        request=request, propagate=throw)
         if isinstance(retval, ExceptionInfo):
             retval = retval.exception
         state, tb = states.SUCCESS, ''

+ 0 - 156
celery/execute/coro.py

@@ -1,156 +0,0 @@
-import traceback
-import socket
-import os
-import sys
-import warnings
-from .. import states
-from .. import current_app
-from .. import signals
-from ..datastructures import ExceptionInfo
-from ..exceptions import RetryTaskError
-from ..utils.serialization import get_pickleable_exception
-
-
-class Traceinfo(object):
-    __slots__ = ("state", "retval", "exc_info",
-                 "exc_type", "exc_value", "tb", "strtb")
-
-    def __init__(self, state, retval, exc_info=None):
-        self.state = state
-        self.retval = retval
-        self.exc_info = exc_info
-        if exc_info:
-            self.exc_type, self.exc_value, self.tb = exc_info
-            self.strtb = "\n".join(traceback.format_exception(*exc_info))
-
-
-def tracer(task, loader=None, hostname=None):
-    hostname = hostname or socket.gethostname()
-    pid = os.getpid()
-
-    PENDING = states.PENDING
-    SUCCESS = states.SUCCESS
-    RETRY = states.RETRY
-    FAILURE = states.FAILURE
-    EXCEPTION_STATES = states.EXCEPTION_STATES
-
-    loader = loader or current_app.loader
-    on_task_init = loader.on_task_init
-
-    task_cleanup = task.backend.process_cleanup
-    loader_cleanup = loader.on_process_cleanup
-    on_success = task.on_success
-    on_failure = task.on_failure
-    on_retry = task.on_retry
-    after_return = task.after_return
-    update_req = task.request.update
-    clear_req = task.request.clear
-    backend = task.backend
-    prepare_exception = backend.prepare_exception
-    mark_as_started = backend.mark_as_started
-    mark_as_done = backend.mark_as_done
-    mark_as_failure = backend.mark_as_failure
-    mark_as_retry = backend.mark_as_retry
-    ignore_result = task.ignore_result
-    store_errors = True
-    if ignore_result:
-        store_errors = task.store_errors_even_if_ignored
-    track_started = task.track_started
-
-    send_prerun = signals.task_prerun.send
-    send_postrun = signals.task_postrun.send
-    send_failure = signals.task_failure.send
-
-    @coroutine
-    def task_tracer(self):
-
-        while 1:
-            X = None
-            ID, ARGS, KWARGS, REQ, propagate = (yield)
-            state = PENDING
-
-            try:
-                # - init
-                on_task_init(ID, task)
-                if track_started and not ignore_result:
-                    mark_as_started(ID, pid=pid, hostname=hostname)
-                update_req(REQ, args=ARGS, kwargs=KWARGS,
-                           called_directly=False)
-                send_prerun(sender=task, task_id=ID, task=task,
-                            args=ARGS, kwargs=KWARGS)
-
-                # - trace execution
-                R = None
-                try:
-                    R = Traceinfo(SUCCESS, task(*ARGS, **KWARGS))
-                except RetryTaskError, exc:
-                    R = TraceInfo(RETRY, exc, sys.exc_info())
-                except Exception, exc:
-                    if propagate:
-                        raise
-                    R = Traceinfo(FAILURE, exc, sys.exc_info())
-                except BaseException, exc:
-                    raise
-                except:  # pragma: no cover
-                    # For Python2.5 where raising strings are still allowed
-                    # (but deprecated)
-                    if propagate:
-                        raise
-                    R = Traceinfo(FAILURE, None, sys.exc_info())
-
-                # - report state
-                state = R.state
-                retval = R.retval
-                if state == SUCCESS:
-                    if not ignore_result:
-                        mark_as_done(ID, retval)
-                    on_success(retval, ID, ARGS, KWARGS)
-                elif state == RETRY:
-                    type_, tb = R.exc_type, R.tb
-                    if not ignore_result:
-                        message, orig_exc = R.exc_value.args
-                    if store_errors:
-                        mark_as_retry(ID, orig_exc, R.strtb)
-                    expanded_msg = "%s: %s" % (message, str(orig_exc))
-                    X = ExceptionInfo((type_, type_(expanded_msg, None), tb))
-                    on_retry(exc, ID, ARGS, KWARGS, X)
-                elif state == FAILURE:
-                    if store_errors:
-                        mark_as_failure(ID, exc, R.strtb)
-                    exc = get_pickleable_exception(exc)
-                    X = ExceptionInfo((type_, exc, tb))
-                    on_failure(exc, ID, ARGS, KWARGS, X)
-                    send_failure(sender=task, task_id=ID, exception=exc,
-                                 args=ARGS, kwargs=KWARGS, traceback=tb,
-                                 einfo=X)
-
-                # - after return
-                if state in EXCEPTION_STATES:
-                    einfo = ExceptionInfo(R.exc_info)
-                after_return(state, R, ID, ARGS, KWARGS, einfo)
-
-                # - post run
-                send_postrun(sender=task, task_id=ID, task=task,
-                             args=ARGS, kwargs=KWARGS, retval=retval)
-
-                yield X
-            except Exception, exc:
-                _type, _value, _tb = sys.exc_info()
-                _value = prepare_exception(exc)
-                exc_info = ExceptionInfo((_type, _value, _tb))
-                warnings.warn("Exception outside body: %s: %s\n%s" % tuple(
-                    map(str, (exc.__class__, exc, exc_info.traceback))))
-                yield exc_info
-            finally:
-                clear_req()
-                try:
-                    task_cleanup()
-                    loader_cleanup()
-                except (KeyboardInterrupt, SystemExit, MemoryError):
-                    raise
-                except Exception, exc:
-                    logger = current_app.log.get_default_logger()
-                    logger.error("Process cleanup failed: %r", exc,
-                                 exc_info=sys.exc_info())
-
-    return task_tracer()

+ 107 - 95
celery/execute/trace.py

@@ -37,6 +37,7 @@ send_prerun = signals.task_prerun.send
 prerun_receivers = signals.task_prerun.receivers
 send_postrun = signals.task_postrun.send
 postrun_receivers = signals.task_postrun.receivers
+STARTED = states.STARTED
 SUCCESS = states.SUCCESS
 RETRY = states.RETRY
 FAILURE = states.FAILURE
@@ -111,106 +112,117 @@ class TraceInfo(object):
             return '\n'.join(traceback.format_exception(*self.exc_info))
         return ''
 
-def trace_task(name, uuid, args, kwargs, task=None, request=None,
-        eager=False, propagate=False, loader=None, hostname=None,
-        store_errors=True, Info=TraceInfo):
-    """Wraps the task in a jail, catches all exceptions, and
-    saves the status and result of the task execution to the task
-    meta backend.
-
-    If the call was successful, it saves the result to the task result
-    backend, and sets the task status to `"SUCCESS"`.
-
-    If the call raises :exc:`~celery.exceptions.RetryTaskError`, it extracts
-    the original exception, uses that as the result and sets the task status
-    to `"RETRY"`.
-
-    If the call results in an exception, it saves the exception as the task
-    result, and sets the task status to `"FAILURE"`.
-
-    :param task_name: The name of the task to execute.
-    :param task_id: The unique id of the task.
-    :param args: List of positional args to pass on to the function.
-    :param kwargs: Keyword arguments mapping to pass on to the function.
-
-    :keyword loader: Custom loader to use, if not specified the current app
-      loader will be used.
-    :keyword hostname: Custom hostname to use, if not specified the system
-      hostname will be used.
-
-    :returns: the evaluated functions return value on success, or
-        the exception instance on failure.
-    """
-    R = I = None
-    try:
-        task = task or tasks[name]
-        backend = task.backend
-        ignore_result = task.ignore_result
-        loader = loader or current_app.loader
-        hostname = hostname or socket.gethostname()
-        task.request.update(request or {}, args=args,
-                            called_directly=False, kwargs=kwargs)
+
+def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
+        Info=TraceInfo, eager=False, propagate=False):
+    task = task or tasks[name]
+    loader = loader or current_app.loader
+    backend = task.backend
+    ignore_result = task.ignore_result
+    track_started = task.track_started
+    track_started = not eager and (task.track_started and not ignore_result)
+    publish_result = not eager and not ignore_result
+    hostname = hostname or socket.gethostname()
+
+    loader_task_init = loader.on_task_init
+    loader_cleanup = loader.on_process_cleanup
+
+    task_on_success = task.on_success
+    task_after_return = task.after_return
+    task_request = task.request
+
+    store_result = backend.store_result
+    backend_cleanup = backend.process_cleanup
+
+    pid = os.getpid()
+
+    update_request = task_request.update
+    clear_request = task_request.clear
+    on_chord_part_return = backend.on_chord_part_return
+
+    def trace_task(uuid, args, kwargs, request=None):
+        R = I = None
         try:
-            # -*- PRE -*-
-            send_prerun(sender=task, task_id=uuid, task=task,
-                        args=args, kwargs=kwargs)
-            loader.on_task_init(uuid, task)
-            if not eager and (task.track_started and not ignore_result):
-                backend.mark_as_started(uuid, pid=getpid(),
-                                        hostname=hostname)
-
-            # -*- TRACE -*-
+            update_request(request or {}, args=args,
+                           called_directly=False, kwargs=kwargs)
             try:
-                R = retval = task(*args, **kwargs)
-                state, einfo = SUCCESS, None
-                task.on_success(retval, uuid, args, kwargs)
-                if not eager and not ignore_result:
-                    backend.mark_as_done(uuid, retval)
-            except RetryTaskError, exc:
-                I = Info(RETRY, exc, sys.exc_info())
-                state, retval, einfo = I.state, I.retval, I.exc_info
-                R = I.handle_error_state(task, eager=eager)
-            except Exception, exc:
-                if propagate:
-                    raise
-                I = Info(FAILURE, exc, sys.exc_info())
-                state, retval, einfo = I.state, I.retval, I.exc_info
-                R = I.handle_error_state(task, eager=eager)
-            except BaseException, exc:
-                raise
-            except:
-                # pragma: no cover
-                # For Python2.5 where raising strings are still allowed
-                # (but deprecated)
-                if propagate:
-                    raise
-                I = Info(FAILURE, None, sys.exc_info())
-                state, retval, einfo = I.state, I.retval, I.exc_info
-                R = I.handle_error_state(task, eager=eager)
-
-            # -* POST *-
-            if task.request.chord:
-                backend.on_chord_part_return(task)
-            task.after_return(state, retval, uuid, args, kwargs, einfo)
-            send_postrun(sender=task, task_id=uuid, task=task,
-                         args=args, kwargs=kwargs, retval=retval)
-        finally:
-            task.request.clear()
-            if not eager:
+                # -*- PRE -*-
+                send_prerun(sender=task, task_id=uuid, task=task,
+                            args=args, kwargs=kwargs)
+                loader_task_init(uuid, task)
+                if track_started:
+                    store_result(uuid, {"pid": pid,
+                                        "hostname": hostname}, STARTED)
+
+                # -*- TRACE -*-
                 try:
-                    backend.process_cleanup()
-                    loader.on_process_cleanup()
-                except (KeyboardInterrupt, SystemExit, MemoryError):
-                    raise
+                    R = retval = task(*args, **kwargs)
+                    state, einfo = SUCCESS, None
+                    task_on_success(retval, uuid, args, kwargs)
+                    if publish_result:
+                        store_result(uuid, retval, SUCCESS)
+                except RetryTaskError, exc:
+                    I = Info(RETRY, exc, sys.exc_info())
+                    state, retval, einfo = I.state, I.retval, I.exc_info
+                    R = I.handle_error_state(task, eager=eager)
                 except Exception, exc:
-                    logger = current_app.log.get_default_logger()
-                    logger.error("Process cleanup failed: %r", exc,
-                                 exc_info=sys.exc_info())
+                    if propagate:
+                        raise
+                    I = Info(FAILURE, exc, sys.exc_info())
+                    state, retval, einfo = I.state, I.retval, I.exc_info
+                    R = I.handle_error_state(task, eager=eager)
+                except BaseException, exc:
+                    raise
+                except:
+                    # pragma: no cover
+                    # For Python2.5 where raising strings are still allowed
+                    # (but deprecated)
+                    if propagate:
+                        raise
+                    I = Info(FAILURE, None, sys.exc_info())
+                    state, retval, einfo = I.state, I.retval, I.exc_info
+                    R = I.handle_error_state(task, eager=eager)
+
+                # -* POST *-
+                if task_request.chord:
+                    on_chord_part_return(task)
+                task_after_return(state, retval, uuid, args, kwargs, einfo)
+                send_postrun(sender=task, task_id=uuid, task=task,
+                            args=args, kwargs=kwargs, retval=retval)
+            finally:
+                clear_request()
+                if not eager:
+                    try:
+                        backend_cleanup()
+                        loader_cleanup()
+                    except (KeyboardInterrupt, SystemExit, MemoryError):
+                        raise
+                    except Exception, exc:
+                        logger = current_app.log.get_default_logger()
+                        logger.error("Process cleanup failed: %r", exc,
+                                     exc_info=sys.exc_info())
+        except Exception, exc:
+            if eager:
+                raise
+            R = report_internal_error(task, exc)
+        return R, I
+
+    return trace_task
+
+
+def trace_task(task, uuid, args, kwargs, request=None, **opts):
+    try:
+        if task.__tracer__ is None:
+            task.__tracer__ = build_tracer(task.name, task, **opts)
+        return task.__tracer__(uuid, args, kwargs, request)
     except Exception, exc:
-        if eager:
-            raise
-        R = report_internal_error(task, exc)
-    return R, I
+        return report_internal_error(task, exc), None
+
+
+def eager_trace_task(task, uuid, args, kwargs, request=None, **opts):
+    opts.setdefault("eager", True)
+    return build_tracer(task.name, task, **opts)(
+            uuid, args, kwargs, request)
 
 
 def report_internal_error(task, exc):

+ 0 - 2
celery/tests/test_task/__init__.py

@@ -4,8 +4,6 @@ from __future__ import with_statement
 from datetime import datetime, timedelta
 from functools import wraps
 
-from mock import Mock
-
 from celery import task
 from celery.app import app_or_default
 from celery.task import task as task_dec

+ 4 - 6
celery/tests/test_task/test_execute_trace.py

@@ -1,14 +1,13 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
-import operator
-
 from celery import current_app
 from celery import states
 from celery.exceptions import RetryTaskError
-from celery.execute.trace import trace_task
+from celery.execute.trace import eager_trace_task
 from celery.tests.utils import unittest
 
+
 @current_app.task
 def add(x, y):
     return x + y
@@ -20,9 +19,8 @@ def raises(exc):
 
 
 def trace(task, args=(), kwargs={}, propagate=False):
-    return trace_task(task.__name__, "id-1", args, kwargs, task,
-                      propagate=propagate, eager=True)
-
+    return eager_trace_task(task, "id-1", args, kwargs,
+                      propagate=propagate)
 
 
 class test_trace(unittest.TestCase):

+ 10 - 8
celery/tests/test_worker/test_worker_job.py

@@ -21,8 +21,9 @@ from celery.concurrency.base import BasePool
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import (RetryTaskError, NotRegistered,
                                WorkerLostError, InvalidTaskError)
-from celery.execute.trace import trace_task, TraceInfo
+from celery.execute.trace import eager_trace_task, TraceInfo
 from celery.log import setup_logger
+from celery.registry import tasks
 from celery.result import AsyncResult
 from celery.task import task as task_dec
 from celery.task.base import Task
@@ -40,8 +41,8 @@ scratch = {"ACK": False}
 some_kwargs_scratchpad = {}
 
 
-def jail(task_id, task_name, args, kwargs):
-    return trace_task(task_name, task_id, args, kwargs)[0]
+def jail(task_id, name, args, kwargs):
+    return eager_trace_task(tasks[name], task_id, args, kwargs, eager=False)[0]
 
 
 def on_ack():
@@ -123,7 +124,8 @@ class test_trace_task(unittest.TestCase):
                 tid = uuid()
                 ret = jail(tid, mytask.name, [2], {})
                 self.assertEqual(ret, 4)
-                mytask.backend.mark_as_done.assert_called_with(tid, 4)
+                mytask.backend.store_result.assert_called_with(tid, 4,
+                                                               states.SUCCESS)
                 logs = sio.getvalue().strip()
                 self.assertIn("Process cleanup failed", logs)
         finally:
@@ -144,15 +146,16 @@ class test_trace_task(unittest.TestCase):
         self.assertEqual(ret, 4)
 
     def test_marked_as_started(self):
-        mytask.track_started = True
 
         class Backend(mytask.backend.__class__):
             _started = []
 
-            def mark_as_started(self, tid, *args, **kwargs):
-                self._started.append(tid)
+            def store_result(self, tid, meta, state):
+                if state == states.STARTED:
+                    self._started.append(tid)
 
         prev, mytask.backend = mytask.backend, Backend()
+        mytask.track_started = True
 
         try:
             tid = uuid()
@@ -469,7 +472,6 @@ class test_TaskRequest(unittest.TestCase):
         self.assertEqual(res, 4 ** 4)
 
     def test_execute_safe_catches_exception(self):
-        old_exec = mytask.__call__
         warnings.resetwarnings()
 
         def _error_exec(self, *args, **kwargs):

+ 2 - 0
celery/utils/coroutine.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from functools import wraps
 from Queue import Queue
 

+ 1 - 1
celery/worker/consumer.py

@@ -430,7 +430,7 @@ class Consumer(object):
             return
 
         try:
-            self.strategies[name].send(message, body, ack)
+            self.strategies[name](message, body, ack)
         except KeyError, exc:
             self.logger.error(UNKNOWN_TASK_ERROR, exc, safe_repr(body),
                               exc_info=sys.exc_info())

+ 20 - 13
celery/worker/job.py

@@ -19,9 +19,9 @@ import socket
 from datetime import datetime
 
 from .. import exceptions
-from .. import registry
+from ..registry import tasks
 from ..app import app_or_default
-from ..execute.trace import trace_task
+from ..execute.trace import build_tracer, trace_task, report_internal_error
 from ..platforms import set_mp_process_title as setps
 from ..utils import noop, kwdict, fun_takes_kwargs, truncate_text
 from ..utils.encoding import safe_repr, safe_str
@@ -34,7 +34,7 @@ from . import state
 WANTED_DELIVERY_INFO = ("exchange", "routing_key", "consumer_tag", )
 
 
-def execute_and_trace(task_name, *args, **kwargs):
+def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
     """This is a pickleable method used as a target when applying to pools.
 
     It's the same as::
@@ -42,12 +42,18 @@ def execute_and_trace(task_name, *args, **kwargs):
         >>> trace_task(task_name, *args, **kwargs)[0]
 
     """
-    hostname = kwargs.get("hostname")
-    setps("celeryd", task_name, hostname, rate_limit=True)
+    task = tasks[name]
     try:
-        return trace_task(task_name, *args, **kwargs)[0]
-    finally:
-        setps("celeryd", "-idle-", hostname, rate_limit=True)
+        hostname = opts.get("hostname")
+        setps("celeryd", name, hostname, rate_limit=True)
+        try:
+            if task.__tracer__ is None:
+                task.__tracer__ = build_tracer(name, task, **opts)
+            return task.__tracer__(uuid, args, kwargs, request)[0]
+        finally:
+            setps("celeryd", "-idle-", hostname, rate_limit=True)
+    except Exception, exc:
+        return report_internal_error(task, exc)
 
 
 class TaskRequest(object):
@@ -138,7 +144,7 @@ class TaskRequest(object):
         self.logger = logger or self.app.log.get_default_logger()
         self.eventer = eventer
 
-        self.task = registry.tasks[self.task_name]
+        self.task = tasks[self.task_name]
         self._store_errors = True
         if self.task.ignore_result:
             self._store_errors = self.task.store_errors_even_if_ignored
@@ -272,7 +278,7 @@ class TaskRequest(object):
             self.acknowledge()
 
         instance_attrs = self.get_instance_attrs(loglevel, logfile)
-        retval, _ = trace_task(*self._get_tracer_args(loglevel, logfile),
+        retval, _ = trace_task(*self._get_tracer_args(loglevel, logfile, True),
                                **{"hostname": self.hostname,
                                   "loader": self.app.loader,
                                   "request": instance_attrs})
@@ -406,7 +412,7 @@ class TaskRequest(object):
                                           "name": self.task_name,
                                           "hostname": self.hostname}})
 
-        task_obj = registry.tasks.get(self.task_name, object)
+        task_obj = tasks.get(self.task_name, object)
         task_obj.send_error_email(context, exc_info.exception)
 
     def acknowledge(self):
@@ -444,7 +450,8 @@ class TaskRequest(object):
                 self.__class__.__name__,
                 self.task_name, self.task_id, self.args, self.kwargs)
 
-    def _get_tracer_args(self, loglevel=None, logfile=None):
+    def _get_tracer_args(self, loglevel=None, logfile=None, use_real=False):
         """Get the task trace args for this task."""
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
-        return self.task_name, self.task_id, self.args, task_func_kwargs
+        first = self.task if use_real else self.task_name
+        return first, self.task_id, self.args, task_func_kwargs

+ 11 - 15
celery/worker/strategy.py

@@ -1,21 +1,17 @@
-from .job import TaskRequest
+from __future__ import absolute_import
 
-from ..utils.coroutine import coroutine
+from .job import TaskRequest
 
 
 def default(task, app, consumer):
+    logger = consumer.logger
+    hostname = consumer.hostname
+    eventer = consumer.event_dispatcher
+    Request = TaskRequest.from_message
+    handle = consumer.on_task
 
-    @coroutine
-    def task_message_handler(self):
-        logger = consumer.logger
-        hostname = consumer.hostname
-        eventer = consumer.event_dispatcher
-        Request = TaskRequest.from_message
-        handle = consumer.on_task
-
-        while 1:
-            M, B, A = (yield)
-            handle(Request(M, B, A, app=app, logger=logger,
-                                    hostname=hostname, eventer=eventer))
+    def task_message_handler(M, B, A):
+        handle(Request(M, B, A, app=app, logger=logger,
+                                hostname=hostname, eventer=eventer))
 
-    return task_message_handler()
+    return task_message_handler