Ask Solem 13 lat temu
rodzic
commit
8af6e58598

+ 21 - 17
celery/app/amqp.py

@@ -22,6 +22,7 @@ from celery import signals
 from celery.utils import cached_property, uuid
 from celery.utils.text import indent as textindent
 
+from . import app_or_default
 from . import routes as _routes
 
 #: Human readable queue declaration.
@@ -48,7 +49,7 @@ class Queues(dict):
     #: The rest of the queues are then used for routing only.
     _consume_from = None
 
-    def __init__(self, queues, default_exchange=None, create_missing=True):
+    def __init__(self, queues=None, default_exchange=None, create_missing=True):
         dict.__init__(self)
         self.aliases = WeakValueDictionary()
         self.default_exchange = default_exchange
@@ -65,11 +66,9 @@ class Queues(dict):
             return dict.__getitem__(self, name)
 
     def __setitem__(self, name, queue):
-        if self.default_exchange:
-            if not queue.exchange or not queue.exchange.name:
-                queue.exchange = self.default_exchange
-            if queue.exchange.type == 'direct' and not queue.routing_key:
-                queue.routing_key = name
+        if self.default_exchange and (not queue.exchange or
+                                      not queue.exchange.name):
+            queue.exchange = self.default_exchange
         dict.__setitem__(self, name, queue)
         if queue.alias:
             self.aliases[queue.alias] = queue
@@ -135,19 +134,16 @@ class Queues(dict):
 
 
 class TaskProducer(Producer):
+    app = None
     auto_declare = False
     retry = False
     retry_policy = None
 
     def __init__(self, channel=None, exchange=None, *args, **kwargs):
-        self.app = kwargs.get("app") or self.app
         self.retry = kwargs.pop("retry", self.retry)
         self.retry_policy = kwargs.pop("retry_policy",
                                         self.retry_policy or {})
         exchange = exchange or self.exchange
-        if not isinstance(exchange, Exchange):
-            exchange = Exchange(exchange,
-                    kwargs.get("exchange_type") or self.exchange_type)
         self.queues = self.app.amqp.queues  # shortcut
         super(TaskProducer, self).__init__(channel, exchange, *args, **kwargs)
 
@@ -216,7 +212,21 @@ class TaskProducer(Producer):
                                                expires=expires,
                                                queue=queue)
         return task_id
-TaskPublisher = TaskProducer  # compat
+
+class TaskPublisher(TaskProducer):
+    """Deprecated version of :class:`TaskProducer`."""
+
+    def __init__(self, channel=None, exchange=None, *args, **kwargs):
+        self.app = app_or_default(kwargs.pop("app", self.app))
+        self.retry = kwargs.pop("retry", self.retry)
+        self.retry_policy = kwargs.pop("retry_policy",
+                                        self.retry_policy or {})
+        exchange = exchange or self.exchange
+        if not isinstance(exchange, Exchange):
+            exchange = Exchange(exchange,
+                                kwargs.pop("exchange_type", "direct"))
+        self.queues = self.app.amqp.queues  # shortcut
+        super(TaskPublisher, self).__init__(channel, exchange, *args, **kwargs)
 
 
 class TaskConsumer(Consumer):
@@ -267,11 +277,6 @@ class AMQP(object):
                                            reverse="amqp.TaskConsumer")
     get_task_consumer = TaskConsumer  # XXX compat
 
-    def queue_or_default(self, q):
-        if q:
-            return self.queues[q] if not isinstance(q, Queue) else q
-        return self.default_queue
-
     @cached_property
     def TaskProducer(self):
         """Returns publisher used to send tasks.
@@ -283,7 +288,6 @@ class AMQP(object):
         return self.app.subclass_with_self(TaskProducer,
                 reverse="amqp.TaskProducer",
                 exchange=self.default_exchange,
-                exchange_type=self.default_exchange.type,
                 routing_key=conf.CELERY_DEFAULT_ROUTING_KEY,
                 serializer=conf.CELERY_TASK_SERIALIZER,
                 compression=conf.CELERY_MESSAGE_COMPRESSION,

+ 2 - 1
celery/app/builtins.py

@@ -121,7 +121,8 @@ def add_group_task(app):
                         [subtask(task).apply(taskset_id=setid)
                             for task in tasks])
             with app.default_producer() as pub:
-                [subtask(task).apply_async(taskset_id=setid, publisher=pub)
+                [subtask(task).apply_async(taskset_id=setid, publisher=pub,
+                                           add_to_parent=False)
                         for task in tasks]
             parent = get_current_worker_task()
             if parent:

+ 10 - 4
celery/app/task.py

@@ -376,7 +376,8 @@ class Task(object):
 
     def apply_async(self, args=None, kwargs=None,
             task_id=None, producer=None, connection=None, router=None,
-            link=None, link_error=None, publisher=None, **options):
+            link=None, link_error=None, publisher=None, add_to_parent=True,
+            **options):
         """Apply tasks asynchronously by sending a message.
 
         :keyword args: The positional arguments to pass on to the
