Explorar o código

TaskPool.apply_async callbacks/errbacks is now singular instead of list

Ask Solem %!s(int64=14) %!d(string=hai) anos
pai
achega
6716a4552f

+ 13 - 17
celery/concurrency/base.py

@@ -65,8 +65,8 @@ class BasePool(object):
         self.on_start()
         self._state = self.RUN
 
-    def apply_async(self, target, args=None, kwargs=None, callbacks=None,
-            errbacks=None, accept_callback=None, timeout_callback=None,
+    def apply_async(self, target, args=None, kwargs=None, callback=None,
+            errback=None, accept_callback=None, timeout_callback=None,
             soft_timeout=None, timeout=None, **compat):
         """Equivalent of the :func:`apply` built-in function.
 
@@ -76,11 +76,9 @@ class BasePool(object):
         """
         args = args or []
         kwargs = kwargs or {}
-        callbacks = callbacks or []
-        errbacks = errbacks or []
 
-        on_ready = partial(self.on_ready, callbacks, errbacks)
-        on_worker_error = partial(self.on_worker_error, errbacks)
+        on_ready = partial(self.on_ready, callback, errback)
+        on_worker_error = partial(self.on_worker_error, errback)
 
         self.logger.debug("TaskPool: Apply %s (args:%s kwargs:%s)" % (
             target, args, kwargs))
@@ -94,7 +92,7 @@ class BasePool(object):
                              soft_timeout=soft_timeout,
                              timeout=timeout)
 
-    def on_ready(self, callbacks, errbacks, ret_value):
+    def on_ready(self, callback, errback, ret_value):
         """What to do when a worker task is ready and its return value has
         been collected."""
 
@@ -102,23 +100,21 @@ class BasePool(object):
             if isinstance(ret_value.exception, (
                     SystemExit, KeyboardInterrupt)):
                 raise ret_value.exception
-            [self.safe_apply_callback(errback, ret_value)
-                    for errback in errbacks]
+            self.safe_apply_callback(errback, ret_value)
         else:
-            [self.safe_apply_callback(callback, ret_value)
-                    for callback in callbacks]
+            self.safe_apply_callback(callback, ret_value)
 
     def on_worker_error(self, errbacks, exc):
         einfo = ExceptionInfo((exc.__class__, exc, None))
         [errback(einfo) for errback in errbacks]
 
     def safe_apply_callback(self, fun, *args):
-        try:
-            fun(*args)
-        except:
-            self.logger.error("Pool callback raised exception: %s" % (
-                traceback.format_exc(), ),
-                exc_info=sys.exc_info())
+        if fun:
+            try:
+                fun(*args)
+            except BaseException:
+                self.logger.error("Pool callback raised exception: %s" % (
+                    traceback.format_exc(), ), exc_info=sys.exc_info())
 
     def _get_info(self):
         return {}

+ 2 - 2
celery/tests/test_concurrency/test_concurrency_processes.py

@@ -144,7 +144,7 @@ class test_TaskPool(unittest.TestCase):
 
         pool = TaskPool(10)
         exc = to_excinfo(KeyError("foo"))
-        pool.on_ready([], [errback], exc)
+        pool.on_ready(None, errback, exc)
         self.assertEqual(exc, scratch[0])
 
     def test_safe_apply_callback(self):
@@ -174,7 +174,7 @@ class test_TaskPool(unittest.TestCase):
 
         pool = TaskPool(10)
         retval = "the quick brown fox"
-        pool.on_ready([callback], [], retval)
+        pool.on_ready(callback, None, retval)
         self.assertEqual(retval, scratch[0])
 
     def test_on_ready_exit_exception(self):