Jelajahi Sumber

Tests passing

Ask Solem 13 tahun lalu
induk
melakukan
0ea86986dc

+ 2 - 41
celery/concurrency/base.py

@@ -7,8 +7,6 @@ import sys
 import time
 import time
 import traceback
 import traceback
 
 
-from functools import partial
-
 from .. import log
 from .. import log
 from ..datastructures import ExceptionInfo
 from ..datastructures import ExceptionInfo
 from ..utils import timer2
 from ..utils import timer2
@@ -76,57 +74,20 @@ class BasePool(object):
         self.on_start()
         self.on_start()
         self._state = self.RUN
         self._state = self.RUN
 
 
-    def apply_async(self, target, args=None, kwargs=None, callback=None,
-            errback=None, accept_callback=None, timeout_callback=None,
-            soft_timeout=None, timeout=None, **compat):
+    def apply_async(self, target, args=[], kwargs={}, **options):
         """Equivalent of the :func:`apply` built-in function.
         """Equivalent of the :func:`apply` built-in function.
 
 
         Callbacks should optimally return as soon as possible since
         Callbacks should optimally return as soon as possible since
         otherwise the thread which handles the result will get blocked.
         otherwise the thread which handles the result will get blocked.
 
 
         """
         """
-        args = args or []
-        kwargs = kwargs or {}
-
-        on_ready = partial(self.on_ready, callback, errback)
-        on_worker_error = partial(self.on_worker_error, errback)
-
         if self._does_debug:
         if self._does_debug:
             self.logger.debug("TaskPool: Apply %s (args:%s kwargs:%s)",
             self.logger.debug("TaskPool: Apply %s (args:%s kwargs:%s)",
                             target, safe_repr(args), safe_repr(kwargs))
                             target, safe_repr(args), safe_repr(kwargs))
 
 
         return self.on_apply(target, args, kwargs,
         return self.on_apply(target, args, kwargs,
-                             callback=on_ready,
-                             accept_callback=accept_callback,
-                             timeout_callback=timeout_callback,
-                             error_callback=on_worker_error,
                              waitforslot=self.putlocks,
                              waitforslot=self.putlocks,
-                             soft_timeout=soft_timeout,
-                             timeout=timeout)
-
-    def on_ready(self, callback, errback, ret_value):
-        """What to do when a worker task is ready and its return value has
-        been collected."""
-
-        if isinstance(ret_value, ExceptionInfo):
-            if isinstance(ret_value.exception, (
-                    SystemExit, KeyboardInterrupt)):
-                raise ret_value.exception
-            self.safe_apply_callback(errback, ret_value)
-        else:
-            self.safe_apply_callback(callback, ret_value)
-
-    def on_worker_error(self, errback, exc_info):
-        errback(exc_info)
-
-    def safe_apply_callback(self, fun, *args):
-        if fun:
-            try:
-                fun(*args)
-            except BaseException:
-                self.logger.error("Pool callback raised exception: %s",
-                                  traceback.format_exc(),
-                                  exc_info=sys.exc_info())
+                             **options)
 
 
     def _get_info(self):
     def _get_info(self):
         return {}
         return {}

+ 22 - 9
celery/concurrency/processes/pool.py

@@ -12,17 +12,18 @@ from __future__ import absolute_import
 # Imports
 # Imports
 #
 #
 
 
+import collections
+import errno
+import itertools
+import logging
 import os
 import os
+import signal
 import sys
 import sys
-import errno
 import threading
 import threading
-import Queue
-import itertools
-import collections
 import time
 import time
-import signal
+import traceback
+import Queue
 import warnings
 import warnings
-import logging
 
 
 from multiprocessing import Process, cpu_count, TimeoutError, Event
 from multiprocessing import Process, cpu_count, TimeoutError, Event
 from multiprocessing import util
 from multiprocessing import util
@@ -68,6 +69,15 @@ def error(msg, *args, **kwargs):
         util._logger.error(msg, *args, **kwargs)
         util._logger.error(msg, *args, **kwargs)
 
 
 
 
+def safe_apply_callback(fun, *args):
+    if fun:
+        try:
+            fun(*args)
+        except BaseException, exc:
+            error("Pool callback raised exception: %r", exc,
+                  exc_info=sys.exc_info())
+
+
 class LaxBoundedSemaphore(threading._Semaphore):
 class LaxBoundedSemaphore(threading._Semaphore):
     """Semaphore that checks that # release is <= # acquires,
     """Semaphore that checks that # release is <= # acquires,
     but ignores if # releases >= value."""
     but ignores if # releases >= value."""
@@ -1020,9 +1030,11 @@ class ApplyResult(object):
 
 
             # apply callbacks last
             # apply callbacks last
             if self._callback and self._success:
             if self._callback and self._success:
