Browse Source

100% coverage for celery.worker.hub

Ask Solem 12 years ago
parent
commit
a686b662f9
6 changed files with 162 additions and 127 deletions
  1. 2 9
      celery/app/amqp.py
  2. 6 7
      celery/app/task.py
  3. 4 28
      celery/task/trace.py
  4. 34 16
      celery/tests/worker/test_request.py
  5. 113 30
      celery/worker/hub.py
  6. 3 37
      celery/worker/job.py

+ 2 - 9
celery/app/amqp.py

@@ -30,10 +30,6 @@ QUEUE_FORMAT = """
 . %(name)s exchange:%(exchange)s(%(exchange_type)s) binding:%(routing_key)s
 """
 
-TASK_BARE = 0x004
-TASK_DEFAULT = 0
-
-
 class Queues(dict):
     """Queue name⇒ declaration mapping.
 
@@ -154,7 +150,7 @@ class TaskProducer(Producer):
             queue=None, now=None, retries=0, chord=None, callbacks=None,
             errbacks=None, mandatory=None, priority=None, immediate=None,
             routing_key=None, serializer=None, delivery_mode=None,
-            compression=None, bare=False, **kwargs):
+            compression=None, **kwargs):
         """Send task message."""
         # merge default and custom policy
         _rp = (dict(self.retry_policy, **retry_policy) if retry_policy
@@ -174,8 +170,6 @@ class TaskProducer(Producer):
             expires = now + timedelta(seconds=expires)
         eta = eta and eta.isoformat()
         expires = expires and expires.isoformat()
-        flags = TASK_DEFAULT
-        flags |= TASK_BARE if bare else 0
 
         body = {"task": task_name,
                 "id": task_id,
@@ -186,8 +180,7 @@ class TaskProducer(Producer):
                 "expires": expires,
                 "utc": self.utc,
                 "callbacks": callbacks,
-                "errbacks": errbacks,
-                "flags": flags}
+                "errbacks": errbacks}
         if taskset_id:
             body["taskset"] = taskset_id
         if chord:

+ 6 - 7
celery/app/task.py

@@ -34,13 +34,12 @@ from celery.utils.mail import ErrorMail
 from .annotations import resolve_all as resolve_all_annotations
 from .registry import _unpickle_task
 
-#: extracts options related to publishing a message from a dict.
-extract_exec_options = mattrgetter("queue", "routing_key",
-                                   "exchange", "immediate",
-                                   "mandatory", "priority",
-                                   "serializer", "delivery_mode",
-                                   "compression", "expires", "bare")
-
+#: extracts attributes related to publishing a message from an object.
+extract_exec_options = mattrgetter(
+    "queue", "routing_key", "exchange",
+    "immediate", "mandatory", "priority", "expires",
+    "serializer", "delivery_mode", "compression",
+)
 
 #: Billiard sets this when execv is enabled.
 #: We use it to find out the name of the original ``__main__``

+ 4 - 28
celery/task/trace.py

@@ -130,30 +130,6 @@ class TraceInfo(object):
             del(tb)
 
 
-def execute_bare(task, uuid, args, kwargs, request=None, Info=TraceInfo):
-    R = I = None
-    kwargs = kwdict(kwargs)
-    try:
-        try:
-            R = retval = task(*args, **kwargs)
-            state = SUCCESS
-        except Exception, exc:
-            I = Info(FAILURE, exc)
-            state, retval = I.state, I.retval
-            R = I.handle_error_state(task)
-        except BaseException, exc:
-            raise
-        except:  # pragma: no cover
-            # For Python2.5 where raising strings are still allowed
-            # (but deprecated)
-            I = Info(FAILURE, None)
-            state, retval = I.state, I.retval
-            R = I.handle_error_state(task)
-    except Exception, exc:
-        R = report_internal_error(task, exc)
-    return R
-
-
 def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         Info=TraceInfo, eager=False, propagate=False):
     # If the task doesn't define a custom __call__ method
@@ -282,16 +258,16 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
     return trace_task
 
 
-def trace_task(task, uuid, args, kwargs, request=None, **opts):
+def trace_task(task, uuid, args, kwargs, request={}, **opts):
     try:
         if task.__trace__ is None:
             task.__trace__ = build_tracer(task.name, task, **opts)
-        return task.__trace__(uuid, args, kwargs, request)
+        return task.__trace__(uuid, args, kwargs, request)[0]
     except Exception, exc:
-        return report_internal_error(task, exc), None
+        return report_internal_error(task, exc)
 
 
-def trace_task_ret(task, uuid, args, kwargs, request):
+def trace_task_ret(task, uuid, args, kwargs, request={}):
     return _tasks[task].__trace__(uuid, args, kwargs, request)[0]
 
 

+ 34 - 16
celery/tests/worker/test_request.py

@@ -21,18 +21,23 @@ from celery.concurrency.base import BasePool
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import (RetryTaskError,
                                WorkerLostError, InvalidTaskError)
-from celery.task.trace import eager_trace_task, TraceInfo, mro_lookup
+from celery.task.trace import (
+    trace_task,
+    trace_task_ret,
+    TraceInfo,
+    mro_lookup,
+    build_tracer,
+)
 from celery.result import AsyncResult
 from celery.task import task as task_dec
 from celery.task.base import Task
 from celery.utils import uuid
 from celery.worker import job as module
-from celery.worker.job import Request, TaskRequest, execute_and_trace
+from celery.worker.job import Request, TaskRequest
 from celery.worker.state import revoked
 
 from celery.tests.utils import Case
 
-
 scratch = {"ACK": False}
 some_kwargs_scratchpad = {}
 
@@ -68,8 +73,10 @@ class test_mro_lookup(Case):
 
 def jail(task_id, name, args, kwargs):
     request = {"id": task_id}
-    return eager_trace_task(current_app.tasks[name],
-            task_id, args, kwargs, request=request, eager=False)[0]
+    task = current_app.tasks[name]
+    task.__trace__ = None  # rebuild
+    return trace_task(task,
+            task_id, args, kwargs, request=request, eager=False)
 
 
 def on_ack(*args, **kwargs):
@@ -221,6 +228,7 @@ class MockEventDispatcher(object):
 
 class test_TaskRequest(Case):
 
+
     def test_task_wrapper_repr(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
         self.assertTrue(repr(tw))
@@ -262,8 +270,11 @@ class test_TaskRequest(Case):
             einfo = ExceptionInfo()
             tw.on_failure(einfo)
             self.assertIn("task-retried", tw.eventer.sent)
-            tw._does_info = False
-            tw.on_failure(einfo)
+            prev, module._does_info = module._does_info, False
+            try:
+                tw.on_failure(einfo)
+            finally:
+                module._does_info = prev
             einfo.internal = True
             tw.on_failure(einfo)
 
@@ -408,8 +419,11 @@ class test_TaskRequest(Case):
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
         tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
         self.assertTrue(tw.acknowledged)
-        tw._does_debug = False
-        tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
+        prev, module._does_debug = module._does_debug, False
+        try:
+            tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
+        finally:
+            module._does_debug = prev
 
     def test_on_accepted_acks_late(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
@@ -432,9 +446,12 @@ class test_TaskRequest(Case):
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
         tw.time_start = 1
         tw.on_success(42)
-        tw._does_info = False
-        tw.on_success(42)
-        self.assertFalse(tw.acknowledged)
+        prev, module._does_info = module._does_info, False
+        try:
+            tw.on_success(42)
+            self.assertFalse(tw.acknowledged)
+        finally:
+            module._does_info = prev
 
     def test_on_success_BaseException(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
@@ -539,8 +556,10 @@ class test_TaskRequest(Case):
         finally:
             mytask.ignore_result = False
 
-    def test_execute_and_trace(self):
-        res = execute_and_trace(mytask.name, uuid(), [4], {})
+    def test_trace_task_ret(self):
+        mytask.__trace__ = build_tracer(mytask.name, mytask,
+                                        current_app.loader, "test")
+        res = trace_task_ret(mytask.name, uuid(), [4], {})
         self.assertEqual(res, 4 ** 4)
 
     def test_execute_safe_catches_exception(self):
@@ -554,8 +573,7 @@ class test_TaskRequest(Case):
 
         with self.assertWarnsRegex(RuntimeWarning,
                 r'Exception raised outside'):
-            res = execute_and_trace(raising.name, uuid(),
-                                    [], {})
+            res = trace_task(raising, uuid(), [], {})
             self.assertIsInstance(res, ExceptionInfo)
 
     def test_worker_task_trace_handle_retry(self):

+ 113 - 30
celery/worker/hub.py

@@ -1,76 +1,136 @@
 from __future__ import absolute_import
 
 from kombu.utils import cached_property
-from kombu.utils.eventio import poll, READ, WRITE, ERR
+from kombu.utils import eventio
 
 from celery.utils.timer2 import Schedule
 
+READ, WRITE, ERR = eventio.READ, eventio.WRITE, eventio.ERR
 
-class DummyLock(object):
 
-    def __enter__(self):
-        return self
+class BoundedSemaphore(object):
+    """Asynchronous Bounded Semaphore.
 
