Browse Source

Pick stable optimizations from optimizations branch

Ask Solem 13 years ago
parent
commit
4b1f822774

+ 2 - 2
celery/app/task.py

@@ -24,7 +24,6 @@ from celery.__compat__ import class_property
 from celery.state import get_current_task
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
-from celery.local import LocalStack
 from celery.result import EagerResult
 from celery.utils import fun_takes_kwargs, uuid, maybe_reraise
 from celery.utils.functional import mattrgetter, maybe_list
@@ -176,7 +175,7 @@ class BaseTask(object):
 
     #: If disabled the worker will not forward magic keyword arguments.
     #: Deprecated and scheduled for removal in v3.0.
-    accept_magic_kwargs = False
+    accept_magic_kwargs = None
 
     #: Destination queue.  The queue needs to exist
     #: in :setting:`CELERY_QUEUES`.  The `routing_key`, `exchange` and
@@ -327,6 +326,7 @@ class BaseTask(object):
         if not was_bound:
             self.annotate()
 
+        from celery.utils.threads import LocalStack
         self.request_stack = LocalStack()
         self.request_stack.push(Context())
 

+ 6 - 1
celery/apps/worker.py

@@ -19,7 +19,6 @@ from celery.utils import cry, isatty
 from celery.utils.imports import qualname
 from celery.utils.log import LOG_LEVELS, get_logger, mlevel, set_in_sighandler
 from celery.utils.text import pluralize
-from celery.utils.threads import active_count as active_thread_count
 from celery.worker import WorkController
 
 try:
@@ -31,6 +30,12 @@ except ImportError:  # pragma: no cover
 logger = get_logger(__name__)
 
 
+def active_thread_count():
+    from threading import enumerate
+    return sum(1 for t in enumerate()
+        if not t.name.startswith("Dummy-"))
+
+
 def safe_say(msg):
     sys.__stderr__.write("\n%s\n" % msg)
 

+ 6 - 4
celery/concurrency/eventlet.py

@@ -106,6 +106,8 @@ class TaskPool(base.BasePool):
     def on_start(self):
         self._pool = self.Pool(self.limit)
         signals.eventlet_pool_started.send(sender=self)
+        self._quick_put = self._pool.spawn_n
+        self._quick_apply_sig = signals.eventlet_pool_apply.send
 
     def on_stop(self):
         signals.eventlet_pool_preshutdown.send(sender=self)
@@ -115,8 +117,8 @@ class TaskPool(base.BasePool):
 
     def on_apply(self, target, args=None, kwargs=None, callback=None,
             accept_callback=None, **_):
-        signals.eventlet_pool_apply.send(sender=self,
+        self._quick_apply_sig(sender=self,
                 target=target, args=args, kwargs=kwargs)
-        self._pool.spawn_n(apply_target, target, args, kwargs,
-                           callback, accept_callback,
-                           self.getpid)
+        self._quick_put(apply_target, target, args, kwargs,
+                        callback, accept_callback,
+                        self.getpid)

+ 3 - 2
celery/concurrency/gevent.py

@@ -87,6 +87,7 @@ class TaskPool(BasePool):
 
     def on_start(self):
         self._pool = self.Pool(self.limit)
+        self._quick_put = self._pool.spawn
 
     def on_stop(self):
         if self._pool is not None:
@@ -94,8 +95,8 @@ class TaskPool(BasePool):
 
     def on_apply(self, target, args=None, kwargs=None, callback=None,
             accept_callback=None, **_):
-        return self._pool.spawn(apply_target, target, args, kwargs,
-                                callback, accept_callback)
+        return self._quick_put(apply_target, target, args, kwargs,
+                               callback, accept_callback)
 
     def grow(self, n=1):
         self._pool._semaphore.counter += n

+ 4 - 0
celery/concurrency/processes/__init__.py

@@ -48,6 +48,10 @@ def process_initializer(app, hostname):
     app.loader.init_worker()
     app.loader.init_worker_process()
     app.finalize()
+
+    from celery.task.trace import build_tracer
+    for name, task in app.tasks.iteritems():
+        task.__tracer__ = build_tracer(name, task, app.loader, hostname)
     signals.worker_process_init.send(sender=None)
 
 

+ 4 - 2
celery/concurrency/threads.py

@@ -29,6 +29,8 @@ class TaskPool(BasePool):
         # threadpool stores all work requests until they are processed
         # we don't need this dict, and it occupies way too much memory.
         self._pool.workRequests = NullDict()
+        self._quick_put = self._pool.putRequest
+        self._quick_clear = self._pool._results_queue.queue.clear
 
     def on_stop(self):
         self._pool.dismissWorkers(self.limit, do_join=True)
@@ -37,10 +39,10 @@ class TaskPool(BasePool):
             accept_callback=None, **_):
         req = self.WorkRequest(apply_target, (target, args, kwargs, callback,
                                               accept_callback))