-                self._callback(self._value)
+                safe_apply_callback(
+                    self._callback, self._value)
             if self._errback and not self._success:
             if self._errback and not self._success:
-                self._errback(self._value)
+                safe_apply_callback(
+                    self._errback, self._value)
         finally:
         finally:
             self._mutex.release()
             self._mutex.release()
 
 
@@ -1035,7 +1047,8 @@ class ApplyResult(object):
             if self._ready:
             if self._ready:
                 self._cache.pop(self._job, None)
                 self._cache.pop(self._job, None)
             if self._accept_callback:
             if self._accept_callback:
-                self._accept_callback(pid, time_accepted)
+                safe_apply_callback(
+                    self._accept_callback, pid, time_accepted)
         finally:
         finally:
             self._mutex.release()
             self._mutex.release()
 
 

+ 20 - 51
celery/tests/test_concurrency/test_concurrency_processes.py

@@ -7,19 +7,33 @@ import time
 
 
 from itertools import cycle
 from itertools import cycle
 
 
-from mock import patch
+from mock import Mock, patch
 from nose import SkipTest
 from nose import SkipTest
 
 
 try:
 try:
     from celery.concurrency import processes as mp
     from celery.concurrency import processes as mp
+    from celery.concurrency.processes.pool import safe_apply_callback
 except ImportError:
 except ImportError:
 
 
     class _mp(object):
     class _mp(object):
         RUN = 0x1
         RUN = 0x1
 
 
         class TaskPool(object):
         class TaskPool(object):
-            pass
+            _pool = Mock()
+
+            def __init__(self, *args, **kwargs):
+                pass
+
+            def start(self):
+                pass
+
+            def stop(self):
+                pass
+
+            def apply_async(self, *args, **kwargs):
+                pass
     mp = _mp()  # noqa
     mp = _mp()  # noqa
+    safe_apply_callback = None
 
 
 from celery.datastructures import ExceptionInfo
 from celery.datastructures import ExceptionInfo
 from celery.utils import noop
 from celery.utils import noop
@@ -128,36 +142,9 @@ class test_TaskPool(unittest.TestCase):
         pool.terminate()
         pool.terminate()
         self.assertTrue(_pool.terminated)
         self.assertTrue(_pool.terminated)
 
 
-    def test_on_worker_error(self):
-        scratch = [None]
-
-        def errback(einfo):
-            scratch[0] = einfo
-
-        pool = TaskPool(10)
-        exc_info = None
-        try:
-            raise KeyError("foo")
-        except KeyError:
-            exc_info = ExceptionInfo(sys.exc_info())
-        pool.on_worker_error(errback, exc_info)
-
-        self.assertTrue(scratch[0])
-        self.assertIs(scratch[0], exc_info)
-
-    def test_on_ready_exception(self):
-        scratch = [None]
-
-        def errback(retval):
-            scratch[0] = retval
-
-        pool = TaskPool(10)
-        exc = to_excinfo(KeyError("foo"))
-        pool.on_ready(None, errback, exc)
-        self.assertEqual(exc, scratch[0])
-
     def test_safe_apply_callback(self):
     def test_safe_apply_callback(self):
-
+        if safe_apply_callback is None:
+            raise SkipTest("multiprocessig not supported")
         _good_called = [0]
         _good_called = [0]
         _evil_called = [0]
         _evil_called = [0]
 
 
@@ -169,29 +156,11 @@ class test_TaskPool(unittest.TestCase):
             _evil_called[0] = 1
             _evil_called[0] = 1
             raise KeyError(x)
             raise KeyError(x)
 
 
-        pool = TaskPool(10)
-        self.assertIsNone(pool.safe_apply_callback(good, 10))
-        self.assertIsNone(pool.safe_apply_callback(evil, 10))
+        self.assertIsNone(safe_apply_callback(good, 10))
+        self.assertIsNone(safe_apply_callback(evil, 10))
         self.assertTrue(_good_called[0])
         self.assertTrue(_good_called[0])
         self.assertTrue(_evil_called[0])
         self.assertTrue(_evil_called[0])
 
 
-    def test_on_ready_value(self):
-        scratch = [None]
-
-        def callback(retval):
-            scratch[0] = retval
-
-        pool = TaskPool(10)
-        retval = "the quick brown fox"
-        pool.on_ready(callback, None, retval)
-        self.assertEqual(retval, scratch[0])
-
-    def test_on_ready_exit_exception(self):
-        pool = TaskPool(10)
-        exc = to_excinfo(SystemExit("foo"))
-        with self.assertRaises(SystemExit):
-            pool.on_ready([], [], exc)
-
     def test_apply_async(self):
     def test_apply_async(self):
         pool = TaskPool(10)
         pool = TaskPool(10)
         pool.start()
         pool.start()

