Browse Source

100% coverage for celery.canvas

Ask Solem 12 years ago
parent
commit
648b799482
3 changed files with 161 additions and 17 deletions
  1. 2 2
      celery/canvas.py
  2. 3 3
      celery/result.py
  3. 156 12
      celery/tests/tasks/test_canvas.py

+ 2 - 2
celery/canvas.py

@@ -328,8 +328,8 @@ class _basemap(Signature):
         )
 
     @classmethod
-    def from_dict(self, d):
-        return chunks(*self._unpack_args(d['kwargs']), **d['options'])
+    def from_dict(cls, d):
+        return cls(*cls._unpack_args(d['kwargs']), **d['options'])
 
 
 class xmap(_basemap):

+ 3 - 3
celery/result.py

@@ -27,6 +27,9 @@ from .five import items, range, string_t
 class ResultBase(object):
     """Base class for all results"""
 
+    #: Parent result (if part of a chain)
+    parent = None
+
 
 class AsyncResult(ResultBase):
     """Query task state.
@@ -46,9 +49,6 @@ class AsyncResult(ResultBase):
     #: The task result backend to use.
     backend = None
 
-    #: Parent result (if part of a chain)
-    parent = None
-
     def __init__(self, id, backend=None, task_name=None,
                  app=None, parent=None):
         self.app = app_or_default(app or self.app)

+ 156 - 12
celery/tests/tasks/test_canvas.py

@@ -2,11 +2,22 @@ from __future__ import absolute_import
 
 from mock import Mock
 
-from celery import current_app, task
-from celery.canvas import Signature, chain, group, chord, subtask
+from celery import shared_task
+from celery.canvas import (
+    Signature,
+    chain,
+    group,
+    chord,
+    subtask,
+    xmap,
+    xstarmap,
+    chunks,
+    _maybe_group,
+    maybe_subtask,
+)
 from celery.result import EagerResult
 
-from celery.tests.utils import Case
+from celery.tests.utils import AppCase
 
 SIG = Signature({'task': 'TASK',
                  'args': ('A1', ),
@@ -15,22 +26,22 @@ SIG = Signature({'task': 'TASK',
                  'subtask_type': ''})
 
 
-@task()
+@shared_task()
 def add(x, y):
     return x + y
 
 
-@task()
+@shared_task()
 def mul(x, y):
     return x * y
 
 
-@task()
+@shared_task()
 def div(x, y):
     return x / y
 
 
-class test_Signature(Case):
+class test_Signature(AppCase):
 
     def test_getitem_property_class(self):
         self.assertTrue(Signature.task)
@@ -94,6 +105,9 @@ class test_Signature(Case):
         self.assertEqual(len(z.tasks), 4)
         with self.assertRaises(TypeError):
             x | 10
+        ax = add.s(2, 2) | (add.s(4) | add.s(8))
+        self.assertIsInstance(ax, chain)
+        self.assertEqual(len(ax.tasks), 3, 'consolidates chain to chain')
 
     def test_INVERT(self):
         x = add.s(2, 2)
@@ -104,8 +118,84 @@ class test_Signature(Case):
         self.assertEqual(~x, 4)
         self.assertTrue(x.apply_async.called)
 
+    def test_merge_immutable(self):
+        x = add.si(2, 2, foo=1)
+        args, kwargs, options = x._merge((4, ), {'bar': 2}, {'task_id': 3})
+        self.assertTupleEqual(args, (2, 2))
+        self.assertDictEqual(kwargs, {'foo': 1})
+        self.assertDictEqual(options, {'task_id': 3})
 
-class test_chain(Case):
+    def test_set_immutable(self):
+        x = add.s(2, 2)
+        self.assertFalse(x.immutable)
+        x.set(immutable=True)
+        self.assertTrue(x.immutable)
+        x.set(immutable=False)
+        self.assertFalse(x.immutable)
+
+    def test_election(self):
+        x = add.s(2, 2)
+        x._freeze('foo')
+        prev, x.type.app.control = x.type.app.control, Mock()
+        try:
+            r = x.election()
+            self.assertTrue(x.type.app.control.election.called)
+            self.assertEqual(r.id, 'foo')
+        finally:
+            x.type.app.control = prev
+
+    def test_AsyncResult_when_not_registerd(self):
+        s = subtask('xxx.not.registered')
+        self.assertTrue(s.AsyncResult)
+
+    def test_apply_async_when_not_registered(self):
+        s = subtask('xxx.not.registered')
+        self.assertTrue(s._apply_async)
+
+
+class test_xmap_xstarmap(AppCase):
+
+    def test_apply(self):
+        for type, attr in [(xmap, 'map'), (xstarmap, 'starmap')]:
+            args = [(i, i) for i in range(10)]
+            s = getattr(add, attr)(args)
+            s.type = Mock()
+
+            s.apply_async(foo=1)
+            s.type.apply_async.assert_called_with(
+                    (), {'task': add.s(), 'it': args}, foo=1,
+            )
+
+            self.assertEqual(type.from_dict(dict(s)), s)
+            self.assertTrue(repr(s))
+
+
+class test_chunks(AppCase):
+
+    def test_chunks(self):
+        x = add.chunks(range(100), 10)
+        self.assertEqual(chunks.from_dict(dict(x)), x)
+
+        self.assertTrue(x.group())
+        self.assertEqual(len(x.group().tasks), 10)
+
+        x.group = Mock()
+        gr = x.group.return_value = Mock()
+
+        x.apply_async()
+        gr.apply_async.assert_called_with((), {})
+
+        x()
+        gr.assert_called_with()
+
+        self.app.conf.CELERY_ALWAYS_EAGER = True
+        try:
+            chunks.apply_chunks(**x['kwargs'])
+        finally:
+            self.app.conf.CELERY_ALWAYS_EAGER = False
+
+
+class test_chain(AppCase):
 
     def test_repr(self):
         x = add.s(2, 2) | add.s(2)
@@ -117,11 +207,11 @@ class test_chain(Case):
         self.assertIsInstance(subtask(dict(x)), chain)
 
     def test_always_eager(self):
-        current_app.conf.CELERY_ALWAYS_EAGER = True
+        self.app.conf.CELERY_ALWAYS_EAGER = True
         try:
             self.assertEqual(~(add.s(4, 4) | add.s(8)), 16)
         finally:
-            current_app.conf.CELERY_ALWAYS_EAGER = False
+            self.app.conf.CELERY_ALWAYS_EAGER = False
 
     def test_apply(self):
         x = chain(add.s(4, 4), add.s(8), add.s(10))
@@ -133,13 +223,30 @@ class test_chain(Case):
         self.assertEqual(res.parent.parent.get(), 8)
         self.assertIsNone(res.parent.parent.parent)
 
+    def test_call_no_tasks(self):
+        x = chain()
+        self.assertFalse(x())
+
+    def test_call_with_tasks(self):
+        x = add.s(2, 2) | add.s(4)
+        x.apply_async = Mock()
+        x(2, 2, foo=1)
+        x.apply_async.assert_called_with((2, 2), {'foo': 1})
+
+    def test_from_dict_no_args__with_args(self):
+        x = dict(add.s(2, 2) | add.s(4))
+        x['args'] = None
+        self.assertIsInstance(chain.from_dict(x), chain)
+        x['args'] = (2, )
+        self.assertIsInstance(chain.from_dict(x), chain)
+
     def test_accepts_generator_argument(self):
         x = chain(add.s(i) for i in range(10))
         self.assertTrue(x.tasks[0].type, add)
         self.assertTrue(x.type)
 
 
-class test_group(Case):
+class test_group(AppCase):
 
     def test_repr(self):
         x = group([add.s(2, 2), add.s(4, 4)])
@@ -150,8 +257,32 @@ class test_group(Case):
         self.assertIsInstance(subtask(x), group)
         self.assertIsInstance(subtask(dict(x)), group)
 
+    def test_maybe_group_sig(self):
+        self.assertListEqual(_maybe_group(add.s(2, 2)), [add.s(2, 2)])
+
+    def test_from_dict(self):
+        x = group([add.s(2, 2), add.s(4, 4)])
+        x['args'] = (2, 2)
+        self.assertTrue(group.from_dict(dict(x)))
+        x['args'] = None
+        self.assertTrue(group.from_dict(dict(x)))
+
+    def test_call_empty_group(self):
+        x = group()
+        self.assertIsNone(x())
+
+    def test_skew(self):
+        g = group([add.s(i, i) for i in range(10)])
+        g.skew(start=1, stop=10, step=1)
+        for i, task in enumerate(g.tasks):
+            self.assertEqual(task.options['countdown'], i + 1)
+
+    def test_iter(self):
+        g = group([add.s(i, i) for i in range(10)])
+        self.assertListEqual(list(iter(g)), g.tasks)
+
 
-class test_chord(Case):
+class test_chord(AppCase):
 
     def test_reverse(self):
         x = chord([add.s(2, 2), add.s(4, 4)], body=mul.s(4))
@@ -184,3 +315,16 @@ class test_chord(Case):
         self.assertTrue(repr(x))
         x.kwargs['body'] = None
         self.assertIn('without body', repr(x))
+
+
+class test_maybe_subtask(AppCase):
+
+    def test_is_None(self):
+        self.assertIsNone(maybe_subtask(None))
+
+    def test_is_dict(self):
+        self.assertIsInstance(maybe_subtask(dict(add.s())), Signature)
+
+    def test_when_sig(self):
+        s = add.s()
+        self.assertIs(maybe_subtask(s), s)