-        self._pool.putRequest(req)
+        self._quick_put(req)
         # threadpool also has callback support,
         # but for some reason the callback is not triggered
         # before you've collected the results.
         # Clear the results (if any), so it doesn't grow too large.
-        self._pool._results_queue.queue.clear()
+        self._quick_clear()
         return req

+ 3 - 5
celery/datastructures.py

@@ -386,12 +386,13 @@ class LimitedSet(object):
     :keyword expires: Time in seconds, before a membership expires.
 
     """
-    __slots__ = ("maxlen", "expires", "_data")
+    __slots__ = ("maxlen", "expires", "_data", "__len__")
 
     def __init__(self, maxlen=None, expires=None):
         self.maxlen = maxlen
         self.expires = expires
         self._data = {}
+        self.__len__ = self._data.__len__
 
     def add(self, value):
         """Add a new member."""
@@ -432,10 +433,7 @@ class LimitedSet(object):
         return self._data
 
     def __iter__(self):
-        return iter(self._data.keys())
-
-    def __len__(self):
-        return len(self._data.keys())
+        return iter(self._data)
 
     def __repr__(self):
         return "LimitedSet(%r)" % (self._data.keys(), )

+ 0 - 80
celery/local.py

@@ -277,86 +277,6 @@ class Local(object):
             raise AttributeError(name)
 
 
-class LocalStack(object):
-    """This class works similar to a :class:`Local` but keeps a stack
-    of objects instead.  This is best explained with an example::
-
-        >>> ls = LocalStack()
-        >>> ls.push(42)
-        >>> ls.top
-        42
-        >>> ls.push(23)
-        >>> ls.top
-        23
-        >>> ls.pop()
-        23
-        >>> ls.top
-        42
-
-    They can be force released by using a :class:`LocalManager` or with
-    the :func:`release_local` function but the correct way is to pop the
-    item from the stack after using.  When the stack is empty it will
-    no longer be bound to the current context (and as such released).
-
-    By calling the stack without arguments it returns a proxy that resolves to
-    the topmost item on the stack.
-
-    """
-
-    def __init__(self):
-        self._local = Local()
-
-    def __release_local__(self):
-        self._local.__release_local__()
-
-    def _get__ident_func__(self):
-        return self._local.__ident_func__
-
-    def _set__ident_func__(self, value):
-        object.__setattr__(self._local, '__ident_func__', value)
-    __ident_func__ = property(_get__ident_func__, _set__ident_func__)
-    del _get__ident_func__, _set__ident_func__
-
-    def __call__(self):
-        def _lookup():
-            rv = self.top
-            if rv is None:
-                raise RuntimeError('object unbound')
-            return rv
-        return Proxy(_lookup)
-
-    def push(self, obj):
-        """Pushes a new item to the stack"""
-        rv = getattr(self._local, 'stack', None)
-        if rv is None:
-            self._local.stack = rv = []
-        rv.append(obj)
-        return rv
-
-    def pop(self):
-        """Removes the topmost item from the stack, will return the
-        old value or `None` if the stack was already empty.
-        """
-        stack = getattr(self._local, 'stack', None)
-        if stack is None:
-            return None
-        elif len(stack) == 1:
-            release_local(self._local)
-            return stack[-1]
-        else:
-            return stack.pop()
-
-    @property
-    def top(self):
-        """The topmost item on the stack.  If the stack is empty,
-        `None` is returned.
-        """
-        try:
-            return self._local.stack[-1]
-        except (AttributeError, IndexError):
-            return None
-
-
 class LocalManager(object):
     """Local objects cannot manage themselves. For that you need a local
     manager.  You can pass a local manager multiple locals or add them later

+ 2 - 1
celery/state.py

@@ -2,7 +2,8 @@ from __future__ import absolute_import
 
 import threading
 
-from celery.local import Proxy, LocalStack
+from celery.local import Proxy
+from celery.utils.threads import LocalStack
 
 default_app = None
 

+ 8 - 2
celery/task/trace.py

@@ -185,6 +185,8 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
     request_stack = task.request_stack
     push_request = request_stack.push
     pop_request = request_stack.pop
