Bläddra i källkod

Merge branch 'master' into 5.0-devel

Ask Solem 8 år sedan
förälder
incheckning
902f1cd4ae

+ 3 - 0
celery/app/amqp.py

@@ -324,6 +324,9 @@ class AMQP:
         if kwargsrepr is None:
             kwargsrepr = saferepr(kwargs, self.kwargsrepr_maxsize)
 
+        if not root_id:  # empty root_id defaults to task_id
+            root_id = task_id
+
         return task_message(
             headers={
                 'lang': 'py',

+ 6 - 8
celery/app/base.py

@@ -684,15 +684,13 @@ class Celery:
         options = router.route(
             options, route_name or name, args, kwargs, task_type)
 
-        if root_id is None:
-            parent, have_parent = self.current_worker_task, True
+        if not root_id or not parent_id:
+            parent = self.current_worker_task
             if parent:
-                root_id = parent.request.root_id or parent.request.id
-        if parent_id is None:
-            if not have_parent:
-                parent, have_parent = self.current_worker_task, True
-            if parent:
-                parent_id = parent.request.id
+                if not root_id:
+                    root_id = parent.request.root_id or parent.request.id
+                if not parent_id:
+                    parent_id = parent.request.id
 
         message = amqp.create_task_message(
             task_id, name, args, kwargs, countdown, eta, group_id,

+ 105 - 57
celery/canvas.py

@@ -30,6 +30,39 @@ __all__ = [
 ]
 
 
+def _shorten_names(task_name, s):
+    # type: (str, str) -> str
+    """Remove repeating module names from string.
+
+    Arguments:
+        task_name (str): Task name (full path including module),
+            to use as the basis for removing module names.
+        s (str): The string we want to work on.
+
+    Example:
+
+        >>> _shorten_names(
+        ...    'x.tasks.add',
+        ...    'x.tasks.add(2, 2) | x.tasks.add(4) | x.tasks.mul(8)',
+        ... )
+        'x.tasks.add(2, 2) | add(4) | mul(8)'
+    """
+    # This is used by repr(), to remove repeating module names.
+
+    # extract the module part of the task name
+    module = task_name.rpartition('.')[0] + '.'
+    # find the first occurance of the module name in the string.
+    index = s.find(module)
+    if index >= 0:
+        s = ''.join([
+            # leave the first occurance of the module name untouched.
+            s[:index + len(module)],
+            # strip seen module name from the rest of the string.
+            s[index + len(module):].replace(module, ''),
+        ])
+    return s
+
+
 class _getitem_property:
     """Attribute -> dict key descriptor.
 
@@ -376,9 +409,6 @@ class Signature(dict):
     def set_immutable(self, immutable):
         self.immutable = immutable
 
-    def set_parent_id(self, parent_id):
-        self.parent_id = parent_id
-
     def _with_list_option(self, key):
         items = self.options.setdefault(key, [])
         if not isinstance(items, MutableSequence):
@@ -445,26 +475,48 @@ class Signature(dict):
             # group() | task -> chord
             return chord(self, body=other, app=self._app)
         elif isinstance(other, group):
-            # task | group() -> unroll group with one member
+            # unroll group with one member
             other = maybe_unroll_group(other)
-            return chain(self, other, app=self._app)
+            if isinstance(self, chain):
+                # chain | group() -> chain
+                sig = self.clone()
+                sig.tasks.append(other)
+                return sig
+            # task | group() -> chain
+            return chain(self, other, app=self.app)
         if not isinstance(self, chain) and isinstance(other, chain):
             # task | chain -> chain
             return chain(
                 _seq_concat_seq((self,), other.tasks), app=self._app)
         elif isinstance(other, chain):
             # chain | chain -> chain
-            return chain(
-                _seq_concat_seq(self.tasks, other.tasks), app=self._app)
+            sig = self.clone()
+            if isinstance(sig.tasks, tuple):
+                sig.tasks = list(sig.tasks)
+            sig.tasks.extend(other.tasks)
+            return sig
         elif isinstance(self, chord):
+            # chord | task ->  attach to body
             sig = self.clone()
             sig.body = sig.body | other
             return sig
         elif isinstance(other, Signature):
             if isinstance(self, chain):
-                # chain | task -> chain
-                return chain(
-                    _seq_concat_item(self.tasks, other), app=self._app)
+                if isinstance(self.tasks[-1], group):
+                    # CHAIN [last item is group] | TASK -> chord
+                    sig = self.clone()
+                    sig.tasks[-1] = chord(
+                        sig.tasks[-1], other, app=self._app)
+                    return sig
+                elif isinstance(self.tasks[-1], chord):
+                    # CHAIN [last item is chord] -> chain with chord body.
+                    sig = self.clone()
+                    sig.tasks[-1].body = sig.tasks[-1].body | other
+                    return sig
+                else:
+                    # chain | task -> chain
+                    return chain(
+                        _seq_concat_item(self.tasks, other), app=self._app)
             # task | task -> chain
             return chain(self, other, app=self._app)
         return NotImplemented
@@ -694,7 +746,7 @@ class chain(Signature):
         steps_extend = steps.extend
 
         prev_task = None
-        prev_res = prev_prev_res = None
+        prev_res = None
         tasks, results = [], []
         i = 0
         while steps:
@@ -727,7 +779,6 @@ class chain(Signature):
                     task, body=prev_task,
                     task_id=prev_res.id, root_id=root_id, app=app,
                 )
-                prev_res = prev_prev_res
 
             if is_last_task:
                 # chain(task_id=id) means task id is set for the last task
@@ -745,27 +796,12 @@ class chain(Signature):
             i += 1
 
             if prev_task:
-                prev_task.set_parent_id(task.id)
-
                 if use_link:
                     # link previous task to this task.
                     task.link(prev_task)
 
-                if prev_res:
-                    if isinstance(prev_task, chord):
-                        # If previous task was a chord,
-                        # the freeze above would have set a parent for
-                        # us, but we'd be overwriting it here.
-
-                        # so fix this relationship so it's:
-                        #     chord body -> group -> THIS RES
-                        assert isinstance(prev_res.parent, GroupResult)
-                        prev_res.parent.parent = res
-                    else:
-                        prev_res.parent = res
-
-            if is_first_task and parent_id is not None:
-                task.set_parent_id(parent_id)
+                if prev_res and not prev_res.parent:
+                    prev_res.parent = res
 
             if link_error:
                 for errback in maybe_list(link_error):
@@ -774,14 +810,18 @@ class chain(Signature):
             tasks.append(task)
             results.append(res)
 
-            prev_task, prev_prev_res, prev_res = (
-                task, prev_res, res,
-            )
-
-        if root_id is None and tasks:
-            root_id = tasks[-1].id
-            for task in reversed(tasks):
-                task.options['root_id'] = root_id
+            prev_task, prev_res = task, res
+            if isinstance(task, chord):
+                # If the task is a chord, and the body is a chain
+                # the chain has already been prepared, and res is
+                # set to the last task in the callback chain.
+
+                # We need to change that so that it points to the
+                # group result object.
+                node = res
+                while node.parent:
+                    node = node.parent
+                prev_res = node
         return tasks, results
 
     def apply(self, args=(), kwargs={}, **options):
@@ -806,7 +846,9 @@ class chain(Signature):
         if not self.tasks:
             return '<{0}@{1:#x}: empty>'.format(
                 type(self).__name__, id(self))
-        return ' | '.join(repr(t) for t in self.tasks)
+        return _shorten_names(
+            self.tasks[0]['task'],
+            ' | '.join(repr(t) for t in self.tasks))
 
 
 class _basemap(Signature):
@@ -1091,10 +1133,6 @@ class group(Signature):
             options.pop('task_id', uuid()))
         return options, group_id, options.get('root_id')
 
-    def set_parent_id(self, parent_id):
-        for task in self.tasks:
-            task.set_parent_id(parent_id)
-
     def freeze(self, _id=None, group_id=None, chord=None,
                root_id=None, parent_id=None):
         # pylint: disable=redefined-outer-name
@@ -1141,7 +1179,11 @@ class group(Signature):
         return iter(self.tasks)
 
     def __repr__(self):
-        return 'group({0.tasks!r})'.format(self)
+        if self.tasks:
+            return _shorten_names(
+                self.tasks[0]['task'],
+                'group({0.tasks!r})'.format(self))
+        return 'group(<empty>)'
 
     def __len__(self):
         return len(self.tasks)
@@ -1216,21 +1258,19 @@ class chord(Signature):
             self.tasks = group(self.tasks, app=self.app)
         header_result = self.tasks.freeze(
             parent_id=parent_id, root_id=root_id, chord=self.body)
-        bodyres = self.body.freeze(
-            _id, parent_id=header_result.id, root_id=root_id)
-        bodyres.parent = header_result
+        bodyres = self.body.freeze(_id, root_id=root_id)
+        # we need to link the body result back to the group result,
+        # but the body may actually be a chain,
+        # so find the first result without a parent
+        node = bodyres
+        while node:
+            if not node.parent:
+                node.parent = header_result
+                break
+            node = node.parent
         self.id = self.tasks.id
-        self.body.set_parent_id(self.id)
         return bodyres
 
-    def set_parent_id(self, parent_id):
-        tasks = self.tasks
-        if isinstance(tasks, group):
-            tasks = tasks.tasks
-        for task in tasks:
-            task.set_parent_id(parent_id)
-        self.parent_id = parent_id
-
     def apply_async(self, args=(), kwargs={}, task_id=None,
                     producer=None, connection=None,
                     router=None, result_cls=None, **options):
@@ -1282,7 +1322,6 @@ class chord(Signature):
 
         results = header.freeze(
             group_id=group_id, chord=body, root_id=root_id).results
-        body.set_parent_id(group_id)
         bodyres = body.freeze(task_id, root_id=root_id)
 
         parent = app.backend.apply_chord(
@@ -1317,7 +1356,16 @@ class chord(Signature):
 
     def __repr__(self):
         if self.body:
-            return self.body.reprcall(self.tasks)
+            if isinstance(self.body, chain):
+                return _shorten_names(
+                    self.body.tasks[0]['task'],
+                    '({0} | {1!r})'.format(
+                        self.body.tasks[0].reprcall(self.tasks),
+                        chain(self.body.tasks[1:], app=self._app),
+                    ),
+                )
+            return _shorten_names(
+                self.body['task'], self.body.reprcall(self.tasks))
         return '<chord without body: {0.tasks!r}>'.format(self)
 
     @cached_property

+ 6 - 2
t/integration/conftest.py

@@ -1,12 +1,16 @@
+import os
 import pytest
 from celery.contrib.testing.manager import Manager
 
+TEST_BROKER = os.environ.get('TEST_BROKER', 'pyamqp://')
+TEST_BACKEND = os.environ.get('TEST_BACKEND', 'redis://')
+
 
 @pytest.fixture(scope='session')
 def celery_config():
     return {
-        'broker_url': 'pyamqp://',
-        'result_backend': 'rpc',
+        'broker_url': TEST_BROKER,
+        'result_backend': TEST_BACKEND
     }
 
 

+ 1 - 1
t/integration/tasks.py

@@ -40,4 +40,4 @@ def collect_ids(self, res, i):
         (previous_result, (root_id, parent_id, i))
 
     """
-    return res, ids(i)
+    return res, (self.request.root_id, self.request.parent_id, i)

+ 56 - 27
t/integration/test_canvas.py

@@ -1,6 +1,7 @@
 import pytest
-from celery import chain, group, uuid
+from celery import chain, chord, group
 from celery.exceptions import TimeoutError
+from celery.result import AsyncResult, GroupResult
 from .tasks import add, collect_ids, ids
 
 TIMEOUT = 120
@@ -24,7 +25,7 @@ class test_chain:
 
     def test_parent_ids(self, manager, num=10):
         assert manager.inspect().ping()
-        c = chain(ids.si(i) for i in range(num))
+        c = chain(ids.si(i=i) for i in range(num))
         c.freeze()
         res = c()
         try:
@@ -44,9 +45,9 @@ class test_chain:
         while node:
             root_id, parent_id, value = node.get(timeout=30)
             assert value == i
-            assert root_id == root.id
             if node.parent:
                 assert parent_id == node.parent.id
+            assert root_id == root.id
             node = node.parent
             i -= 1
 
@@ -55,7 +56,11 @@ class test_group:
 
     def test_parent_ids(self, manager):
         assert manager.inspect().ping()
-        g = ids.si(1) | ids.si(2) | group(ids.si(i) for i in range(2, 50))
+        g = (
+            ids.si(i=1) |
+            ids.si(i=2) |
+            group(ids.si(i=i) for i in range(2, 50))
+        )
         res = g()
         expected_root_id = res.parent.parent.id
         expected_parent_id = res.parent.id
@@ -68,48 +73,72 @@ class test_group:
             assert value == i + 2
 
 
+def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
+    root_id, parent_id, value = r.get(timeout=TIMEOUT)
+    assert expected_value == value
+    assert root_id == expected_root_id
+    assert parent_id == expected_parent_id
+
+
 @pytest.mark.celery(result_backend='redis://')
-class xxx_chord:
+class test_chord:
 
     def test_parent_ids(self, manager):
-        self.assert_parentids_chord()
-
-    def test_parent_ids__already_set(self, manager):
-        self.assert_parentids_chord(uuid(), uuid())
+        root = ids.si(i=1)
+        expected_root_id = root.freeze().id
+        g = chain(
+            root, ids.si(i=2),
+            chord(
+                group(ids.si(i=i) for i in range(3, 50)),
+                chain(collect_ids.s(i=50) | ids.si(i=51)),
+            ),
+        )
+        self.assert_parentids_chord(g(), expected_root_id)
 
-    def assert_parentids_chord(self, base_root=None, base_parent=None):
+    def test_parent_ids__OR(self, manager):
+        root = ids.si(i=1)
+        expected_root_id = root.freeze().id
         g = (
-            ids.si(1) |
-            ids.si(2) |
-            group(ids.si(i) for i in range(3, 50)) |
+            root |
+            ids.si(i=2) |
+            group(ids.si(i=i) for i in range(3, 50)) |
             collect_ids.s(i=50) |
-            ids.si(51)
+            ids.si(i=51)
         )
-        g.freeze(root_id=base_root, parent_id=base_parent)
-        res = g.apply_async(root_id=base_root, parent_id=base_parent)
-        expected_root_id = base_root or res.parent.parent.parent.id
+        self.assert_parentids_chord(g(), expected_root_id)
 
-        root_id, parent_id, value = res.get(timeout=30)
-        assert value == 51
-        assert root_id == expected_root_id
-        assert parent_id == res.parent.id
+    def assert_parentids_chord(self, res, expected_root_id):
+        assert isinstance(res, AsyncResult)
+        assert isinstance(res.parent, AsyncResult)
+        assert isinstance(res.parent.parent, GroupResult)
+        assert isinstance(res.parent.parent.parent, AsyncResult)
+        assert isinstance(res.parent.parent.parent.parent, AsyncResult)
 
+        # first we check the last task
+        assert_ids(res, 51, expected_root_id, res.parent.id)
+
+        # then the chord callback
         prev, (root_id, parent_id, value) = res.parent.get(timeout=30)
         assert value == 50
         assert root_id == expected_root_id
-        assert parent_id == res.parent.parent.id
+        # started by one of the chord header tasks.
+        assert parent_id in res.parent.parent.results
 
+        # check what the chord callback recorded
         for i, p in enumerate(prev):
             root_id, parent_id, value = p
             assert root_id == expected_root_id
-            assert parent_id == res.parent.parent.id
+            assert parent_id == res.parent.parent.parent.id
 
-        root_id, parent_id, value = res.parent.parent.get(timeout=30)
+        # ids(i=2)
+        root_id, parent_id, value = res.parent.parent.parent.get(timeout=30)
         assert value == 2
-        assert parent_id == res.parent.parent.parent.id
+        assert parent_id == res.parent.parent.parent.parent.id
         assert root_id == expected_root_id
 
-        root_id, parent_id, value = res.parent.parent.parent.get(timeout=30)
+        # ids(i=1)
+        root_id, parent_id, value = res.parent.parent.parent.parent.get(
+            timeout=30)
         assert value == 1
         assert root_id == expected_root_id
-        assert parent_id == base_parent
+        assert parent_id is None

+ 7 - 91
t/unit/tasks/test_canvas.py

@@ -1,5 +1,5 @@
 import pytest
-from case import ContextMock, MagicMock, Mock
+from case import MagicMock, Mock
 from celery._state import _task_stack
 from celery.canvas import (
     Signature,
@@ -249,7 +249,7 @@ class test_chain(CanvasCase):
 
     def test_repr(self):
         x = self.add.s(2, 2) | self.add.s(2)
-        assert repr(x) == '%s(2, 2) | %s(2)' % (self.add.name, self.add.name)
+        assert repr(x) == '%s(2, 2) | add(2)' % (self.add.name,)
 
     def test_apply_async(self):
         c = self.add.s(2, 2) | self.add.s(4) | self.add.s(8)
@@ -258,43 +258,6 @@ class test_chain(CanvasCase):
         assert result.parent.parent
         assert result.parent.parent.parent is None
 
-    def test_group_to_chord__freeze_parent_id(self):
-        def using_freeze(c):
-            c.freeze(parent_id='foo', root_id='root')
-            return c._frozen[0]
-        self.assert_group_to_chord_parent_ids(using_freeze)
-
-    def assert_group_to_chord_parent_ids(self, freezefun):
-        c = (
-            self.add.s(5, 5) |
-            group([self.add.s(i, i) for i in range(5)], app=self.app) |
-            self.add.si(10, 10) |
-            self.add.si(20, 20) |
-            self.add.si(30, 30)
-        )
-        tasks = freezefun(c)
-        assert tasks[-1].parent_id == 'foo'
-        assert tasks[-1].root_id == 'root'
-        assert tasks[-2].parent_id == tasks[-1].id
-        assert tasks[-2].root_id == 'root'
-        assert tasks[-2].body.parent_id == tasks[-2].tasks.id
-        assert tasks[-2].body.parent_id == tasks[-2].id
-        assert tasks[-2].body.root_id == 'root'
-        assert tasks[-2].tasks.tasks[0].parent_id == tasks[-1].id
-        assert tasks[-2].tasks.tasks[0].root_id == 'root'
-        assert tasks[-2].tasks.tasks[1].parent_id == tasks[-1].id
-        assert tasks[-2].tasks.tasks[1].root_id == 'root'
-        assert tasks[-2].tasks.tasks[2].parent_id == tasks[-1].id
-        assert tasks[-2].tasks.tasks[2].root_id == 'root'
-        assert tasks[-2].tasks.tasks[3].parent_id == tasks[-1].id
-        assert tasks[-2].tasks.tasks[3].root_id == 'root'
-        assert tasks[-2].tasks.tasks[4].parent_id == tasks[-1].id
-        assert tasks[-2].tasks.tasks[4].root_id == 'root'
-        assert tasks[-3].parent_id == tasks[-2].body.id
-        assert tasks[-3].root_id == 'root'
-        assert tasks[-4].parent_id == tasks[-3].id
-        assert tasks[-4].root_id == 'root'
-
     def test_splices_chains(self):
         c = chain(
             self.add.s(5, 5),
@@ -337,21 +300,12 @@ class test_chain(CanvasCase):
         assert tasks[-1].args[0] == 5
         assert isinstance(tasks[-2], chord)
         assert len(tasks[-2].tasks) == 5
-        assert tasks[-2].parent_id == tasks[-1].id
-        assert tasks[-2].root_id == tasks[-1].id
-        assert tasks[-2].body.args[0] == 10
-        assert tasks[-2].body.parent_id == tasks[-2].id
-
-        assert tasks[-3].args[0] == 20
-        assert tasks[-3].root_id == tasks[-1].id
-        assert tasks[-3].parent_id == tasks[-2].body.id
-
-        assert tasks[-4].args[0] == 30
-        assert tasks[-4].parent_id == tasks[-3].id
-        assert tasks[-4].root_id == tasks[-1].id
 
-        assert tasks[-2].body.options['link']
-        assert tasks[-2].body.options['link'][0].options['link']
+        body = tasks[-2].body
+        assert len(body.tasks) == 3
+        assert body.tasks[0].args[0] == 10
+        assert body.tasks[1].args[0] == 20
+        assert body.tasks[2].args[0] == 30
 
         c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
         c2._use_link = True
@@ -433,39 +387,6 @@ class test_chain(CanvasCase):
         assert chain(app=self.app)() is None
         assert chain(app=self.app).apply_async() is None
 
-    def test_root_id_parent_id(self):
-        self.app.conf.task_protocol = 2
-        c = chain(self.add.si(i, i) for i in range(4))
-        c.freeze()
-        tasks, _ = c._frozen
-        for i, task in enumerate(tasks):
-            assert task.root_id == tasks[-1].id
-            try:
-                assert task.parent_id == tasks[i + 1].id
-            except IndexError:
-                assert i == len(tasks) - 1
-            else:
-                valid_parents = i
-        assert valid_parents == len(tasks) - 2
-
-        self.assert_sent_with_ids(tasks[-1], tasks[-1].id, 'foo',
-                                  parent_id='foo')
-        assert tasks[-2].options['parent_id']
-        self.assert_sent_with_ids(tasks[-2], tasks[-1].id, tasks[-1].id)
-        self.assert_sent_with_ids(tasks[-3], tasks[-1].id, tasks[-2].id)
-        self.assert_sent_with_ids(tasks[-4], tasks[-1].id, tasks[-3].id)
-
-    def assert_sent_with_ids(self, task, rid, pid, **options):
-        self.app.amqp.send_task_message = Mock(name='send_task_message')
-        self.app.backend = Mock()
-        self.app.producer_or_acquire = ContextMock()
-
-        task.apply_async(**options)
-        self.app.amqp.send_task_message.assert_called()
-        message = self.app.amqp.send_task_message.call_args[0][2]
-        assert message.headers['parent_id'] == pid
-        assert message.headers['root_id'] == rid
-
     def test_call_no_tasks(self):
         x = chain()
         assert not x()
@@ -683,11 +604,6 @@ class test_chord(CanvasCase):
         x = chord(group(self.add.s(2, 2), self.add.s(4, 4), app=self.app))
         assert x.tasks
 
-    def test_set_parent_id(self):
-        x = chord(group(self.add.s(2, 2)))
-        x.tasks = [self.add.s(2, 2)]
-        x.set_parent_id('pid')
-
     def test_app_when_app(self):
         app = Mock(name='app')
         x = chord([self.add.s(4, 4)], app=app)

+ 1 - 0
t/unit/tasks/test_tasks.py

@@ -411,6 +411,7 @@ class test_tasks(TasksCase):
         with pytest.raises(Ignore):
             self.mytask.replace(sig1)
 
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_replace_callback(self):
         c = group([self.mytask.s()], app=self.app)
         c.freeze = Mock(name='freeze')

+ 10 - 0
tox.ini

@@ -38,6 +38,16 @@ basepython =
     pypy3: pypy3
     flake8,flakeplus,apicheck,linkcheck,configcheck,pydocstyle,cov: python2.7
 
+[testenv:redis]
+setenv =
+    TEST_BROKER = redis://
+    TEST_BACKEND = redis://
+
+[testenv:rabbitmq]
+setenv =
+    TEST_BROKER = pyamqp://
+    TEST_BACKEND = rpc
+
 [testenv:cov]
 commands =
     pip install -U -r{toxinidir}/requirements/dev.txt