+ 1 - 0
celery/tests/test_worker/test_worker_job.py

@@ -198,6 +198,7 @@ class MockEventDispatcher(object):
 
 
     def __init__(self):
     def __init__(self):
         self.sent = []
         self.sent = []
+        self.enabled = True
 
 
     def send(self, event, **fields):
     def send(self, event, **fields):
         self.sent.append(event)
         self.sent.append(event)

+ 20 - 10
celery/worker/job.py

@@ -19,6 +19,7 @@ import socket
 from datetime import datetime
 from datetime import datetime
 
 
 from .. import exceptions
 from .. import exceptions
+from ..datastructures import ExceptionInfo
 from ..registry import tasks
 from ..registry import tasks
 from ..app import app_or_default
 from ..app import app_or_default
 from ..execute.trace import build_tracer, trace_task, report_internal_error
 from ..execute.trace import build_tracer, trace_task, report_internal_error
@@ -258,7 +259,7 @@ class TaskRequest(object):
                                   accept_callback=self.on_accepted,
                                   accept_callback=self.on_accepted,
                                   timeout_callback=self.on_timeout,
                                   timeout_callback=self.on_timeout,
                                   callback=self.on_success,
                                   callback=self.on_success,
-                                  errback=self.on_failure,
+                                  error_callback=self.on_failure,
                                   soft_timeout=self.task.soft_time_limit,
                                   soft_timeout=self.task.soft_time_limit,
                                   timeout=self.task.time_limit)
                                   timeout=self.task.time_limit)
         return result
         return result
@@ -315,7 +316,7 @@ class TaskRequest(object):
         return False
         return False
 
 
     def send_event(self, type, **fields):
     def send_event(self, type, **fields):
-        if self.eventer:
+        if self.eventer and self.eventer.enabled:
             self.eventer.send(type, **fields)
             self.eventer.send(type, **fields)
 
 
     def on_accepted(self, pid, time_accepted):
     def on_accepted(self, pid, time_accepted):
@@ -348,23 +349,32 @@ class TaskRequest(object):
         if self.store_errors:
         if self.store_errors:
             self.task.backend.mark_as_failure(self.id, exc)
             self.task.backend.mark_as_failure(self.id, exc)
 
 
-    def on_success(self, ret_value):
+    def on_success(self, ret_value, now=None):
         """Handler called if the task was successfully processed."""
         """Handler called if the task was successfully processed."""
+        if isinstance(ret_value, ExceptionInfo):
+            if isinstance(ret_value.exception, (
+                    SystemExit, KeyboardInterrupt)):
+                raise ret_value.exception
+            return self.on_failure(ret_value)
         state.task_ready(self)
         state.task_ready(self)
 
 
         if self.task.acks_late:
         if self.task.acks_late:
             self.acknowledge()
             self.acknowledge()
 
 
-        runtime = self.time_start and (time.time() - self.time_start) or 0
-        self.send_event("task-succeeded", uuid=self.id,
-                        result=safe_repr(ret_value), runtime=runtime)
+        if self.eventer and self.eventer.enabled:
+            now = time.time()
+            runtime = self.time_start and (time.time() - self.time_start) or 0
+            self.send_event("task-succeeded", uuid=self.id,
+                            result=safe_repr(ret_value), runtime=runtime)
 
 
         if self._does_info:
         if self._does_info:
+            now = now or time.time()
+            runtime = self.time_start and (time.time() - self.time_start) or 0
             self.logger.info(self.success_msg.strip(),
             self.logger.info(self.success_msg.strip(),
-                            {"id": self.id,
-                             "name": self.name,
-                             "return_value": self.repr_result(ret_value),
-                             "runtime": runtime})
+                        {"id": self.id,
+                        "name": self.name,
+                        "return_value": self.repr_result(ret_value),
+                        "runtime": runtime})
 
 
     def on_retry(self, exc_info):
     def on_retry(self, exc_info):
         """Handler called if the task should be retried."""
         """Handler called if the task should be retried."""

+ 2 - 2
funtests/benchmarks/bench_worker.py

@@ -27,8 +27,8 @@ celery.conf.update(BROKER_TRANSPORT="librabbitmq",
                            "exchange": "bench.worker",
                            "exchange": "bench.worker",
                            "routing_key": "bench.worker",
                            "routing_key": "bench.worker",
                            "no_ack": True,
                            "no_ack": True,
-                           "exchange_durable": False,
-                           "queue_durable": False,
+                           #"exchange_durable": False,
+                           #"queue_durable": False,
                         }
                         }
                    },
                    },
                    CELERY_TASK_SERIALIZER="json",
                    CELERY_TASK_SERIALIZER="json",