Browse Source

Resolve TypeError on `.get` from nested groups (#4432)

* Accept and pass along the `on_interval` in ResultSet.get

Otherwise, calls to .get or .join on ResultSets fail on nested groups.
Fixes #4274

* Add a unit test that verifies the fixed behavior

Verified that the unit test fails on master, but passes on the patched version. The
nested structure of results was borrowed from #4274

* Wrap long lines

* Add integration test for #4274 use case

* Switch to a simpler, group-only-based integration test

* Flatten expected integration test result

* Added back testcase from #4274 and skip it if the backend under test does not support native joins.

* Fix lint.

* Enable only if chords are allowed.

* Fix access to message.
Misha Wolfson 7 years ago
parent
commit
dd2cdd9c4f
3 changed files with 83 additions and 3 deletions
  1. 3 2
      celery/result.py
  2. 44 0
      t/integration/test_canvas.py
  3. 36 1
      t/unit/tasks/test_result.py

+ 3 - 2
celery/result.py

@@ -628,7 +628,7 @@ class ResultSet(ResultBase):
 
     def get(self, timeout=None, propagate=True, interval=0.5,
             callback=None, no_ack=True, on_message=None,
-            disable_sync_subtasks=True):
+            disable_sync_subtasks=True, on_interval=None):
         """See :meth:`join`.
 
         This is here for API compatibility with :class:`AsyncResult`,
@@ -640,7 +640,8 @@ class ResultSet(ResultBase):
         return (self.join_native if self.supports_native_join else self.join)(
             timeout=timeout, propagate=propagate,
             interval=interval, callback=callback, no_ack=no_ack,
-            on_message=on_message, disable_sync_subtasks=disable_sync_subtasks
+            on_message=on_message, disable_sync_subtasks=disable_sync_subtasks,
+            on_interval=on_interval,
         )
 
     def join(self, timeout=None, propagate=True, interval=0.5,

+ 44 - 0
t/integration/test_canvas.py

@@ -142,6 +142,24 @@ class test_group:
             assert parent_id == expected_parent_id
             assert value == i + 2
 
+    @flaky
+    def test_nested_group(self, manager):
+        assert manager.inspect().ping()
+
+        c = group(
+            add.si(1, 10),
+            group(
+                add.si(1, 100),
+                group(
+                    add.si(1, 1000),
+                    add.si(1, 2000),
+                ),
+            ),
+        )
+        res = c()
+
+        assert res.get(timeout=TIMEOUT) == [11, 101, 1001, 2001]
+
 
 def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
     root_id, parent_id, value = r.get(timeout=TIMEOUT)
@@ -164,6 +182,32 @@ class test_chord:
         res = c()
         assert res.get(timeout=TIMEOUT) == [12, 13, 14, 15]
 
+    @flaky
+    def test_nested_group_chain(self, manager):
+        try:
+            manager.app.backend.ensure_chords_allowed()
+        except NotImplementedError as e:
+            raise pytest.skip(e.args[0])
+
+        if not manager.app.backend.supports_native_join:
+            raise pytest.skip('Requires native join support.')
+        c = chain(
+            add.si(1, 0),
+            group(
+                add.si(1, 100),
+                chain(
+                    add.si(1, 200),
+                    group(
+                        add.si(1, 1000),
+                        add.si(1, 2000),
+                    ),
+                ),
+            ),
+            add.si(1, 10),
+        )
+        res = c()
+        assert res.get(timeout=TIMEOUT) == 11
+
     @flaky
     def test_parent_ids(self, manager):
         if not manager.app.conf.result_backend.startswith('redis'):

+ 36 - 1
t/unit/tasks/test_result.py

@@ -519,12 +519,16 @@ class MockAsyncResultFailure(AsyncResult):
 class MockAsyncResultSuccess(AsyncResult):
     forgotten = False
 
+    def __init__(self, *args, **kwargs):
+        self._result = kwargs.pop('result', 42)
+        super(MockAsyncResultSuccess, self).__init__(*args, **kwargs)
+
     def forget(self):
         self.forgotten = True
 
     @property
     def result(self):
-        return 42
+        return self._result
 
     @property
     def state(self):
@@ -622,6 +626,37 @@ class test_GroupResult:
         for sub in subs:
             assert sub.forgotten
 
+    def test_get_nested_without_native_join(self):
+        backend = SimpleBackend()
+        backend.supports_native_join = False
+        ts = self.app.GroupResult(uuid(), [
+            MockAsyncResultSuccess(uuid(), result='1.1',
+                                   app=self.app, backend=backend),
+            self.app.GroupResult(uuid(), [
+                MockAsyncResultSuccess(uuid(), result='2.1',
+                                       app=self.app, backend=backend),
+                self.app.GroupResult(uuid(), [
+                    MockAsyncResultSuccess(uuid(), result='3.1',
+                                           app=self.app, backend=backend),
+                    MockAsyncResultSuccess(uuid(), result='3.2',
+                                           app=self.app, backend=backend),
+                ]),
+            ]),
+        ])
+        ts.app.backend = backend
+
+        vals = ts.get()
+        assert vals == [
+            '1.1',
+            [
+                '2.1',
+                [
+                    '3.1',
+                    '3.2',
+                ]
+            ],
+        ]
+
     def test_getitem(self):
         subs = [MockAsyncResultSuccess(uuid(), app=self.app),
                 MockAsyncResultSuccess(uuid(), app=self.app)]