Browse Source

Chords was broken + pyflakes

Ask Solem 13 years ago
parent
commit
d60fa8d40c

+ 13 - 19
celery/app/amqp.py

@@ -16,7 +16,7 @@ from datetime import timedelta
 from kombu import BrokerConnection, Exchange
 from kombu import BrokerConnection, Exchange
 from kombu import compat as messaging
 from kombu import compat as messaging
 from kombu import pools
 from kombu import pools
-from kombu.common import declaration_cached, maybe_declare
+from kombu.common import maybe_declare
 
 
 from celery import signals
 from celery import signals
 from celery.utils import cached_property, lpmerge, uuid
 from celery.utils import cached_property, lpmerge, uuid
@@ -24,11 +24,6 @@ from celery.utils import text
 
 
 from . import routes as _routes
 from . import routes as _routes
 
 
-#: List of known options to a Kombu producers send method.
-#: Used to extract the message related options out of any `dict`.
-MSG_OPTIONS = ("mandatory", "priority", "immediate", "routing_key",
-               "serializer", "delivery_mode", "compression")
-
 #: Human readable queue declaration.
 #: Human readable queue declaration.
 QUEUE_FORMAT = """
 QUEUE_FORMAT = """
 . %(name)s exchange:%(exchange)s (%(exchange_type)s) \
 . %(name)s exchange:%(exchange)s (%(exchange_type)s) \
@@ -36,12 +31,6 @@ binding:%(binding_key)s
 """
 """
 
 
 
 
-def extract_msg_options(options, keep=MSG_OPTIONS):
-    """Extracts known options to `basic_publish` from a dict,
-    and returns a new dict."""
-    return dict((name, options.get(name)) for name in keep)
-
-
 class Queues(dict):
 class Queues(dict):
     """Queue name⇒ declaration mapping.
     """Queue name⇒ declaration mapping.
 
 
@@ -157,11 +146,6 @@ class TaskPublisher(messaging.Publisher):
         self.utc = kwargs.pop("enable_utc", False)
         self.utc = kwargs.pop("enable_utc", False)
         super(TaskPublisher, self).__init__(*args, **kwargs)
         super(TaskPublisher, self).__init__(*args, **kwargs)
 
 
-    #def declare(self):
-    #    if self.exchange.name and \
-    #       #            not declaration_cached(self.exchange, self.channel):
-    #       #super(TaskPublisher, self).declare()
-
     def _get_queue(self, name):
     def _get_queue(self, name):
         if name not in self._queue_cache:
         if name not in self._queue_cache:
             options = self.app.amqp.queues[name]
             options = self.app.amqp.queues[name]
@@ -190,7 +174,9 @@ class TaskPublisher(messaging.Publisher):
             expires=None, exchange=None, exchange_type=None,
             expires=None, exchange=None, exchange_type=None,
             event_dispatcher=None, retry=None, retry_policy=None,
             event_dispatcher=None, retry=None, retry_policy=None,
             queue=None, now=None, retries=0, chord=None, callbacks=None,
             queue=None, now=None, retries=0, chord=None, callbacks=None,
-            errbacks=None, **kwargs):
+            errbacks=None, mandatory=None, priority=None, immediate=None,
+            routing_key=None, serializer=None, delivery_mode=None,
+            compression=None, **kwargs):
         """Send task message."""
         """Send task message."""
 
 
         connection = self.connection
         connection = self.connection
@@ -240,7 +226,11 @@ class TaskPublisher(messaging.Publisher):
         send = self.send
         send = self.send
         if do_retry:
         if do_retry:
             send = connection.ensure(self, self.send, **_retry_policy)
             send = connection.ensure(self, self.send, **_retry_policy)
-        send(body, exchange=exchange, **extract_msg_options(kwargs))
+        send(body, exchange=exchange, mandatory=mandatory,
+             immediate=immediate, routing_key=routing_key,
+             serializer=serializer or self.serializer,
+             delivery_mode=delivery_mode or self.delivery_mode,
+             compression=compression or self.compression)
         signals.task_sent.send(sender=task_name, **body)
         signals.task_sent.send(sender=task_name, **body)
         if event_dispatcher:
         if event_dispatcher:
             event_dispatcher.send("task-sent", uuid=task_id,
             event_dispatcher.send("task-sent", uuid=task_id,
@@ -357,6 +347,10 @@ class AMQP(object):
             self.flush_routes()
             self.flush_routes()
         return self._rtable
         return self._rtable
 
 
+    @cached_property
+    def router(self):
+        return self.Router()
+
     @cached_property
     @cached_property
     def publisher_pool(self):
     def publisher_pool(self):
         return PublisherPool(self.app)
         return PublisherPool(self.app)

+ 20 - 11
celery/app/builtins.py

@@ -80,17 +80,17 @@ def add_group_task(app):
         name = "celery.group"
         name = "celery.group"
         accept_magic_kwargs = False
         accept_magic_kwargs = False
 
 
-        def run(self, tasks, result):
+        def run(self, tasks, result, setid):
             app = self.app
             app = self.app
             result = from_serializable(result)
             result = from_serializable(result)
             if self.request.is_eager or app.conf.CELERY_ALWAYS_EAGER:
             if self.request.is_eager or app.conf.CELERY_ALWAYS_EAGER:
                 return app.TaskSetResult(result.id,
                 return app.TaskSetResult(result.id,
-                        [subtask(task).apply(taskset_id=self.request.taskset)
+                        [subtask(task).apply(taskset_id=setid)
                             for task in tasks])
                             for task in tasks])
             with app.pool.acquire(block=True) as conn:
             with app.pool.acquire(block=True) as conn:
                 with app.amqp.TaskPublisher(conn) as publisher:
                 with app.amqp.TaskPublisher(conn) as publisher:
                     [subtask(task).apply_async(
                     [subtask(task).apply_async(
-                                    taskset_id=self.request.taskset,
+                                    taskset_id=setid,
                                     publisher=publisher)
                                     publisher=publisher)
                             for task in tasks]
                             for task in tasks]
             parent = get_current_task()
             parent = get_current_task()
@@ -103,16 +103,20 @@ def add_group_task(app):
             options["taskset_id"] = group_id = \
             options["taskset_id"] = group_id = \
                     options.setdefault("task_id", uuid())
                     options.setdefault("task_id", uuid())
             for task in tasks:
             for task in tasks:
-                tid = task.options.setdefault("task_id", uuid())
-                task.options["taskset_id"] = group_id
+                opts = task.options
+                opts["taskset_id"] = group_id
+                try:
+                    tid = opts["task_id"]
+                except KeyError:
+                    tid = opts["task_id"] = uuid()
                 r.append(self.AsyncResult(tid))
                 r.append(self.AsyncResult(tid))
-            return tasks, self.app.TaskSetResult(group_id, r)
+            return tasks, self.app.TaskSetResult(group_id, r), group_id
 
 
         def apply_async(self, args=(), kwargs={}, **options):
         def apply_async(self, args=(), kwargs={}, **options):
             if self.app.conf.CELERY_ALWAYS_EAGER:
             if self.app.conf.CELERY_ALWAYS_EAGER:
                 return self.apply(args, kwargs, **options)
                 return self.apply(args, kwargs, **options)
-            tasks, result = self.prepare(options, **kwargs)
-            super(Group, self).apply_async((tasks, result), **options)
+            tasks, result, gid = self.prepare(options, **kwargs)
+            super(Group, self).apply_async((tasks, result, gid), **options)
             return result
             return result
 
 
         def apply(self, args=(), kwargs={}, **options):
         def apply(self, args=(), kwargs={}, **options):
@@ -173,15 +177,20 @@ def add_chord_task(app):
             r = []
             r = []
             setid = uuid()
             setid = uuid()
             for task in header.tasks:
             for task in header.tasks:
-                tid = task.options.setdefault("task_id", uuid())
-                task.options["chord"] = body
+                opts = task.options
+                try:
+                    tid = opts["task_id"]
+                except KeyError:
+                    tid = opts["task_id"] = uuid()
+                opts["chord"] = body
+                opts["taskset_id"] = setid
                 r.append(app.AsyncResult(tid))
                 r.append(app.AsyncResult(tid))
             app.backend.on_chord_apply(setid, body,
             app.backend.on_chord_apply(setid, body,
                                        interval=interval,
                                        interval=interval,
                                        max_retries=max_retries,
                                        max_retries=max_retries,
                                        propagate=propagate,
                                        propagate=propagate,
                                        result=r)
                                        result=r)
-            return header(taskset_id=setid)
+            return header(task_id=setid)
 
 
         def apply_async(self, args=(), kwargs={}, task_id=None, **options):
         def apply_async(self, args=(), kwargs={}, task_id=None, **options):
             if self.app.conf.CELERY_ALWAYS_EAGER:
             if self.app.conf.CELERY_ALWAYS_EAGER:

+ 4 - 4
celery/app/task.py

@@ -235,7 +235,7 @@ class BaseTask(object):
 
 
     #: The name of a serializer that are registered with
     #: The name of a serializer that are registered with
     #: :mod:`kombu.serialization.registry`.  Default is `"pickle"`.
     #: :mod:`kombu.serialization.registry`.  Default is `"pickle"`.
-    serializer = "pickle"
+    serializer = None
 
 
     #: Hard time limit.
     #: Hard time limit.
     #: Defaults to the :setting:`CELERY_TASK_TIME_LIMIT` setting.
     #: Defaults to the :setting:`CELERY_TASK_TIME_LIMIT` setting.
@@ -274,7 +274,7 @@ class BaseTask(object):
     #:
     #:
     #: The application default can be overridden with the
     #: The application default can be overridden with the
     #: :setting:`CELERY_ACKS_LATE` setting.
     #: :setting:`CELERY_ACKS_LATE` setting.
-    acks_late = False
+    acks_late = None
 
 
     #: Default task expiry time.
     #: Default task expiry time.
     expires = None
     expires = None
@@ -434,7 +434,7 @@ class BaseTask(object):
 
 
     def apply_async(self, args=None, kwargs=None,
     def apply_async(self, args=None, kwargs=None,
             task_id=None, publisher=None, connection=None,
             task_id=None, publisher=None, connection=None,
-            router=None, queues=None, link=None, link_error=None, **options):
+            router=None, link=None, link_error=None, **options):
         """Apply tasks asynchronously by sending a message.
         """Apply tasks asynchronously by sending a message.
 
 
         :keyword args: The positional arguments to pass on to the
         :keyword args: The positional arguments to pass on to the
@@ -522,7 +522,7 @@ class BaseTask(object):
 
 
         """
         """
         app = self._get_app()
         app = self._get_app()
-        router = app.amqp.Router(queues)
+        router = router or self.app.amqp.router
         conf = app.conf
         conf = app.conf
 
 
         if conf.CELERY_ALWAYS_EAGER:
         if conf.CELERY_ALWAYS_EAGER:

+ 2 - 1
celery/backends/base.py

@@ -454,7 +454,8 @@ class KeyValueStoreBackend(BaseDictBackend):
             return
             return
         key = self.get_key_for_chord(setid)
         key = self.get_key_for_chord(setid)
         deps = TaskSetResult.restore(setid, backend=task.backend)
         deps = TaskSetResult.restore(setid, backend=task.backend)
-        if self.incr(key) >= deps.total:
+        val = self.incr(key)
+        if val >= deps.total:
             subtask(task.request.chord).delay(deps.join(propagate=propagate))
             subtask(task.request.chord).delay(deps.join(propagate=propagate))
             deps.delete()
             deps.delete()
             self.client.delete(key)
             self.client.delete(key)

+ 11 - 10
celery/canvas.py

@@ -17,7 +17,7 @@ from kombu.utils import kwdict, reprcall
 from celery import current_app
 from celery import current_app
 from celery.local import Proxy
 from celery.local import Proxy
 from celery.utils import cached_property, uuid
 from celery.utils import cached_property, uuid
-from celery.utils.functional import maybe_list
+from celery.utils.functional import maybe_list, is_list
 from celery.utils.compat import chain_from_iterable
 from celery.utils.compat import chain_from_iterable
 
 
 Chord = Proxy(lambda: current_app.tasks["celery.chord"])
 Chord = Proxy(lambda: current_app.tasks["celery.chord"])
@@ -104,9 +104,9 @@ class Signature(dict):
         return self.type.apply(args, kwargs, **options)
         return self.type.apply(args, kwargs, **options)
 
 
     def _merge(self, args=(), kwargs={}, options={}):
     def _merge(self, args=(), kwargs={}, options={}):
-        return (tuple(args) + tuple(self.args),
-                dict(self.kwargs, **kwargs),
-                dict(self.options, **options))
+        return (tuple(args) + tuple(self.args) if args else self.args,
+                dict(self.kwargs, **kwargs) if kwargs else self.kwargs,
+                dict(self.options, **options) if options else self.options)
 
 
     def clone(self, args=(), kwargs={}, **options):
     def clone(self, args=(), kwargs={}, **options):
         args, kwargs, options = self._merge(args, kwargs, options)
         args, kwargs, options = self._merge(args, kwargs, options)
@@ -179,7 +179,7 @@ class Signature(dict):
 
 
     @cached_property
     @cached_property
     def type(self):
     def type(self):
-        return self._type or current_app.tasks[self.task]
+        return self._type or current_app.tasks[self["task"]]
     task = _getitem_property("task")
     task = _getitem_property("task")
     args = _getitem_property("args")
     args = _getitem_property("args")
     kwargs = _getitem_property("kwargs")
     kwargs = _getitem_property("kwargs")
@@ -190,6 +190,7 @@ class Signature(dict):
 class chain(Signature):
 class chain(Signature):
 
 
     def __init__(self, *tasks, **options):
     def __init__(self, *tasks, **options):
+        tasks = tasks[0] if len(tasks) == 1 and is_list(tasks[0]) else tasks
         Signature.__init__(self, "celery.chain", (), {"tasks": tasks}, options)
         Signature.__init__(self, "celery.chain", (), {"tasks": tasks}, options)
         self.tasks = tasks
         self.tasks = tasks
         self.subtask_type = "chain"
         self.subtask_type = "chain"
@@ -208,19 +209,19 @@ Signature.register_type(chain)
 
 
 class group(Signature):
 class group(Signature):
 
 
-    def __init__(self, tasks, **options):
-        self.tasks = tasks = [maybe_subtask(t) for t in tasks]
+    def __init__(self, *tasks, **options):
+        tasks = tasks[0] if len(tasks) == 1 and is_list(tasks[0]) else tasks
         Signature.__init__(self, "celery.group", (), {"tasks": tasks}, options)
         Signature.__init__(self, "celery.group", (), {"tasks": tasks}, options)
-        self.subtask_type = "group"
+        self.tasks, self.subtask_type = tasks, "group"
 
 
     @classmethod
     @classmethod
     def from_dict(self, d):
     def from_dict(self, d):
         return group(d["kwargs"]["tasks"], **kwdict(d["options"]))
         return group(d["kwargs"]["tasks"], **kwdict(d["options"]))
 
 
     def __call__(self, **options):
     def __call__(self, **options):
-        tasks, result = self.type.prepare(options,
+        tasks, result, gid = self.type.prepare(options,
                                 map(Signature.clone, self.tasks))
                                 map(Signature.clone, self.tasks))
-        return self.type(tasks, result)
+        return self.type(tasks, result, gid)
 
 
     def __repr__(self):
     def __repr__(self):
         return repr(self.tasks)
         return repr(self.tasks)

+ 1 - 0
celery/concurrency/processes/__init__.py

@@ -47,6 +47,7 @@ def process_initializer(app, hostname):
                   str(os.environ.get("CELERY_LOG_REDIRECT_LEVEL")))
                   str(os.environ.get("CELERY_LOG_REDIRECT_LEVEL")))
     app.loader.init_worker()
     app.loader.init_worker()
     app.loader.init_worker_process()
     app.loader.init_worker_process()
+    app.finalize()
     signals.worker_process_init.send(sender=None)
     signals.worker_process_init.send(sender=None)
 
 
 
 

+ 8 - 3
celery/local.py

@@ -179,9 +179,14 @@ class PromiseProxy(Proxy):
         return self._get_current_object()
         return self._get_current_object()
 
 
     def __evaluate__(self):
     def __evaluate__(self):
-        thing = Proxy._get_current_object(self)
-        object.__setattr__(self, "__thing", thing)
-        return thing
+        try:
+            thing = Proxy._get_current_object(self)
+            object.__setattr__(self, "__thing", thing)
+            return thing
+        finally:
+            object.__delattr__(self, "_Proxy__local")
+            object.__delattr__(self, "_Proxy__args")
+            object.__delattr__(self, "_Proxy__kwargs")
 
 
 
 
 def maybe_evaluate(obj):
 def maybe_evaluate(obj):

+ 6 - 3
celery/utils/__init__.py

@@ -69,9 +69,12 @@ def deprecated(description=None, deprecation=None, removal=None,
 
 
 
 
 def lpmerge(L, R):
 def lpmerge(L, R):
-    """Left precedent dictionary merge.  Keeps values from `l`, if the value
-    in `r` is :const:`None`."""
-    return dict(L, **dict((k, v) for k, v in R.iteritems() if v is not None))
+    """In place left precedent dictionary merge.
+
+    Keeps values from `L`, if the value in `R` is :const:`None`."""
+    set = L.__setitem__
+    [set(k, v) for k, v in R.iteritems() if v is not None]
+    return L
 
 
 
 
 def is_iterable(obj):
 def is_iterable(obj):

+ 5 - 11
celery/utils/functional.py

@@ -18,12 +18,6 @@ from functools import partial, wraps
 from itertools import islice
 from itertools import islice
 from threading import Lock, RLock
 from threading import Lock, RLock
 
 
-try:
-    from collections import Sequence
-except ImportError:             # pragma: no cover
-    # <= Py2.5
-    Sequence = (list, tuple)    # noqa
-
 from kombu.utils.functional import promise, maybe_promise
 from kombu.utils.functional import promise, maybe_promise
 
 
 from .compat import UserDict, OrderedDict
 from .compat import UserDict, OrderedDict
@@ -97,12 +91,12 @@ class LRUCache(UserDict):
         return newval
         return newval
 
 
 
 
+def is_list(l):
+    return hasattr(l, "__iter__") and not isinstance(l, dict)
+
+
 def maybe_list(l):
 def maybe_list(l):
-    if l is None:
-        return l
-    elif not isinstance(l, basestring) and isinstance(l, Sequence):
-        return l
-    return [l]
+    return l if l is None or is_list(l) else [l]
 
 
 
 
 def memoize(maxsize=None, Cache=LRUCache):
 def memoize(maxsize=None, Cache=LRUCache):

+ 2 - 1
celery/worker/autoscale.py

@@ -60,11 +60,12 @@ class Autoscaler(bgThread):
 
 
     def body(self):
     def body(self):
         with self.mutex:
         with self.mutex:
+            procs = self.processes
             cur = min(self.qty, self.max_concurrency)
             cur = min(self.qty, self.max_concurrency)
             if cur > procs:
             if cur > procs:
                 self.scale_up(cur - procs)
                 self.scale_up(cur - procs)
             elif cur < procs:
             elif cur < procs:
-                self.scale_down((self.processes - cur) - self.min_concurrency)
+                self.scale_down((procs - cur) - self.min_concurrency)
         sleep(1.0)
         sleep(1.0)
 
 
     def update(self, max=None, min=None):
     def update(self, max=None, min=None):

+ 9 - 12
celery/worker/job.py

@@ -22,9 +22,9 @@ from datetime import datetime
 from kombu.utils import kwdict, reprcall
 from kombu.utils import kwdict, reprcall
 from kombu.utils.encoding import safe_repr, safe_str
 from kombu.utils.encoding import safe_repr, safe_str
 
 
-from celery import current_app
 from celery import exceptions
 from celery import exceptions
 from celery.app import app_or_default
 from celery.app import app_or_default
+from celery.app.state import _tls
 from celery.datastructures import ExceptionInfo
 from celery.datastructures import ExceptionInfo
 from celery.task.trace import build_tracer, trace_task, report_internal_error
 from celery.task.trace import build_tracer, trace_task, report_internal_error
 from celery.platforms import set_mp_process_title as setps
 from celery.platforms import set_mp_process_title as setps
@@ -39,6 +39,8 @@ from . import state
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 debug, info, warn, error = (logger.debug, logger.info,
 debug, info, warn, error = (logger.debug, logger.info,
                             logger.warn, logger.error)
                             logger.warn, logger.error)
+_does_debug = logger.isEnabledFor(logging.DEBUG)
+_does_info = logger.isEnabledFor(logging.INFO)
 
 
 # Localize
 # Localize
 tz_to_local = timezone.to_local
 tz_to_local = timezone.to_local
@@ -56,7 +58,7 @@ def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
         >>> trace_task(name, *args, **kwargs)[0]
         >>> trace_task(name, *args, **kwargs)[0]
 
 
     """
     """
-    task = current_app.tasks[name]
+    task = _tls.current_app._tasks[name]
     try:
     try:
         hostname = opts.get("hostname")
         hostname = opts.get("hostname")
         setps("celeryd", name, hostname, rate_limit=True)
         setps("celeryd", name, hostname, rate_limit=True)
@@ -77,9 +79,8 @@ class Request(object):
                  "callbacks", "errbacks",
                  "callbacks", "errbacks",
                  "eventer", "connection_errors",
                  "eventer", "connection_errors",
                  "task", "eta", "expires",
                  "task", "eta", "expires",
-                 "_does_debug", "_does_info", "request_dict",
-                 "acknowledged", "success_msg", "error_msg",
-                 "retry_msg", "time_start", "worker_pid",
+                 "request_dict", "acknowledged", "success_msg",
+                 "error_msg", "retry_msg", "time_start", "worker_pid",
                  "_already_revoked", "_terminate_on_ack", "_tzlocal")
                  "_already_revoked", "_terminate_on_ack", "_tzlocal")
 
 
     #: Format string used to log task success.
     #: Format string used to log task success.
@@ -148,10 +149,6 @@ class Request(object):
             "routing_key": delivery_info.get("routing_key"),
             "routing_key": delivery_info.get("routing_key"),
         }
         }
 
 
-        ## shortcuts
-        self._does_debug = logger.isEnabledFor(logging.DEBUG)
-        self._does_info = logger.isEnabledFor(logging.INFO)
-
         self.request_dict = body
         self.request_dict = body
 
 
     @classmethod
     @classmethod
@@ -288,7 +285,7 @@ class Request(object):
         if not self.task.acks_late:
         if not self.task.acks_late:
             self.acknowledge()
             self.acknowledge()
         self.send_event("task-started", uuid=self.id, pid=pid)
         self.send_event("task-started", uuid=self.id, pid=pid)
-        if self._does_debug:
+        if _does_debug:
             debug("Task accepted: %s[%s] pid:%r", self.name, self.id, pid)
             debug("Task accepted: %s[%s] pid:%r", self.name, self.id, pid)
         if self._terminate_on_ack is not None:
         if self._terminate_on_ack is not None:
             _, pool, signal = self._terminate_on_ack
             _, pool, signal = self._terminate_on_ack
@@ -327,7 +324,7 @@ class Request(object):
             self.send_event("task-succeeded", uuid=self.id,
             self.send_event("task-succeeded", uuid=self.id,
                             result=safe_repr(ret_value), runtime=runtime)
                             result=safe_repr(ret_value), runtime=runtime)
 
 
