Przeglądaj źródła

Chords was broken + pyflakes

Ask Solem 13 lat temu
rodzic
commit
d60fa8d40c

+ 13 - 19
celery/app/amqp.py

@@ -16,7 +16,7 @@ from datetime import timedelta
 from kombu import BrokerConnection, Exchange
 from kombu import compat as messaging
 from kombu import pools
-from kombu.common import declaration_cached, maybe_declare
+from kombu.common import maybe_declare
 
 from celery import signals
 from celery.utils import cached_property, lpmerge, uuid
@@ -24,11 +24,6 @@ from celery.utils import text
 
 from . import routes as _routes
 
-#: List of known options to a Kombu producers send method.
-#: Used to extract the message related options out of any `dict`.
-MSG_OPTIONS = ("mandatory", "priority", "immediate", "routing_key",
-               "serializer", "delivery_mode", "compression")
-
 #: Human readable queue declaration.
 QUEUE_FORMAT = """
 . %(name)s exchange:%(exchange)s (%(exchange_type)s) \
@@ -36,12 +31,6 @@ binding:%(binding_key)s
 """
 
 
-def extract_msg_options(options, keep=MSG_OPTIONS):
-    """Extracts known options to `basic_publish` from a dict,
-    and returns a new dict."""
-    return dict((name, options.get(name)) for name in keep)
-
-
 class Queues(dict):
     """Queue name⇒ declaration mapping.
 
@@ -157,11 +146,6 @@ class TaskPublisher(messaging.Publisher):
         self.utc = kwargs.pop("enable_utc", False)
         super(TaskPublisher, self).__init__(*args, **kwargs)
 
-    #def declare(self):
-    #    if self.exchange.name and \
-    #       #            not declaration_cached(self.exchange, self.channel):
-    #       #super(TaskPublisher, self).declare()
-
     def _get_queue(self, name):
         if name not in self._queue_cache:
             options = self.app.amqp.queues[name]
@@ -190,7 +174,9 @@ class TaskPublisher(messaging.Publisher):
             expires=None, exchange=None, exchange_type=None,
             event_dispatcher=None, retry=None, retry_policy=None,
             queue=None, now=None, retries=0, chord=None, callbacks=None,
-            errbacks=None, **kwargs):
+            errbacks=None, mandatory=None, priority=None, immediate=None,
+            routing_key=None, serializer=None, delivery_mode=None,
+            compression=None, **kwargs):
         """Send task message."""
 
         connection = self.connection
@@ -240,7 +226,11 @@ class TaskPublisher(messaging.Publisher):
         send = self.send
         if do_retry:
             send = connection.ensure(self, self.send, **_retry_policy)
-        send(body, exchange=exchange, **extract_msg_options(kwargs))
+        send(body, exchange=exchange, mandatory=mandatory,
+             immediate=immediate, routing_key=routing_key,
+             serializer=serializer or self.serializer,
+             delivery_mode=delivery_mode or self.delivery_mode,
+             compression=compression or self.compression)
         signals.task_sent.send(sender=task_name, **body)
         if event_dispatcher:
             event_dispatcher.send("task-sent", uuid=task_id,
@@ -357,6 +347,10 @@ class AMQP(object):
             self.flush_routes()
         return self._rtable
 
+    @cached_property
+    def router(self):
+        return self.Router()
+
     @cached_property
     def publisher_pool(self):
         return PublisherPool(self.app)

+ 20 - 11
celery/app/builtins.py

@@ -80,17 +80,17 @@ def add_group_task(app):
         name = "celery.group"
         accept_magic_kwargs = False
 
-        def run(self, tasks, result):
+        def run(self, tasks, result, setid):
             app = self.app
             result = from_serializable(result)
             if self.request.is_eager or app.conf.CELERY_ALWAYS_EAGER:
                 return app.TaskSetResult(result.id,
-                        [subtask(task).apply(taskset_id=self.request.taskset)
+                        [subtask(task).apply(taskset_id=setid)
                             for task in tasks])
             with app.pool.acquire(block=True) as conn:
                 with app.amqp.TaskPublisher(conn) as publisher:
                     [subtask(task).apply_async(
-                                    taskset_id=self.request.taskset,
+                                    taskset_id=setid,
                                     publisher=publisher)
                             for task in tasks]
             parent = get_current_task()
@@ -103,16 +103,20 @@ def add_group_task(app):
             options["taskset_id"] = group_id = \
                     options.setdefault("task_id", uuid())
             for task in tasks:
-                tid = task.options.setdefault("task_id", uuid())
-                task.options["taskset_id"] = group_id
+                opts = task.options
+                opts["taskset_id"] = group_id
+                try:
+                    tid = opts["task_id"]
+                except KeyError:
+                    tid = opts["task_id"] = uuid()
                 r.append(self.AsyncResult(tid))
-            return tasks, self.app.TaskSetResult(group_id, r)
+            return tasks, self.app.TaskSetResult(group_id, r), group_id
 
         def apply_async(self, args=(), kwargs={}, **options):
             if self.app.conf.CELERY_ALWAYS_EAGER:
                 return self.apply(args, kwargs, **options)
-            tasks, result = self.prepare(options, **kwargs)
-            super(Group, self).apply_async((tasks, result), **options)
+            tasks, result, gid = self.prepare(options, **kwargs)
+            super(Group, self).apply_async((tasks, result, gid), **options)
             return result
 
         def apply(self, args=(), kwargs={}, **options):
@@ -173,15 +177,20 @@ def add_chord_task(app):
             r = []
             setid = uuid()
             for task in header.tasks:
-                tid = task.options.setdefault("task_id", uuid())
-                task.options["chord"] = body
+                opts = task.options
+                try:
+                    tid = opts["task_id"]
+                except KeyError:
+                    tid = opts["task_id"] = uuid()
+                opts["chord"] = body
+                opts["taskset_id"] = setid
                 r.append(app.AsyncResult(tid))
             app.backend.on_chord_apply(setid, body,
                                        interval=interval,
                                        max_retries=max_retries,
                                        propagate=propagate,
                                        result=r)
-            return header(taskset_id=setid)
+            return header(task_id=setid)
 
         def apply_async(self, args=(), kwargs={}, task_id=None, **options):
             if self.app.conf.CELERY_ALWAYS_EAGER:

+ 4 - 4
celery/app/task.py

@@ -235,7 +235,7 @@ class BaseTask(object):
 
     #: The name of a serializer that are registered with
     #: :mod:`kombu.serialization.registry`.  Default is `"pickle"`.
-    serializer = "pickle"
+    serializer = None
 
     #: Hard time limit.
     #: Defaults to the :setting:`CELERY_TASK_TIME_LIMIT` setting.
@@ -274,7 +274,7 @@ class BaseTask(object):
     #:
     #: The application default can be overridden with the
     #: :setting:`CELERY_ACKS_LATE` setting.
-    acks_late = False
+    acks_late = None
 
     #: Default task expiry time.
     expires = None
@@ -434,7 +434,7 @@ class BaseTask(object):
 
     def apply_async(self, args=None, kwargs=None,
             task_id=None, publisher=None, connection=None,
-            router=None, queues=None, link=None, link_error=None, **options):
+            router=None, link=None, link_error=None, **options):
         """Apply tasks asynchronously by sending a message.
 
         :keyword args: The positional arguments to pass on to the