+    push_task = _task_stack.push
+    pop_task = _task_stack.pop
     on_chord_part_return = backend.on_chord_part_return
 
     from celery import canvas
@@ -194,7 +196,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         R = I = None
         kwargs = kwdict(kwargs)
         try:
-            _task_stack.push(task)
+            push_task(task)
             task_request = Context(request or {}, args=args,
                                    called_directly=False, kwargs=kwargs)
             push_request(task_request)
@@ -258,7 +260,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                                  args=args, kwargs=kwargs,
                                  retval=retval, state=state)
             finally:
-                _task_stack.pop()
+                pop_task()
                 pop_request()
                 if not eager:
                     try:
@@ -287,6 +289,10 @@ def trace_task(task, uuid, args, kwargs, request=None, **opts):
         return report_internal_error(task, exc), None
 
 
+def trace_task_ret(task, uuid, args, kwargs, request):
+    task.__tracer__(uuid, args, kwargs, request)
+
+
 def eager_trace_task(task, uuid, args, kwargs, request=None, **opts):
     opts.setdefault("eager", True)
     return build_tracer(task.name, task, **opts)(

+ 15 - 0
celery/utils/threads.py

@@ -81,3 +81,18 @@ class bgThread(Thread):
         self._is_stopped.wait()
         if self.is_alive():
             self.join(1e100)
+
+
+class LocalStack(threading.local):
+
+    def __init__(self):
+        self.stack = []
+        self.push = self.stack.append
+        self.pop = self.stack.pop
+
+    @property
+    def top(self):
+        try:
+            return self.stack[-1]
+        except (AttributeError, IndexError):
+            return None

+ 1 - 1
celery/worker/__init__.py

@@ -353,7 +353,7 @@ class WorkController(configurated):
     def process_task(self, req):
         """Process task by sending it to the pool of workers."""
         try:
-            req.task.execute(req, self.pool, self.loglevel, self.logfile)
+            req.execute_using_pool(self.pool)
         except Exception, exc:
             logger.critical("Internal error: %r\n%s",
                             exc, traceback.format_exc(), exc_info=True)

+ 15 - 7
celery/worker/consumer.py

@@ -89,6 +89,7 @@ from kombu.utils.eventio import READ, WRITE, ERR
 from celery.app import app_or_default
 from celery.datastructures import AttributeDict
 from celery.exceptions import InvalidTaskError, SystemTerminate
+from celery.task.trace import build_tracer
 from celery.utils import timer2
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
@@ -145,6 +146,8 @@ Consumer: Connection to broker lost. \
 Trying to re-establish the connection...\
 """
 
+task_reserved = state.task_reserved
+
 logger = get_logger(__name__)
 info, warn, error, crit = (logger.info, logger.warn,
                            logger.error, logger.critical)
@@ -336,11 +339,16 @@ class Consumer(object):
         if hub:
             hub.on_init.append(self.on_poll_init)
         self.hub = hub
+        self._quick_put = self.ready_queue.put
 
     def update_strategies(self):
         S = self.strategies
-        for task in self.app.tasks.itervalues():
-            S[task.name] = task.start_strategy(self.app, self)
+        app = self.app
+        loader = app.loader
+        hostname = self.hostname
+        for name, task in self.app.tasks.iteritems():
+            S[name] = task.start_strategy(app, self)
+            task.__tracer__ = build_tracer(name, task, loader, hostname)
 
     def start(self):
         """Start the consumer.
@@ -456,7 +464,7 @@ class Consumer(object):
                 else:
                     sleep(min(time_to_sleep, 0.1))
 
-    def on_task(self, task):
+    def on_task(self, task, task_reserved=task_reserved):
         """Handle received task.
 
         If the task has an `eta` we enter it into the ETA schedule,
@@ -489,8 +497,8 @@ class Consumer(object):
                 self.timer.apply_at(eta, self.apply_eta_task, (task, ),
                                     priority=6)
         else:
-            state.task_reserved(task)
-            self.ready_queue.put(task)
+            task_reserved(task)
+            self._quick_put(task)
 
     def on_control(self, body, message):
         """Process remote control command message."""
@@ -505,8 +513,8 @@ class Consumer(object):
     def apply_eta_task(self, task):
         """Method called by the timer to apply a task with an
         ETA/countdown."""
-        state.task_reserved(task)
-        self.ready_queue.put(task)
+        task_reserved(task)
+        self._quick_put(task)
         self.qos.decrement_eventually()
 
     def _message_report(self, body, message):

+ 19 - 20
celery/worker/job.py

@@ -29,6 +29,7 @@ from celery.datastructures import ExceptionInfo
 from celery.task.trace import (
     build_tracer,
     trace_task,
+    trace_task_ret,
     report_internal_error,
     execute_bare,
 )
@@ -52,6 +53,10 @@ tz_to_local = timezone.to_local
 tz_or_local = timezone.tz_or_local
 tz_utc = timezone.utc
 
+task_accepted = state.task_accepted
+task_ready = state.task_ready
+revoked_tasks = state.revoked
+
 NEEDS_KWDICT = sys.version_info <= (2, 6)
 
 
@@ -190,20 +195,16 @@ class Request(object):
         kwargs.update(extend_with)
         return kwargs
 
-    def execute_using_pool(self, pool, loglevel=None, logfile=None):
+    def execute_using_pool(self, pool, **kwargs):
         """Like :meth:`execute`, but using a worker pool.
 
-        :param pool: A :class:`multiprocessing.Pool` instance.
-
-        :keyword loglevel: The loglevel used by the task.
-
-        :keyword logfile: The logfile used by the task.
+        :param pool: A :class:`celery.concurrency.base.TaskPool` instance.
 
         """
         task = self.task
         if self.flags & 0x004:
             return pool.apply_async(execute_bare,
-                    args=(self.task, self.id, self.args, self.kwargs),
+                    args=(task, self.id, self.args, self.kwargs),
                     accept_callback=self.on_accepted,
                     timeout_callback=self.on_timeout,
                     callback=self.on_success,
@@ -215,16 +216,14 @@ class Request(object):
 
         hostname = self.hostname
         kwargs = self.kwargs
-        if self.task.accept_magic_kwargs:
+        if task.accept_magic_kwargs:
             kwargs = self.extend_with_default_kwargs(loglevel, logfile)
         request = self.request_dict
-        request.update({"loglevel": loglevel, "logfile": logfile,
-                        "hostname": hostname, "is_eager": False,
+        request.update({"hostname": hostname, "is_eager": False,
                         "delivery_info": self.delivery_info})
-        result = pool.apply_async(execute_and_trace,
-                                  args=(self.name, self.id, self.args, kwargs),
-                                  kwargs={"hostname": hostname,
-                                          "request": request},
+        result = pool.apply_async(trace_task_ret,
+                                  args=(task, self.id,
+                                        self.args, kwargs, request),
                                   accept_callback=self.on_accepted,
                                   timeout_callback=self.on_timeout,
                                   callback=self.on_success,
@@ -264,7 +263,7 @@ class Request(object):
     def maybe_expire(self):
         """If expired, mark the task as revoked."""
         if self.expires and datetime.now(self.tzlocal) > self.expires:
-            state.revoked.add(self.id)
+            revoked_tasks.add(self.id)
             if self.store_errors:
                 self.task.backend.mark_as_revoked(self.id)
 
@@ -280,7 +279,7 @@ class Request(object):
             return True
         if self.expires:
             self.maybe_expire()
-        if self.id in state.revoked:
+        if self.id in revoked_tasks:
             warn("Skipping revoked task: %s[%s]", self.name, self.id)
             self.send_event("task-revoked", uuid=self.id)
             self.acknowledge()
@@ -296,7 +295,7 @@ class Request(object):
         """Handler called when task is accepted by worker pool."""
         self.worker_pid = pid
         self.time_start = time_accepted
-        state.task_accepted(self)
+        task_accepted(self)
         if not self.task.acks_late:
             self.acknowledge()
         self.send_event("task-started", uuid=self.id, pid=pid)
@@ -308,7 +307,7 @@ class Request(object):
 
     def on_timeout(self, soft, timeout):
         """Handler called if the task times out."""
-        state.task_ready(self)
+        task_ready(self)
         if soft:
             warn("Soft time limit (%ss) exceeded for %s[%s]",
                  timeout, self.name, self.id)
@@ -328,7 +327,7 @@ class Request(object):
                     SystemExit, KeyboardInterrupt)):
                 raise ret_value.exception
             return self.on_failure(ret_value)
-        state.task_ready(self)
+        task_ready(self)
 
         if self.task.acks_late:
             self.acknowledge()
@@ -360,7 +359,7 @@ class Request(object):
 
     def on_failure(self, exc_info):
         """Handler called if the task raised an exception."""
-        state.task_ready(self)
+        task_ready(self)
 
         if not exc_info.internal: