Bladeren bron

Tests passing

Ask Solem 13 jaren geleden
bovenliggende
commit
0ea86986dc

+ 2 - 41
celery/concurrency/base.py

@@ -7,8 +7,6 @@ import sys
 import time
 import traceback
 
-from functools import partial
-
 from .. import log
 from ..datastructures import ExceptionInfo
 from ..utils import timer2
@@ -76,57 +74,20 @@ class BasePool(object):
         self.on_start()
         self._state = self.RUN
 
-    def apply_async(self, target, args=None, kwargs=None, callback=None,
-            errback=None, accept_callback=None, timeout_callback=None,
-            soft_timeout=None, timeout=None, **compat):
+    def apply_async(self, target, args=[], kwargs={}, **options):
         """Equivalent of the :func:`apply` built-in function.
 
         Callbacks should optimally return as soon as possible since
         otherwise the thread which handles the result will get blocked.
 
         """
-        args = args or []
-        kwargs = kwargs or {}
-
-        on_ready = partial(self.on_ready, callback, errback)
-        on_worker_error = partial(self.on_worker_error, errback)
-
         if self._does_debug:
             self.logger.debug("TaskPool: Apply %s (args:%s kwargs:%s)",
                             target, safe_repr(args), safe_repr(kwargs))
 
         return self.on_apply(target, args, kwargs,
-                             callback=on_ready,
-                             accept_callback=accept_callback,
-                             timeout_callback=timeout_callback,
-                             error_callback=on_worker_error,
                              waitforslot=self.putlocks,
-                             soft_timeout=soft_timeout,
-                             timeout=timeout)
-
-    def on_ready(self, callback, errback, ret_value):
-        """What to do when a worker task is ready and its return value has
-        been collected."""
-
-        if isinstance(ret_value, ExceptionInfo):
-            if isinstance(ret_value.exception, (
-                    SystemExit, KeyboardInterrupt)):
-                raise ret_value.exception
-            self.safe_apply_callback(errback, ret_value)
-        else:
-            self.safe_apply_callback(callback, ret_value)
-
-    def on_worker_error(self, errback, exc_info):
-        errback(exc_info)
-
-    def safe_apply_callback(self, fun, *args):
-        if fun:
-            try:
-                fun(*args)
-            except BaseException:
-                self.logger.error("Pool callback raised exception: %s",
-                                  traceback.format_exc(),
-                                  exc_info=sys.exc_info())
+                             **options)
 
     def _get_info(self):
         return {}

+ 22 - 9
celery/concurrency/processes/pool.py

@@ -12,17 +12,18 @@ from __future__ import absolute_import
 # Imports
 #
 
+import collections
+import errno
+import itertools
+import logging
 import os
+import signal
 import sys
-import errno
 import threading
-import Queue
-import itertools
-import collections
 import time
-import signal
+import traceback
+import Queue
 import warnings
-import logging
 
 from multiprocessing import Process, cpu_count, TimeoutError, Event
 from multiprocessing import util
@@ -68,6 +69,15 @@ def error(msg, *args, **kwargs):
         util._logger.error(msg, *args, **kwargs)
 
 
+def safe_apply_callback(fun, *args):
+    if fun:
+        try:
+            fun(*args)
+        except BaseException, exc:
+            error("Pool callback raised exception: %r", exc,
+                  exc_info=sys.exc_info())
+
+
 class LaxBoundedSemaphore(threading._Semaphore):
     """Semaphore that checks that # release is <= # acquires,
     but ignores if # releases >= value."""
@@ -1020,9 +1030,11 @@ class ApplyResult(object):
 
             # apply callbacks last
             if self._callback and self._success:
-                self._callback(self._value)
+                safe_apply_callback(
+                    self._callback, self._value)
             if self._errback and not self._success:
-                self._errback(self._value)
+                safe_apply_callback(
+                    self._errback, self._value)
         finally:
             self._mutex.release()
 
@@ -1035,7 +1047,8 @@ class ApplyResult(object):
             if self._ready:
                 self._cache.pop(self._job, None)
             if self._accept_callback:
-                self._accept_callback(pid, time_accepted)
+                safe_apply_callback(
+                    self._accept_callback, pid, time_accepted)
         finally:
             self._mutex.release()
 

+ 20 - 51
celery/tests/test_concurrency/test_concurrency_processes.py

@@ -7,19 +7,33 @@ import time
 
 from itertools import cycle
 
-from mock import patch
+from mock import Mock, patch
 from nose import SkipTest
 
 try:
     from celery.concurrency import processes as mp
+    from celery.concurrency.processes.pool import safe_apply_callback
 except ImportError:
 
     class _mp(object):
         RUN = 0x1
 
         class TaskPool(object):
-            pass
+            _pool = Mock()
+
+            def __init__(self, *args, **kwargs):
+                pass
+
+            def start(self):
+                pass
+
+            def stop(self):
+                pass
+
+            def apply_async(self, *args, **kwargs):
+                pass
     mp = _mp()  # noqa
+    safe_apply_callback = None
 
 from celery.datastructures import ExceptionInfo
 from celery.utils import noop
@@ -128,36 +142,9 @@ class test_TaskPool(unittest.TestCase):
         pool.terminate()
         self.assertTrue(_pool.terminated)
 
-    def test_on_worker_error(self):
-        scratch = [None]
-
-        def errback(einfo):
-            scratch[0] = einfo
-
-        pool = TaskPool(10)
-        exc_info = None
-        try:
-            raise KeyError("foo")
-        except KeyError:
-            exc_info = ExceptionInfo(sys.exc_info())
-        pool.on_worker_error(errback, exc_info)
-
-        self.assertTrue(scratch[0])
-        self.assertIs(scratch[0], exc_info)
-
-    def test_on_ready_exception(self):
-        scratch = [None]
-
-        def errback(retval):
-            scratch[0] = retval
-
-        pool = TaskPool(10)
-        exc = to_excinfo(KeyError("foo"))
-        pool.on_ready(None, errback, exc)
-        self.assertEqual(exc, scratch[0])
-
     def test_safe_apply_callback(self):
-
+        if safe_apply_callback is None:
+            raise SkipTest("multiprocessig not supported")
         _good_called = [0]
         _evil_called = [0]
 
@@ -169,29 +156,11 @@ class test_TaskPool(unittest.TestCase):
             _evil_called[0] = 1
             raise KeyError(x)
 
-        pool = TaskPool(10)
-        self.assertIsNone(pool.safe_apply_callback(good, 10))
-        self.assertIsNone(pool.safe_apply_callback(evil, 10))
+        self.assertIsNone(safe_apply_callback(good, 10))
+        self.assertIsNone(safe_apply_callback(evil, 10))
         self.assertTrue(_good_called[0])
         self.assertTrue(_evil_called[0])
 
-    def test_on_ready_value(self):
-        scratch = [None]
-
-        def callback(retval):
-            scratch[0] = retval
-
-        pool = TaskPool(10)
-        retval = "the quick brown fox"
-        pool.on_ready(callback, None, retval)
-        self.assertEqual(retval, scratch[0])
-
-    def test_on_ready_exit_exception(self):
-        pool = TaskPool(10)
-        exc = to_excinfo(SystemExit("foo"))
-        with self.assertRaises(SystemExit):
-            pool.on_ready([], [], exc)
-
     def test_apply_async(self):
         pool = TaskPool(10)
         pool.start()

+ 1 - 0
celery/tests/test_worker/test_worker_job.py

@@ -198,6 +198,7 @@ class MockEventDispatcher(object):
 
     def __init__(self):
         self.sent = []
+        self.enabled = True
 
     def send(self, event, **fields):
         self.sent.append(event)

+ 20 - 10
celery/worker/job.py

@@ -19,6 +19,7 @@ import socket
 from datetime import datetime
 
 from .. import exceptions
+from ..datastructures import ExceptionInfo
 from ..registry import tasks
 from ..app import app_or_default
 from ..execute.trace import build_tracer, trace_task, report_internal_error
@@ -258,7 +259,7 @@ class TaskRequest(object):
                                   accept_callback=self.on_accepted,
                                   timeout_callback=self.on_timeout,
                                   callback=self.on_success,
-                                  errback=self.on_failure,
+                                  error_callback=self.on_failure,
                                   soft_timeout=self.task.soft_time_limit,
                                   timeout=self.task.time_limit)
         return result
@@ -315,7 +316,7 @@ class TaskRequest(object):
         return False
 
     def send_event(self, type, **fields):
-        if self.eventer:
+        if self.eventer and self.eventer.enabled:
             self.eventer.send(type, **fields)
 
     def on_accepted(self, pid, time_accepted):
@@ -348,23 +349,32 @@ class TaskRequest(object):
         if self.store_errors:
             self.task.backend.mark_as_failure(self.id, exc)
 
-    def on_success(self, ret_value):
+    def on_success(self, ret_value, now=None):
         """Handler called if the task was successfully processed."""
+        if isinstance(ret_value, ExceptionInfo):
+            if isinstance(ret_value.exception, (
+                    SystemExit, KeyboardInterrupt)):
+                raise ret_value.exception
+            return self.on_failure(ret_value)
         state.task_ready(self)
 
         if self.task.acks_late:
             self.acknowledge()
 
-        runtime = self.time_start and (time.time() - self.time_start) or 0
-        self.send_event("task-succeeded", uuid=self.id,
-                        result=safe_repr(ret_value), runtime=runtime)
+        if self.eventer and self.eventer.enabled:
+            now = time.time()
+            runtime = self.time_start and (time.time() - self.time_start) or 0
+            self.send_event("task-succeeded", uuid=self.id,
+                            result=safe_repr(ret_value), runtime=runtime)
 
         if self._does_info:
+            now = now or time.time()
+            runtime = self.time_start and (time.time() - self.time_start) or 0
             self.logger.info(self.success_msg.strip(),
-                            {"id": self.id,
-                             "name": self.name,
-                             "return_value": self.repr_result(ret_value),
-                             "runtime": runtime})
+                        {"id": self.id,
+                        "name": self.name,
+                        "return_value": self.repr_result(ret_value),
+                        "runtime": runtime})
 
     def on_retry(self, exc_info):
         """Handler called if the task should be retried."""

+ 2 - 2
funtests/benchmarks/bench_worker.py

@@ -27,8 +27,8 @@ celery.conf.update(BROKER_TRANSPORT="librabbitmq",
                            "exchange": "bench.worker",
                            "routing_key": "bench.worker",
                            "no_ack": True,
-                           "exchange_durable": False,
-                           "queue_durable": False,
+                           #"exchange_durable": False,
+                           #"queue_durable": False,
                         }
                    },
                    CELERY_TASK_SERIALIZER="json",