@@ -522,7 +522,7 @@ class BaseTask(object):
 
         """
         app = self._get_app()
-        router = app.amqp.Router(queues)
+        router = router or self.app.amqp.router
         conf = app.conf
 
         if conf.CELERY_ALWAYS_EAGER:

+ 2 - 1
celery/backends/base.py

@@ -454,7 +454,8 @@ class KeyValueStoreBackend(BaseDictBackend):
             return
         key = self.get_key_for_chord(setid)
         deps = TaskSetResult.restore(setid, backend=task.backend)
-        if self.incr(key) >= deps.total:
+        val = self.incr(key)
+        if val >= deps.total:
             subtask(task.request.chord).delay(deps.join(propagate=propagate))
             deps.delete()
             self.client.delete(key)

+ 11 - 10
celery/canvas.py

@@ -17,7 +17,7 @@ from kombu.utils import kwdict, reprcall
 from celery import current_app
 from celery.local import Proxy
 from celery.utils import cached_property, uuid
-from celery.utils.functional import maybe_list
+from celery.utils.functional import maybe_list, is_list
 from celery.utils.compat import chain_from_iterable
 
 Chord = Proxy(lambda: current_app.tasks["celery.chord"])
@@ -104,9 +104,9 @@ class Signature(dict):
         return self.type.apply(args, kwargs, **options)
 
     def _merge(self, args=(), kwargs={}, options={}):
-        return (tuple(args) + tuple(self.args),
-                dict(self.kwargs, **kwargs),
-                dict(self.options, **options))
+        return (tuple(args) + tuple(self.args) if args else self.args,
+                dict(self.kwargs, **kwargs) if kwargs else self.kwargs,
+                dict(self.options, **options) if options else self.options)
 
     def clone(self, args=(), kwargs={}, **options):
         args, kwargs, options = self._merge(args, kwargs, options)
@@ -179,7 +179,7 @@ class Signature(dict):
 
     @cached_property
     def type(self):
-        return self._type or current_app.tasks[self.task]
+        return self._type or current_app.tasks[self["task"]]
     task = _getitem_property("task")
     args = _getitem_property("args")
     kwargs = _getitem_property("kwargs")
@@ -190,6 +190,7 @@ class Signature(dict):
 class chain(Signature):
 
     def __init__(self, *tasks, **options):
+        tasks = tasks[0] if len(tasks) == 1 and is_list(tasks[0]) else tasks
         Signature.__init__(self, "celery.chain", (), {"tasks": tasks}, options)
         self.tasks = tasks
         self.subtask_type = "chain"
@@ -208,19 +209,19 @@ Signature.register_type(chain)
 
 class group(Signature):
 
-    def __init__(self, tasks, **options):
-        self.tasks = tasks = [maybe_subtask(t) for t in tasks]
+    def __init__(self, *tasks, **options):
+        tasks = tasks[0] if len(tasks) == 1 and is_list(tasks[0]) else tasks
         Signature.__init__(self, "celery.group", (), {"tasks": tasks}, options)
-        self.subtask_type = "group"
+        self.tasks, self.subtask_type = tasks, "group"
 
     @classmethod
     def from_dict(self, d):
         return group(d["kwargs"]["tasks"], **kwdict(d["options"]))
 
     def __call__(self, **options):
-        tasks, result = self.type.prepare(options,
+        tasks, result, gid = self.type.prepare(options,
                                 map(Signature.clone, self.tasks))
-        return self.type(tasks, result)
+        return self.type(tasks, result, gid)
 
     def __repr__(self):
         return repr(self.tasks)

+ 1 - 0
celery/concurrency/processes/__init__.py

@@ -47,6 +47,7 @@ def process_initializer(app, hostname):
                   str(os.environ.get("CELERY_LOG_REDIRECT_LEVEL")))
     app.loader.init_worker()
     app.loader.init_worker_process()
+    app.finalize()
     signals.worker_process_init.send(sender=None)
 
 

+ 8 - 3
celery/local.py

@@ -179,9 +179,14 @@ class PromiseProxy(Proxy):
         return self._get_current_object()
 
     def __evaluate__(self):
-        thing = Proxy._get_current_object(self)
-        object.__setattr__(self, "__thing", thing)
-        return thing
+        try:
+            thing = Proxy._get_current_object(self)
+            object.__setattr__(self, "__thing", thing)
+            return thing
+        finally:
+            object.__delattr__(self, "_Proxy__local")
+            object.__delattr__(self, "_Proxy__args")
+            object.__delattr__(self, "_Proxy__kwargs")
 
 
 def maybe_evaluate(obj):

+ 6 - 3
celery/utils/__init__.py

@@ -69,9 +69,12 @@ def deprecated(description=None, deprecation=None, removal=None,
 
 
 def lpmerge(L, R):
-    """Left precedent dictionary merge.  Keeps values from `l`, if the value
-    in `r` is :const:`None`."""
-    return dict(L, **dict((k, v) for k, v in R.iteritems() if v is not None))
+    """In place left precedent dictionary merge.
+
+    Keeps values from `L`, if the value in `R` is :const:`None`."""
+    set = L.__setitem__
+    [set(k, v) for k, v in R.iteritems() if v is not None]
+    return L
 
 
 def is_iterable(obj):

+ 5 - 11
celery/utils/functional.py

@@ -18,12 +18,6 @@ from functools import partial, wraps
 from itertools import islice
 from threading import Lock, RLock
 
-try:
-    from collections import Sequence
-except ImportError:             # pragma: no cover
-    # <= Py2.5
-    Sequence = (list, tuple)    # noqa
-
 from kombu.utils.functional import promise, maybe_promise
 
 from .compat import UserDict, OrderedDict
@@ -97,12 +91,12 @@ class LRUCache(UserDict):
         return newval
 
 
+def is_list(l):
+    return hasattr(l, "__iter__") and not isinstance(l, dict)
+
+
 def maybe_list(l):
-    if l is None:
-        return l
-    elif not isinstance(l, basestring) and isinstance(l, Sequence):
-        return l
-    return [l]
+    return l if l is None or is_list(l) else [l]
 
 
 def memoize(maxsize=None, Cache=LRUCache):

+ 2 - 1
celery/worker/autoscale.py

@@ -60,11 +60,12 @@ class Autoscaler(bgThread):
 
     def body(self):
         with self.mutex:
+            procs = self.processes
             cur = min(self.qty, self.max_concurrency)
             if cur > procs:
                 self.scale_up(cur - procs)
             elif cur < procs:
-                self.scale_down((self.processes - cur) - self.min_concurrency)
+                self.scale_down((procs - cur) - self.min_concurrency)
         sleep(1.0)
 
     def update(self, max=None, min=None):

+ 9 - 12
celery/worker/job.py

@@ -22,9 +22,9 @@ from datetime import datetime
 from kombu.utils import kwdict, reprcall
 from kombu.utils.encoding import safe_repr, safe_str
 
-from celery import current_app
 from celery import exceptions
 from celery.app import app_or_default
+from celery.app.state import _tls
 from celery.datastructures import ExceptionInfo
 from celery.task.trace import build_tracer, trace_task, report_internal_error
 from celery.platforms import set_mp_process_title as setps
@@ -39,6 +39,8 @@ from . import state
 logger = get_logger(__name__)
 debug, info, warn, error = (logger.debug, logger.info,
                             logger.warn, logger.error)
+_does_debug = logger.isEnabledFor(logging.DEBUG)
+_does_info = logger.isEnabledFor(logging.INFO)
 
 # Localize
 tz_to_local = timezone.to_local
@@ -56,7 +58,7 @@ def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
         >>> trace_task(name, *args, **kwargs)[0]
 
     """