@@ -459,6 +460,10 @@ class Task(object):
                       if an error occurs while executing the task.
 
         :keyword producer: :class:~@amqp.TaskProducer` instance to use.
+        :keyword add_to_parent: If set to True (default) and the task
+            is applied while executing another task, then the result
+            will be appended to the parent tasks ``request.children``
+            attribute.
         :keyword publisher: Deprecated alias to ``producer``.
 
         .. note::
@@ -495,9 +500,10 @@ class Task(object):
                                    errbacks=maybe_list(link_error),
                                    **options)
         result = self.AsyncResult(task_id)
-        parent = get_current_worker_task()
-        if parent:
-            parent.request.children.append(result)
+        if add_to_parent:
+            parent = get_current_worker_task()
+            if parent:
+                parent.request.children.append(result)
         return result
 
     def retry(self, args=None, kwargs=None, exc=None, throw=True,

+ 1 - 1
celery/bin/celeryd.py

@@ -153,7 +153,7 @@ class WorkerCommand(Command):
         if loglevel:
             try:
                 kwargs["loglevel"] = mlevel(loglevel)
-            except KeyError:
+            except KeyError:  # pragma: no cover
                 self.die("Unknown level %r. Please use one of %s." % (
                     loglevel, "|".join(l for l in LOG_LEVELS.keys()
                       if isinstance(l, basestring))))

+ 1 - 0
celery/tests/__init__.py

@@ -16,6 +16,7 @@ os.environ["CELERY_LOADER"] = "default"
 os.environ["EVENTLET_NOPATCH"] = "yes"
 os.environ["GEVENT_NOPATCH"] = "yes"
 os.environ["KOMBU_DISABLE_LIMIT_PROTECTION"] = "yes"
+os.environ["CELERY_BROKER_URL"] = "memory://"
 
 try:
     WindowsError = WindowsError  # noqa

+ 37 - 1
celery/tests/app/test_amqp.py

@@ -1,9 +1,10 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
+from kombu import Exchange, Queue
 from mock import Mock
 
-from celery.app.amqp import Queues
+from celery.app.amqp import Queues, TaskPublisher
 from celery.tests.utils import AppCase
 
 
@@ -37,6 +38,22 @@ class test_TaskProducer(AppCase):
         self.assertFalse(pub.connection.ensure.call_count)
 
 
+class test_compat_TaskPublisher(AppCase):
+
+    def test_compat_exchange_is_string(self):
+        producer = TaskPublisher(exchange="foo", app=self.app)
+        self.assertIsInstance(producer.exchange, Exchange)
+        self.assertEqual(producer.exchange.name, "foo")
+        self.assertEqual(producer.exchange.type, "direct")
+        producer = TaskPublisher(exchange="foo", exchange_type="topic",
+                                 app=self.app)
+        self.assertEqual(producer.exchange.type, "topic")
+
+    def test_compat_exchange_is_Exchange(self):
+        producer = TaskPublisher(exchange=Exchange("foo"))
+        self.assertEqual(producer.exchange.name, "foo")
+
+
 class test_PublisherPool(AppCase):
 
     def test_setup_nolimit(self):
@@ -100,3 +117,22 @@ class test_Queues(AppCase):
 
     def test_with_defaults(self):
         self.assertEqual(Queues(None), {})
+
+    def test_add(self):
+        q = Queues()
+        q.add("foo", exchange="ex", routing_key="rk")
+        self.assertIn("foo", q)
+        self.assertIsInstance(q["foo"], Queue)
+        self.assertEqual(q["foo"].routing_key, "rk")
+
+    def test_add_default_exchange(self):
+        ex = Exchange("fff", "fanout")
+        q = Queues(default_exchange=ex)
+        q.add(Queue("foo"))
+        self.assertEqual(q["foo"].exchange, ex)
+
+    def test_alias(self):
+        q = Queues()
+        q.add(Queue("foo", alias="barfoo"))
+        self.assertIs(q["barfoo"], q["foo"])
+

+ 173 - 3
celery/tests/app/test_app.py

@@ -10,8 +10,8 @@ from kombu import Exchange
 
 from celery import Celery
 from celery import app as _app
+from celery import state
 from celery.app import defaults
-from celery.app import state
 from celery.loaders.base import BaseLoader
 from celery.platforms import pyimplementation
 from celery.utils.serialization import pickle
@@ -66,12 +66,182 @@ class test_App(Case):
         self.assertEqual(task.name, app.main + ".fun")
 
     def test_with_broker(self):
-        app = Celery(set_as_current=False, broker="foo://baribaz")
-        self.assertEqual(app.conf.BROKER_HOST, "foo://baribaz")
+        prev = os.environ.get("CELERY_BROKER_URL")
+        os.environ.pop("CELERY_BROKER_URL", None)
+        try:
+            app = Celery(set_as_current=False, broker="foo://baribaz")
+            self.assertEqual(app.conf.BROKER_HOST, "foo://baribaz")
+        finally:
+            os.environ["CELERY_BROKER_URL"] = prev
 
     def test_repr(self):
         self.assertTrue(repr(self.app))
 
+    def test_custom_task_registry(self):
+        app1 = Celery(set_as_current=False)
+        app2 = Celery(set_as_current=False, tasks=app1.tasks)
+        self.assertIs(app2.tasks, app1.tasks)
+
+    def test_include_argument(self):
+        app = Celery(set_as_current=False, include=("foo", "bar.foo"))
+        self.assertEqual(app.conf.CELERY_IMPORTS, ("foo", "bar.foo"))
+
+    def test_set_as_current(self):
+        current = state._tls.current_app
+        try:
+            app = Celery(set_as_current=True)
+            self.assertIs(state._tls.current_app, app)
+        finally:
+            state._tls.current_app = current
+
+    def test_current_task(self):
+        app = Celery(set_as_current=False)
+
+        @app.task
+        def foo():
+            pass
+
+        state._task_stack.push(foo)
+        try:
+            self.assertEqual(app.current_task.name, foo.name)
+        finally:
+            state._task_stack.pop()
+
+    def test_task_not_shared(self):
+        with patch("celery.app.base.shared_task") as shared_task:
+            app = Celery(set_as_current=False)
+            @app.task(shared=False)
+            def foo():
+                pass
+            self.assertFalse(shared_task.called)
+
+    def test_task_compat_with_filter(self):
+        app = Celery(set_as_current=False, accept_magic_kwargs=True)
+        check = Mock()
+
+        def filter(task):
+            check(task)
+            return task
+
+        @app.task(filter=filter)
+        def foo():
+            pass
+        check.assert_called_with(foo)
+
+    def test_task_with_filter(self):
+        app = Celery(set_as_current=False, accept_magic_kwargs=False)
+        check = Mock()
+
+        def filter(task):
+            check(task)
+            return task
+
+        @app.task(filter=filter)
+        def foo():
+            pass
+        check.assert_called_with(foo)
+
+    def test_task_sets_main_name_MP_MAIN_FILE(self):
+        from celery.app import task as _task
+        _task.MP_MAIN_FILE = __file__
+        try:
+            app = Celery("xuzzy", set_as_current=False)
+
+            @app.task
+            def foo():
+                pass
+
+            self.assertEqual(foo.name, "xuzzy.foo")
+        finally:
+            _task.MP_MAIN_FILE = None
+
+    def test_base_task_inherits_magic_kwargs_from_app(self):
+        from celery.app.task import Task
+
+        class timkX(Task):
+            abstract = True
+
+        app = Celery(set_as_current=False, accept_magic_kwargs=True)
+        timkX.bind(app)
+        self.assertTrue(timkX.accept_magic_kwargs)
+
+    def test_annotate_decorator(self):
+        from celery.app.task import Task
+
+        class adX(Task):
+            abstract = True
+
+            def run(self, y, z, x):
+                return y, z, x
+
+        check = Mock()
+        def deco(fun):
+            def _inner(*args, **kwargs):
+                check(*args, **kwargs)
+                return fun(*args, **kwargs)
+            return _inner
+
+        app = Celery(set_as_current=False)
+        app.conf.CELERY_ANNOTATIONS = {
+                adX.name: {"@__call__": deco}
+        }
+        adX.bind(app)
+        self.assertIs(adX.app, app)
+
+        i = adX()
+        i(2, 4, x=3)
+        check.assert_called_with(i, 2, 4, x=3)
+
+        i.annotate()
+        i.annotate()
+
+    def test_apply_async_has__self__(self):
+        app = Celery(set_as_current=False)
+
+        @app.task(__self__="hello")
+        def aawsX():
+            pass
+
+        with patch("celery.app.amqp.TaskProducer.delay_task") as dt:
+            aawsX.apply_async((4, 5))
+            args = dt.call_args[0][1]
+            self.assertEqual(args, ("hello", 4, 5))
+
+    def test_apply_async__connection_arg(self):
+        app = Celery(set_as_current=False)
+
+        @app.task()
+        def aacaX():
+            pass
+
+        connection = app.broker_connection("asd://")
+        with self.assertRaises(KeyError):
+            aacaX.apply_async(connection=connection)
+
+    def test_apply_async_adds_children(self):
+        from celery.state import _task_stack
+        app = Celery(set_as_current=False)
+
+        @app.task()
+        def a3cX1(self):
+            pass
+
+        @app.task()
+        def a3cX2(self):
+            pass
+
+        _task_stack.push(a3cX1)
+        try:
+            a3cX1.push_request(called_directly=False)
+            try:
+                res = a3cX2.apply_async(add_to_parent=True)
+                self.assertIn(res, a3cX1.request.children)
+            finally:
+                a3cX1.pop_request()
+        finally:
+            _task_stack.pop()
+
+
     def test_TaskSet(self):
         ts = self.app.TaskSet()
         self.assertListEqual(ts.tasks, [])

+ 47 - 3
celery/tests/app/test_builtins.py

@@ -1,6 +1,6 @@
 from __future__ import absolute_import
 
-from mock import Mock
+from mock import Mock, patch
 
 from celery import current_app as app, group, task, chord
 from celery.app import builtins
@@ -32,6 +32,44 @@ class test_backend_cleanup(Case):
             app.backend = prev
 
 
+class test_map(Case):
+
+    def test_run(self):
+
+        @app.task()
+        def map_mul(x):
+            return x[0] * x[1]
+
+        res = app.tasks["celery.map"](map_mul, [(2, 2), (4, 4), (8, 8)])
+        self.assertEqual(res, [4, 16, 64])
+
+
+class test_starmap(Case):
+
+    def test_run(self):
+
+        @app.task()
+        def smap_mul(x, y):
+            return x * y
+
+        res = app.tasks["celery.starmap"](smap_mul, [(2, 2), (4, 4), (8, 8)])
+        self.assertEqual(res, [4, 16, 64])
+
+
+class test_chunks(Case):
+
+    @patch("celery.canvas.chunks.apply_chunks")
+    def test_run(self, apply_chunks):
+
+        @app.task()
+        def chunks_mul(l):
+            return x * y
+
+        res = app.tasks["celery.chunks"](chunks_mul,
+                [(2, 2), (4, 4), (8, 8)], 1)
+        self.assertTrue(apply_chunks.called)
+
+
 class test_group(Case):
 
     def setUp(self):
@@ -65,9 +103,12 @@ class test_group(Case):
         try:
             add.push_request(called_directly=False)
             try:
+                assert not add.request.children
                 x = group([add.s(4, 4), add.s(8, 8)])
-                x.apply_async()
+                res = x()
                 self.assertTrue(add.request.children)
+                self.assertIn(res, add.request.children)
+                self.assertEqual(len(add.request.children), 1)
             finally:
                 add.pop_request()
         finally:
@@ -95,7 +136,7 @@ class test_chord(Case):
 
     def setUp(self):
         self.prev = app.tasks.get("celery.chord")
-        self.task = builtins.add_chain_task(app)()
+        self.task = builtins.add_chord_task(app)()
 
     def tearDown(self):
         app.tasks["celery.chord"] = self.prev
@@ -106,6 +147,9 @@ class test_chord(Case):
         self.assertTrue(r)
         self.assertTrue(r.parent)
 
+    def test_run_header_not_group(self):
+        self.task([add.s(i, i) for i in xrange(10)], xsum.s())
+
     def test_apply_eager(self):
         app.conf.CELERY_ALWAYS_EAGER = True
         try:

+ 9 - 0
celery/tests/backends/test_backends.py

@@ -1,6 +1,8 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
+from mock import patch
+
 from celery import current_app
 from celery import backends
 from celery.backends.amqp import AMQPBackend
@@ -38,3 +40,10 @@ class test_backends(Case):
         backend, url_ = backends.get_backend_by_url(url)
         self.assertIs(backend, RedisBackend)
         self.assertEqual(url_, url)
+
+    def test_sym_raises_ValuError(self):
+        with patch("celery.backends.symbol_by_name") as sbn:
+            sbn.side_effect = ValueError()
+            with self.assertRaises(ValueError):
+                backends.get_backend_cls("xxx.xxx:foo")
+

+ 18 - 0
celery/tests/bin/test_base.py

@@ -119,6 +119,24 @@ class test_Command(AppCase):
         finally:
             cmd.app.conf.BROKER_URL = "memory://"
 
+    def test_find_app(self):
+        cmd = MockCommand()
+        with patch("celery.bin.base.symbol_by_name") as sbn:
+            from types import ModuleType
+            x = ModuleType("proj")
+            def on_sbn(*args, **kwargs):
+
+                def after(*args, **kwargs):
+                    x.celery = "quick brown fox"
+                    x.__path__ = None
+                    return x
+                sbn.side_effect = after
+                return x
+            sbn.side_effect = on_sbn
+            x.__path__ = [True]
+            self.assertEqual(cmd.find_app("proj"), "quick brown fox")
+
+
     def test_parse_preload_options_shortopt(self):
         cmd = Command()
         cmd.preload_options = (Option("-s", action="store", dest="silent"), )

+ 56 - 0
celery/tests/bin/test_celeryd.py

@@ -91,6 +91,23 @@ class test_Worker(AppCase):
         with self.assertRaises(SystemExit):
             WorkerCommand(app=celery).run(beat=True)
 
+    def test_setup_concurrency_very_early(self):
+        x = WorkerCommand()
+        x.run = Mock()
+        with self.assertRaises(ImportError):
+            x.execute_from_commandline(["celeryd", "-P", "xyzybox"])
+
+    @disable_stdouts
+    def test_invalid_loglevel_gives_error(self):
+        x = WorkerCommand(app=Celery(set_as_current=False))
+        with self.assertRaises(SystemExit):
+            x.run(loglevel="GRIM_REAPER")
+
+    def test_no_loglevel(self):
+        app = Celery(set_as_current=False)
+        app.Worker = Mock()
+        WorkerCommand(app=app).run(loglevel=None)
+
     def test_tasklist(self):
         celery = Celery(set_as_current=False)
         worker = celery.Worker()
@@ -151,6 +168,38 @@ class test_Worker(AppCase):
         worker.autoscale = 13, 10
         self.assertTrue(worker.startup_info())
 
+        worker = self.Worker(queues="foo,bar,baz,xuzzy,do,re,mi")
+        app = worker.app
+        prev, app.loader = app.loader, Mock()
+        try:
+            app.loader.__module__ = "acme.baked_beans"
+            self.assertTrue(worker.startup_info())
+        finally:
+            app.loader = prev
+
+        prev, app.loader = app.loader, Mock()
+        try:
+            app.loader.__module__ = "celery.loaders.foo"
+            self.assertTrue(worker.startup_info())
+        finally:
+            app.loader = prev
+
+        from celery.loaders.app import AppLoader
+        prev, app.loader = app.loader, AppLoader()
+        try:
+            self.assertTrue(worker.startup_info())
+        finally:
+            app.loader = prev
+
+        worker.send_events = True
+        self.assertTrue(worker.startup_info())
+
+        # test when there are too few output lines
+        # to draft the ascii art onto
+        prev, cd.ARTLINES = (cd.ARTLINES,
+            ["the quick brown fox"])
+        self.assertTrue(worker.startup_info())
+
     @disable_stdouts
     def test_run(self):
         self.Worker().run()
@@ -328,6 +377,9 @@ class test_Worker(AppCase):
 
 class test_funs(AppCase):
 
+    def test_active_thread_count(self):
+        self.assertTrue(cd.active_thread_count())
+
     @disable_stdouts
     def test_set_process_status(self):
         try:
@@ -583,6 +635,10 @@ class test_signal_handlers(AppCase):
             handlers["SIGHUP"]("SIGHUP", object())
             self.assertTrue(state.should_stop)
             self.assertTrue(argv)
+            argv[:] = []
+            fork.return_value = 1
+            handlers["SIGHUP"]("SIGHUP", object())
+            self.assertFalse(argv)
         finally:
             os.execv = execv
             state.should_stop = False

+ 19 - 0
celery/tests/tasks/test_tasks.py

@@ -181,6 +181,25 @@ class test_task_retries(Case):
         self.assertEqual(retry_task.iterations, 2)
 
 
+class test_canvas_utils(Case):
+
+    def test_si(self):
+        self.assertTrue(retry_task.si())
+        self.assertTrue(retry_task.si().immutable)
+
+    def test_chunks(self):
+        self.assertTrue(retry_task.chunks(range(100), 10))
+
+    def test_map(self):
+        self.assertTrue(retry_task.map(range(100)))
+
+    def test_starmap(self):
+        self.assertTrue(retry_task.starmap(range(100)))
+
+    def test_on_success(self):
+        retry_task.on_success(1, 1, (), {})
+
+
 class test_tasks(Case):
 
     def test_unpickle_task(self):