-    def __exit__(self, *exc_info):
-        pass
+    Bounded means that the value will stay within the specified
+    range even if it is released more times than it was acquired.
 
+    This type is *not thread safe*.
 
-class BoundedSemaphore(object):
+    Example:
+
+        >>> x = BoundedSemaphore(2)
+
+        >>> def callback(i):
+        ...     print("HELLO %r" % i)
+
+        >>> x.acquire(callback, 1)
+        HELLO 1
+
+        >>> x.acquire(callback, 2)
+        HELLO 2
+
+        >>> x.acquire(callback, 3)
+        >>> x._waiters   # private, do not access directly
+        [(callback, 3)]
+
+        >>> x.release()
+        HELLO 3
 
-    def __init__(self, value=1):
+    """
+
+    def __init__(self, value):
         self.initial_value = self.value = value
         self._waiting = []
 
-    def grow(self):
-        self.initial_value += 1
-        self.release()
+    def acquire(self, callback, *partial_args):
+        """Acquire semaphore, applying ``callback`` when
+        the semaphore is ready.
 
-    def shrink(self):
-        self.initial_value -= 1
+        :param callback: The callback to apply.
+        :param *partial_args: partial arguments to callback.
 
-    def acquire(self, callback, *partial_args, **partial_kwargs):
+        """
         if self.value <= 0:
             self._waiting.append((callback, partial_args))
             return False
         else:
             self.value = max(self.value - 1, 0)
-            callback(*partial_args, **partial_kwargs)
+            callback(*partial_args)
             return True
 
     def release(self):
+        """Release semaphore.
+
+        This will apply any waiting callbacks from previous
+        calls to :meth:`acquire` done when the semaphore was busy.
+
+        """
         self.value = min(self.value + 1, self.initial_value)
         if self._waiting:
             waiter, args = self._waiting.pop()
             waiter(*args)
 
+    def grow(self, n=1):
+        """Change the size of the semaphore to hold more values."""
+        self.initial_value += n
+        self.value += n
+        [self.release() for _ in xrange(n)]
+
+    def shrink(self, n=1):
+        """Change the size of the semaphore to hold less values."""
+        self.initial_value = max(self.initial_value - n, 0)
+        self.value = max(self.value - n, 0)
+
     def clear(self):
-        pass
+        """Reset the sempahore, including wiping out any waiting callbacks."""
+        self._waiting[:] = []
+        self.value = self.initial_value
 
 
 class Hub(object):
-    READ, WRITE, ERR = READ, WRITE, ERR
+    """Event loop object.
+
+    :keyword timer: Specify custom :class:`~celery.utils.timer2.Schedule`.
+
+    """
+    #: Flag set if reading from an fd will not block.
+    READ = READ
+
+    #: Flag set if writing to an fd will not block.
+    WRITE = WRITE
+
+    #: Flag set on error, and the fd should be read from asap.
+    ERR = ERR
+
+    #: List of callbacks to be called when the loop is initialized,
+    #: applied with the hub instance as sole argument.
+    on_init = None
+
+    #: List of callbacks to be called when the loop is exiting,
+    #: applied with the hub instance as sole argument.
+    on_close = None
+
+    #: List of callbacks to be called when a task is received.
+    #: Takes no arguments.
+    on_task = None
 
     def __init__(self, timer=None):
+        self.timer = Schedule() if timer is None else timer
+
         self.readers = {}
         self.writers = {}
-        self.timer = Schedule() if timer is None else timer
         self.on_init = []
         self.on_close = []
         self.on_task = []
 
     def start(self):
-        self.poller = poll()
+        """Called by StartStopComponent at worker startup."""
+        self.poller = eventio.poll()
 
     def stop(self):
+        """Called by StartStopComponent at worker shutdown."""
         self.poller.close()
 
-    def __enter__(self):
-        self.init()
-        return self
-
-    def __exit__(self, *exc_info):
-        return self.close()
-
     def init(self):
         for callback in self.on_init:
             callback(self)
@@ -106,18 +166,41 @@ class Hub(object):
     def update_writers(self, map):
         [self.add_writer(*x) for x in map.iteritems()]
 
-    def remove(self, fd):
+    def _unregister(self, fd):
         try:
             self.poller.unregister(fd)
         except (KeyError, OSError):
             pass
 
-    def close(self):
-        [self.remove(fd) for fd in self.readers.keys()]
-        [self.remove(fd) for fd in self.writers.keys()]
+    def remove(self, fd):
+        fileno = fd.fileno() if not isinstance(fd, int) else fd
+        self.readers.pop(fileno, None)
+        self.writers.pop(fileno, None)
+        self._unregister(fd)
+
+    def __enter__(self):
+        self.init()
+        return self
+
+    def close(self, *args):
+        [self._unregister(fd) for fd in self.readers]
+        self.readers.clear()
+        [self._unregister(fd) for fd in self.writers]
+        self.writers.clear()
         for callback in self.on_close:
             callback(self)
+    __exit__ = close
 
     @cached_property
     def scheduler(self):
         return iter(self.timer)
+
+
+class DummyLock(object):
+    """Pretending to be a lock."""
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *exc_info):
+        pass

+ 3 - 37
celery/worker/job.py

@@ -31,7 +31,6 @@ from celery.task.trace import (
     trace_task,
     trace_task_ret,
     report_internal_error,
-    execute_bare,
 )
 from celery.platforms import set_mp_process_title as setps
 from celery.utils import fun_takes_kwargs
@@ -60,35 +59,13 @@ revoked_tasks = state.revoked
 NEEDS_KWDICT = sys.version_info <= (2, 6)
 
 
-def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
-    """This is a pickleable method used as a target when applying to pools.
-
-    It's the same as::
-
-        >>> trace_task(name, *args, **kwargs)[0]
-
-    """
-    task = current_app.tasks[name]
-    try:
-        hostname = opts.get("hostname")
-        setps("celeryd", name, hostname, rate_limit=True)
-        try:
-            if task.__trace__ is None:
-                task.__trace__ = build_tracer(name, task, **opts)
-            return task.__trace__(uuid, args, kwargs, request)[0]
-        finally:
-            setps("celeryd", "-idle-", hostname, rate_limit=True)
-    except Exception, exc:
-        return report_internal_error(task, exc)
-
-
 class Request(object):
     """A request for task execution."""
     __slots__ = ("app", "name", "id", "args", "kwargs",
                  "on_ack", "delivery_info", "hostname",
                  "callbacks", "errbacks",
                  "eventer", "connection_errors",
-                 "task", "eta", "expires", "flags",
+                 "task", "eta", "expires",
                  "request_dict", "acknowledged", "success_msg",
                  "error_msg", "retry_msg", "time_start", "worker_pid",
                  "_already_revoked", "_terminate_on_ack", "_tzlocal")
@@ -130,7 +107,6 @@ class Request(object):
         eta = body.get("eta")
         expires = body.get("expires")
         utc = body.get("utc", False)
-        self.flags = body.get("flags", False)
         self.on_ack = on_ack
         self.hostname = hostname or socket.gethostname()
         self.eventer = eventer
@@ -202,15 +178,6 @@ class Request(object):
 
         """
         task = self.task
-        if self.flags & 0x004:
-            return pool.apply_async(execute_bare,
-                    args=(task, self.id, self.args, self.kwargs),
-                    accept_callback=self.on_accepted,
-                    timeout_callback=self.on_timeout,
-                    callback=self.on_success,
-                    error_callback=self.on_failure,
-                    soft_timeout=task.soft_time_limit,
-                    timeout=task.time_limit)
         if self.revoked():
             return
 
@@ -253,10 +220,9 @@ class Request(object):
         request.update({"loglevel": loglevel, "logfile": logfile,
                         "hostname": self.hostname, "is_eager": False,
                         "delivery_info": self.delivery_info})
-        retval, _ = trace_task(self.task, self.id, self.args, kwargs,
+        retval = trace_task(self.task, self.id, self.args, kwargs, request,
                                **{"hostname": self.hostname,
-                                  "loader": self.app.loader,
-                                  "request": request})
+                                  "loader": self.app.loader})
         self.acknowledge()
         return retval