-    task = current_app.tasks[name]
+    task = _tls.current_app._tasks[name]
     try:
         hostname = opts.get("hostname")
         setps("celeryd", name, hostname, rate_limit=True)
@@ -77,9 +79,8 @@ class Request(object):
                  "callbacks", "errbacks",
                  "eventer", "connection_errors",
                  "task", "eta", "expires",
-                 "_does_debug", "_does_info", "request_dict",
-                 "acknowledged", "success_msg", "error_msg",
-                 "retry_msg", "time_start", "worker_pid",
+                 "request_dict", "acknowledged", "success_msg",
+                 "error_msg", "retry_msg", "time_start", "worker_pid",
                  "_already_revoked", "_terminate_on_ack", "_tzlocal")
 
     #: Format string used to log task success.
@@ -148,10 +149,6 @@ class Request(object):
             "routing_key": delivery_info.get("routing_key"),
         }
 
-        ## shortcuts
-        self._does_debug = logger.isEnabledFor(logging.DEBUG)
-        self._does_info = logger.isEnabledFor(logging.INFO)
-
         self.request_dict = body
 
     @classmethod
@@ -288,7 +285,7 @@ class Request(object):
         if not self.task.acks_late:
             self.acknowledge()
         self.send_event("task-started", uuid=self.id, pid=pid)