-        if self._does_info:
+        if _does_info:
             now = now or time.time()
             now = now or time.time()
             runtime = self.time_start and (time.time() - self.time_start) or 0
             runtime = self.time_start and (time.time() - self.time_start) or 0
             info(self.success_msg.strip(), {
             info(self.success_msg.strip(), {
@@ -341,7 +338,7 @@ class Request(object):
                          exception=safe_repr(exc_info.exception.exc),
                          exception=safe_repr(exc_info.exception.exc),
                          traceback=safe_str(exc_info.traceback))
                          traceback=safe_str(exc_info.traceback))
 
 
-        if self._does_info:
+        if _does_info:
             info(self.retry_msg.strip(), {
             info(self.retry_msg.strip(), {
                 "id": self.id, "name": self.name,
                 "id": self.id, "name": self.name,
                 "exc": safe_repr(exc_info.exception.exc)}, exc_info=exc_info)
                 "exc": safe_repr(exc_info.exception.exc)}, exc_info=exc_info)

+ 5 - 4
funtests/benchmarks/bench_worker.py

@@ -11,7 +11,7 @@ if JSONIMP:
 
 
 print("anyjson implementation: %r" % (anyjson.implementation.name, ))
 print("anyjson implementation: %r" % (anyjson.implementation.name, ))
 
 
-from celery import Celery
+from celery import Celery, group
 
 
 DEFAULT_ITS = 20000
 DEFAULT_ITS = 20000
 
 
@@ -66,7 +66,7 @@ def it(_, n):
 
 
 def bench_apply(n=DEFAULT_ITS):
 def bench_apply(n=DEFAULT_ITS):
     time_start = time.time()
     time_start = time.time()
-    celery.TaskSet(it.subtask((i, n)) for i in xrange(n)).apply_async()
+    group(it.s(i, n) for i in xrange(n))()
     print("-- apply %s tasks: %ss" % (n, time.time() - time_start, ))
     print("-- apply %s tasks: %ss" % (n, time.time() - time_start, ))
 
 
 
 
@@ -81,6 +81,7 @@ def bench_work(n=DEFAULT_ITS, loglevel="CRITICAL"):
         print("STARTING WORKER")
         print("STARTING WORKER")
         worker.start()
         worker.start()
     except SystemExit:
     except SystemExit:
+        raise
         assert sum(worker.state.total_count.values()) == n + 1
         assert sum(worker.state.total_count.values()) == n + 1
 
 
 
 
@@ -103,8 +104,8 @@ def main(argv=sys.argv):
         return {"apply": bench_apply,
         return {"apply": bench_apply,
                 "work": bench_work,
                 "work": bench_work,
                 "both": bench_both}[argv[1]](n=n)
                 "both": bench_both}[argv[1]](n=n)
-    except KeyboardInterrupt:
-        pass
+    except:
+        raise
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":