-        if self._does_debug:
+        if _does_debug:
             debug("Task accepted: %s[%s] pid:%r", self.name, self.id, pid)
         if self._terminate_on_ack is not None:
             _, pool, signal = self._terminate_on_ack
@@ -327,7 +324,7 @@ class Request(object):
             self.send_event("task-succeeded", uuid=self.id,
                             result=safe_repr(ret_value), runtime=runtime)
 
-        if self._does_info:
+        if _does_info:
             now = now or time.time()
             runtime = self.time_start and (time.time() - self.time_start) or 0
             info(self.success_msg.strip(), {
@@ -341,7 +338,7 @@ class Request(object):
                          exception=safe_repr(exc_info.exception.exc),
                          traceback=safe_str(exc_info.traceback))
 
-        if self._does_info:
+        if _does_info:
             info(self.retry_msg.strip(), {
                 "id": self.id, "name": self.name,
                 "exc": safe_repr(exc_info.exception.exc)}, exc_info=exc_info)

+ 5 - 4
funtests/benchmarks/bench_worker.py

@@ -11,7 +11,7 @@ if JSONIMP:
 
 print("anyjson implementation: %r" % (anyjson.implementation.name, ))
 
-from celery import Celery
+from celery import Celery, group
 
 DEFAULT_ITS = 20000
 
@@ -66,7 +66,7 @@ def it(_, n):
 
 def bench_apply(n=DEFAULT_ITS):
     time_start = time.time()
-    celery.TaskSet(it.subtask((i, n)) for i in xrange(n)).apply_async()
+    group(it.s(i, n) for i in xrange(n))()
     print("-- apply %s tasks: %ss" % (n, time.time() - time_start, ))
 
 
@@ -81,6 +81,7 @@ def bench_work(n=DEFAULT_ITS, loglevel="CRITICAL"):
         print("STARTING WORKER")
         worker.start()
     except SystemExit:
+        raise
         assert sum(worker.state.total_count.values()) == n + 1
 
 
@@ -103,8 +104,8 @@ def main(argv=sys.argv):
         return {"apply": bench_apply,
                 "work": bench_work,
                 "both": bench_both}[argv[1]](n=n)
-    except KeyboardInterrupt:
-        pass
+    except:
+        raise
 
 
 if __name__ == "__main__":