Browse Source

Add support for delayed expansion of generators to groups and chords.
Prior to this change, if a generator was passed to group or chord it was completely expanded prior to any tasks being submitted to the backend. After this change, tasks will be submitted to the backend as they are retreived from the generator, allowing work to begin much sooner for long-running generators.
This is currently a work in progress: only redis is supported.

Chad Dombrova 9 years ago
parent
commit
23abf3576b

+ 8 - 4
celery/backends/base.py

@@ -386,6 +386,9 @@ class Backend(object):
     def on_chord_part_return(self, request, state, result, **kwargs):
     def on_chord_part_return(self, request, state, result, **kwargs):
         pass
         pass
 
 
+    def set_chord_size(self, group_id, size):
+        pass
+
     def fallback_chord_unlock(self, group_id, body, result=None,
     def fallback_chord_unlock(self, group_id, body, result=None,
                               countdown=1, **kwargs):
                               countdown=1, **kwargs):
         kwargs['result'] = [r.as_tuple() for r in result]
         kwargs['result'] = [r.as_tuple() for r in result]
@@ -394,11 +397,12 @@ class Backend(object):
         )
         )
 
 
     def apply_chord(self, header, partial_args, group_id, body,
     def apply_chord(self, header, partial_args, group_id, body,
-                    options={}, **kwargs):
+                    options={}, result=None, **kwargs):
+        result = list(result)
         fixed_options = {k: v for k, v in items(options) if k != 'task_id'}
         fixed_options = {k: v for k, v in items(options) if k != 'task_id'}
-        result = header(*partial_args, task_id=group_id, **fixed_options or {})
-        self.fallback_chord_unlock(group_id, body, **kwargs)
-        return result
+        res = header(*partial_args, task_id=group_id, **fixed_options or {})
+        self.fallback_chord_unlock(group_id, body, result=result, **kwargs)
+        return res
 
 
     def current_task_children(self, request=None):
     def current_task_children(self, request=None):
         request = request or getattr(current_task(), 'request', None)
         request = request or getattr(current_task(), 'request', None)

+ 14 - 3
celery/backends/redis.py

@@ -239,6 +239,9 @@ class RedisBackend(base.BaseKeyValueStoreBackend, async.AsyncBackendMixin):
             raise ChordError('Dependency {0} raised {1!r}'.format(tid, retval))
             raise ChordError('Dependency {0} raised {1!r}'.format(tid, retval))
         return retval
         return retval
 
 
+    def set_chord_size(self, group_id, chord_size):
+        self.set(self.get_key_for_group(group_id, '.s'), chord_size)
+
     def apply_chord(self, header, partial_args, group_id, body,
     def apply_chord(self, header, partial_args, group_id, body,
                     result=None, options={}, **kwargs):
                     result=None, options={}, **kwargs):
         # avoids saving the group in the redis db.
         # avoids saving the group in the redis db.
@@ -254,17 +257,24 @@ class RedisBackend(base.BaseKeyValueStoreBackend, async.AsyncBackendMixin):
         client = self.client
         client = self.client
         jkey = self.get_key_for_group(gid, '.j')
         jkey = self.get_key_for_group(gid, '.j')
         tkey = self.get_key_for_group(gid, '.t')
         tkey = self.get_key_for_group(gid, '.t')
+        skey = self.get_key_for_group(gid, '.s')
         result = self.encode_result(result, state)
         result = self.encode_result(result, state)
         with client.pipeline() as pipe:
         with client.pipeline() as pipe:
-            _, readycount, totaldiff, _, _ = pipe                           \
+            _, readycount, totaldiff, total, _, _, _ = pipe                 \
                 .rpush(jkey, self.encode([1, tid, state, result]))          \
                 .rpush(jkey, self.encode([1, tid, state, result]))          \
                 .llen(jkey)                                                 \
                 .llen(jkey)                                                 \
                 .get(tkey)                                                  \
                 .get(tkey)                                                  \
+                .get(skey)                                                  \
                 .expire(jkey, 86400)                                        \
                 .expire(jkey, 86400)                                        \
                 .expire(tkey, 86400)                                        \
                 .expire(tkey, 86400)                                        \
+                .expire(skey, 86400)                                        \
                 .execute()
                 .execute()
 
 
-        totaldiff = int(totaldiff or 0)
+        if total is None:
+            # chord is not completely submitted yet
+            return
+
+        total = int(total) + int(totaldiff or 0)
 
 
         try:
         try:
             callback = maybe_signature(request.chord, app=app)
             callback = maybe_signature(request.chord, app=app)
@@ -272,10 +282,11 @@ class RedisBackend(base.BaseKeyValueStoreBackend, async.AsyncBackendMixin):
             if readycount == total:
             if readycount == total:
                 decode, unpack = self.decode, self._unpack_chord_result
                 decode, unpack = self.decode, self._unpack_chord_result
                 with client.pipeline() as pipe:
                 with client.pipeline() as pipe:
-                    resl, _, _ = pipe               \
+                    resl, _, _, _ = pipe            \
                         .lrange(jkey, 0, total)     \
                         .lrange(jkey, 0, total)     \
                         .delete(jkey)               \
                         .delete(jkey)               \
                         .delete(tkey)               \
                         .delete(tkey)               \
+                        .delete(skey)               \
                         .execute()
                         .execute()
                 try:
                 try:
                     callback.delay([unpack(tup, decode) for tup in resl])
                     callback.delay([unpack(tup, decode) for tup in resl])

+ 56 - 92
celery/canvas.py

@@ -18,18 +18,18 @@ from collections import MutableSequence, deque
 from copy import deepcopy
 from copy import deepcopy
 from functools import partial as _partial, reduce
 from functools import partial as _partial, reduce
 from operator import itemgetter
 from operator import itemgetter
-from itertools import chain as _chain
+from itertools import chain as _chain, tee
 
 
 from kombu.utils import cached_property, fxrange, reprcall, uuid
 from kombu.utils import cached_property, fxrange, reprcall, uuid
 from vine import barrier
 from vine import barrier
 
 
 from celery._state import current_app
 from celery._state import current_app
-from celery.five import python_2_unicode_compatible
+from celery.five import python_2_unicode_compatible, zip
 from celery.local import try_import
 from celery.local import try_import
 from celery.result import GroupResult
 from celery.result import GroupResult
 from celery.utils import abstract
 from celery.utils import abstract
 from celery.utils.functional import (
 from celery.utils.functional import (
-    maybe_list, is_list, _regen, regen, chunks as _chunks,
+    maybe_list, is_list, _regen, regen, lookahead, chunks as _chunks,
 )
 )
 from celery.utils.text import truncate
 from celery.utils.text import truncate
 
 
@@ -114,12 +114,6 @@ def task_name_from(task):
     return getattr(task, 'name', task)
     return getattr(task, 'name', task)
 
 
 
 
-def _upgrade(fields, sig):
-    """Used by custom signatures in .from_dict, to keep common fields."""
-    sig.update(chord_size=fields.get('chord_size'))
-    return sig
-
-
 @abstract.CallableSignature.register
 @abstract.CallableSignature.register
 @python_2_unicode_compatible
 @python_2_unicode_compatible
 class Signature(dict):
 class Signature(dict):
@@ -178,13 +172,14 @@ class Signature(dict):
         else:
         else:
             self._type = task
             self._type = task
 
 
-        init(self,
-             task=task_name, args=tuple(args or ()),
-             kwargs=kwargs or {},
-             options=dict(options or {}, **ex),
-             subtask_type=subtask_type,
-             immutable=immutable,
-             chord_size=None)
+        init(
+            self,
+            task=task_name, args=tuple(args or ()),
+            kwargs=kwargs or {},
+            options=dict(options or {}, **ex),
+            subtask_type=subtask_type,
+            immutable=immutable,
+        )
 
 
     def __call__(self, *partial_args, **partial_kwargs):
     def __call__(self, *partial_args, **partial_kwargs):
         args, kwargs, _ = self._merge(partial_args, partial_kwargs, None)
         args, kwargs, _ = self._merge(partial_args, partial_kwargs, None)
@@ -216,7 +211,6 @@ class Signature(dict):
         s = Signature.from_dict({'task': self.task, 'args': tuple(args),
         s = Signature.from_dict({'task': self.task, 'args': tuple(args),
                                  'kwargs': kwargs, 'options': deepcopy(opts),
                                  'kwargs': kwargs, 'options': deepcopy(opts),
                                  'subtask_type': self.subtask_type,
                                  'subtask_type': self.subtask_type,
-                                 'chord_size': self.chord_size,
                                  'immutable': self.immutable}, app=self._app)
                                  'immutable': self.immutable}, app=self._app)
         s._type = self._type
         s._type = self._type
         return s
         return s
@@ -410,7 +404,6 @@ class Signature(dict):
     kwargs = _getitem_property('kwargs')
     kwargs = _getitem_property('kwargs')
     options = _getitem_property('options')
     options = _getitem_property('options')
     subtask_type = _getitem_property('subtask_type')
     subtask_type = _getitem_property('subtask_type')
-    chord_size = _getitem_property('chord_size')
     immutable = _getitem_property('immutable')
     immutable = _getitem_property('immutable')
 
 
 
 
@@ -594,7 +587,7 @@ class chain(Signature):
                 tasks = d['kwargs']['tasks'] = list(tasks)
                 tasks = d['kwargs']['tasks'] = list(tasks)
             # First task must be signature object to get app
             # First task must be signature object to get app
             tasks[0] = maybe_signature(tasks[0], app=app)
             tasks[0] = maybe_signature(tasks[0], app=app)
-        return _upgrade(d, chain(*tasks, app=app, **d['options']))
+        return chain(*tasks, app=app, **d['options'])
 
 
     @property
     @property
     def app(self):
     def app(self):
@@ -630,9 +623,7 @@ class _basemap(Signature):
 
 
     @classmethod
     @classmethod
     def from_dict(cls, d, app=None):
     def from_dict(cls, d, app=None):
-        return _upgrade(
-            d, cls(*cls._unpack_args(d['kwargs']), app=app, **d['options']),
-        )
+        return cls(*cls._unpack_args(d['kwargs']), app=app, **d['options'])
 
 
 
 
 @Signature.register_type
 @Signature.register_type
@@ -670,10 +661,7 @@ class chunks(Signature):
 
 
     @classmethod
     @classmethod
     def from_dict(self, d, app=None):
     def from_dict(self, d, app=None):
-        return _upgrade(
-            d, chunks(*self._unpack_args(
-                d['kwargs']), app=app, **d['options']),
-        )
+        return chunks(*self._unpack_args(d['kwargs']), app=app, **d['options'])
 
 
     def apply_async(self, args=(), kwargs={}, **opts):
     def apply_async(self, args=(), kwargs={}, **opts):
         return self.group().apply_async(
         return self.group().apply_async(
@@ -705,7 +693,7 @@ def _maybe_group(tasks, app):
     elif isinstance(tasks, abstract.CallableSignature):
     elif isinstance(tasks, abstract.CallableSignature):
         tasks = [tasks]
         tasks = [tasks]
     else:
     else:
-        tasks = [signature(t, app=app) for t in tasks]
+        tasks = regen(signature(t, app=app) for t in tasks)
     return tasks
     return tasks
 
 
 
 
@@ -728,45 +716,29 @@ class group(Signature):
 
 
     @classmethod
     @classmethod
     def from_dict(self, d, app=None):
     def from_dict(self, d, app=None):
-        return _upgrade(
-            d, group(d['kwargs']['tasks'], app=app, **d['options']),
-        )
+        return group(d['kwargs']['tasks'], app=app, **d['options'])
 
 
     def __len__(self):
     def __len__(self):
         return len(self.tasks)
         return len(self.tasks)
 
 
-    def _prepared(self, tasks, partial_args, group_id, root_id, app,
-                  CallableSignature=abstract.CallableSignature,
-                  from_dict=Signature.from_dict,
-                  isinstance=isinstance, tuple=tuple):
-        for task in tasks:
-            if isinstance(task, CallableSignature):
-                # local sigs are always of type Signature, and we
-                # clone them to make sure we do not modify the originals.
-                task = task.clone()
-            else:
-                # serialized sigs must be converted to Signature.
-                task = from_dict(task, app=app)
-            if isinstance(task, group):
-                # needs yield_from :(
-                unroll = task._prepared(
-                    task.tasks, partial_args, group_id, root_id, app,
-                )
-                for taskN, resN in unroll:
-                    yield taskN, resN
-            else:
-                if partial_args and not task.immutable:
-                    task.args = tuple(partial_args) + tuple(task.args)
-                yield task, task.freeze(group_id=group_id, root_id=root_id)
+    def _prepared(self, tasks, partial_args, group_id, root_id, app):
+        for task in self._unroll_tasks(tasks, app=app, clone=True):
+            if partial_args and not task.immutable:
+                task.args = tuple(partial_args) + tuple(task.args)
+            yield task, task.freeze(group_id=group_id, root_id=root_id)
 
 
     def _apply_tasks(self, tasks, producer=None, app=None, p=None,
     def _apply_tasks(self, tasks, producer=None, app=None, p=None,
                      add_to_parent=None, chord=None, **options):
                      add_to_parent=None, chord=None, **options):
         app = app or self.app
         app = app or self.app
         with app.producer_or_acquire(producer) as producer:
         with app.producer_or_acquire(producer) as producer:
-            for sig, res in tasks:
+            for count, (curr_task, next_task) in enumerate(lookahead(tasks)):
+                sig, res = curr_task
+                ichord = sig.options.get('chord') or chord
+                if ichord is not None and next_task is None:
+                    # last task in chord: set the length *before* applying
+                    app.backend.set_chord_size(sig.options['group_id'], count)
                 sig.apply_async(producer=producer, add_to_parent=False,
                 sig.apply_async(producer=producer, add_to_parent=False,
-                                chord=sig.options.get('chord') or chord,
-                                **options)
+                                chord=ichord, **options)
 
 
                 # adding callback to result, such that it will gradually
                 # adding callback to result, such that it will gradually
                 # fulfill the barrier.
                 # fulfill the barrier.
@@ -846,17 +818,21 @@ class group(Signature):
     def __call__(self, *partial_args, **options):
     def __call__(self, *partial_args, **options):
         return self.apply_async(partial_args, **options)
         return self.apply_async(partial_args, **options)
 
 
-    def _freeze_unroll(self, new_tasks, group_id, chord, root_id, parent_id):
-        stack = deque(self.tasks)
-        while stack:
-            task = maybe_signature(stack.popleft(), app=self._app).clone()
+    def _unroll_tasks(self, tasks, app=None, clone=True):
+        for task in tasks:
+            task = maybe_signature(task, app=app or self._app, clone=clone)
             if isinstance(task, group):
             if isinstance(task, group):
-                stack.extendleft(task.tasks)
+                it = task._unroll_tasks(task.tasks, app=app, clone=clone)
+                for subtask in it:
+                    yield subtask
             else:
             else:
-                new_tasks.append(task)
-                yield task.freeze(group_id=group_id,
-                                  chord=chord, root_id=root_id,
-                                  parent_id=parent_id)
+                yield task
+
+    def _freeze_tasks(self, tasks, group_id, chord, root_id, parent_id):
+        for task in tasks:
+            yield task.freeze(group_id=group_id,
+                              chord=chord, root_id=root_id,
+                              parent_id=parent_id)
 
 
     def freeze(self, _id=None, group_id=None, chord=None,
     def freeze(self, _id=None, group_id=None, chord=None,
                root_id=None, parent_id=None):
                root_id=None, parent_id=None):
@@ -871,16 +847,11 @@ class group(Signature):
             opts['chord'] = chord
             opts['chord'] = chord
         root_id = opts.setdefault('root_id', root_id)
         root_id = opts.setdefault('root_id', root_id)
         parent_id = opts.setdefault('parent_id', parent_id)
         parent_id = opts.setdefault('parent_id', parent_id)
-        new_tasks = []
-        # Need to unroll subgroups early so that chord gets the
-        # right result instance for chord_unlock etc.
-        results = list(self._freeze_unroll(
-            new_tasks, group_id, chord, root_id, parent_id,
-        ))
-        if isinstance(self.tasks, MutableSequence):
-            self.tasks[:] = new_tasks
-        else:
-            self.tasks = new_tasks
+        tasks1, tasks2 = tee(self._unroll_tasks(self.tasks, clone=True))
+        results = regen(
+            self._freeze_tasks(tasks1, group_id, chord, root_id, parent_id))
+        # if self.tasks is consumed, it will also populate the group results
+        self.tasks = regen(x[0] for x in zip(tasks2, results))
         return self.app.GroupResult(gid, results)
         return self.app.GroupResult(gid, results)
     _freeze = freeze
     _freeze = freeze
 
 
@@ -942,7 +913,7 @@ class chord(Signature):
     @classmethod
     @classmethod
     def from_dict(self, d, app=None):
     def from_dict(self, d, app=None):
         args, d['kwargs'] = self._unpack_args(**d['kwargs'])
         args, d['kwargs'] = self._unpack_args(**d['kwargs'])
-        return _upgrade(d, self(*args, app=app, **d))
+        return self(*args, app=app, **d)
 
 
     @staticmethod
     @staticmethod
     def _unpack_args(header=None, body=None, **kwargs):
     def _unpack_args(header=None, body=None, **kwargs):
@@ -990,17 +961,8 @@ class chord(Signature):
             args=(tasks.apply(args, kwargs).get(propagate=propagate),),
             args=(tasks.apply(args, kwargs).get(propagate=propagate),),
         )
         )
 
 
-    def _traverse_tasks(self, tasks, value=None):
-        stack = deque(list(tasks))
-        while stack:
-            task = stack.popleft()
-            if isinstance(task, group):
-                stack.extend(task.tasks)
-            else:
-                yield task if value is None else value
-
     def __length_hint__(self):
     def __length_hint__(self):
-        return sum(self._traverse_tasks(self.tasks, 1))
+        return len(list(self._unroll_tasks(self.tasks, clone=False)))
 
 
     def run(self, header, body, partial_args, app=None, interval=None,
     def run(self, header, body, partial_args, app=None, interval=None,
             countdown=1, max_retries=None, eager=False,
             countdown=1, max_retries=None, eager=False,
@@ -1008,7 +970,6 @@ class chord(Signature):
         app = app or self._get_app(body)
         app = app or self._get_app(body)
         group_id = uuid()
         group_id = uuid()
         root_id = body.options.get('root_id')
         root_id = body.options.get('root_id')
-        body.chord_size = self.__length_hint__()
         options = dict(self.options, **options) if options else self.options
         options = dict(self.options, **options) if options else self.options
         if options:
         if options:
             options.pop('task_id', None)
             options.pop('task_id', None)
@@ -1070,13 +1031,16 @@ def signature(varies, *args, **kwargs):
 subtask = signature   # XXX compat
 subtask = signature   # XXX compat
 
 
 
 
-def maybe_signature(d, app=None):
+def maybe_signature(d, app=None, clone=False):
     if d is not None:
     if d is not None:
-        if (isinstance(d, dict) and
-                not isinstance(d, abstract.CallableSignature)):
-            d = signature(d)
+        if isinstance(d, dict):
+            if not isinstance(d, abstract.CallableSignature):
+                d = signature(d)
+            elif clone:
+                d = d.clone()
+
         if app is not None:
         if app is not None:
             d._app = app
             d._app = app
-        return d
+    return d
 
 
 maybe_subtask = maybe_signature  # XXX compat
 maybe_subtask = maybe_signature  # XXX compat

+ 2 - 0
celery/tests/app/test_builtins.py

@@ -115,6 +115,7 @@ class test_group(BuiltinsCase):
     def test_task(self, current_worker_task):
     def test_task(self, current_worker_task):
         g, result = self.mock_group(self.add.s(2), self.add.s(4))
         g, result = self.mock_group(self.add.s(2), self.add.s(4))
         self.task(g.tasks, result, result.id, (2,)).results
         self.task(g.tasks, result, result.id, (2,)).results
+        print('TASKS: %r' % (g.tasks,))
         g.tasks[0].clone().apply_async.assert_called_with(
         g.tasks[0].clone().apply_async.assert_called_with(
             group_id=result.id, producer=self.app.producer_or_acquire(),
             group_id=result.id, producer=self.app.producer_or_acquire(),
             add_to_parent=False,
             add_to_parent=False,
@@ -178,5 +179,6 @@ class test_chord(BuiltinsCase):
     def test_apply_eager_with_arguments(self):
     def test_apply_eager_with_arguments(self):
         self.app.conf.task_always_eager = True
         self.app.conf.task_always_eager = True
         x = chord([self.add.s(i) for i in range(10)], body=self.xsum.s())
         x = chord([self.add.s(i) for i in range(10)], body=self.xsum.s())
+        print(list(x.tasks))
         r = x.apply_async([1])
         r = x.apply_async([1])
         self.assertEqual(r.get(), 55)
         self.assertEqual(r.get(), 55)

+ 12 - 8
celery/tests/backends/test_base.py

@@ -8,10 +8,13 @@ from contextlib import contextmanager
 from celery.exceptions import ChordError, TimeoutError
 from celery.exceptions import ChordError, TimeoutError
 from celery.five import items, bytes_if_py2, range
 from celery.five import items, bytes_if_py2, range
 from celery.utils import serialization
 from celery.utils import serialization
-from celery.utils.serialization import subclass_exception
-from celery.utils.serialization import find_pickleable_exception as fnpe
-from celery.utils.serialization import UnpickleableExceptionWrapper
-from celery.utils.serialization import get_pickleable_exception as gpe
+from celery.utils.serialization import (
+    subclass_exception,
+    find_pickleable_exception as fnpe,
+    UnpickleableExceptionWrapper,
+    get_pickleable_exception as gpe,
+)
+from celery.utils.functional import pass1, regen
 
 
 from celery import states
 from celery import states
 from celery import group, uuid
 from celery import group, uuid
@@ -22,7 +25,6 @@ from celery.backends.base import (
     _nulldict,
     _nulldict,
 )
 )
 from celery.result import result_from_tuple
 from celery.result import result_from_tuple
-from celery.utils.functional import pass1
 
 
 from celery.tests.case import ANY, AppCase, Case, Mock, call, patch, skip
 from celery.tests.case import ANY, AppCase, Case, Mock, call, patch, skip
 
 
@@ -84,7 +86,7 @@ class test_BaseBackend_interface(AppCase):
         self.app.tasks[unlock] = Mock()
         self.app.tasks[unlock] = Mock()
         self.b.apply_chord(
         self.b.apply_chord(
             group(app=self.app), (), 'dakj221', None,
             group(app=self.app), (), 'dakj221', None,
-            result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
+            result=regen(self.app.AsyncResult(x) for x in [1, 2, 3]),
         )
         )
         self.assertTrue(self.app.tasks[unlock].apply_async.call_count)
         self.assertTrue(self.app.tasks[unlock].apply_async.call_count)
 
 
@@ -524,12 +526,14 @@ class test_KeyValueStoreBackend(AppCase):
     def test_chord_apply_fallback(self):
     def test_chord_apply_fallback(self):
         self.b.implements_incr = False
         self.b.implements_incr = False
         self.b.fallback_chord_unlock = Mock()
         self.b.fallback_chord_unlock = Mock()
+        res = regen(x for x in range(10))
         self.b.apply_chord(
         self.b.apply_chord(
             group(app=self.app), (), 'group_id', 'body',
             group(app=self.app), (), 'group_id', 'body',
-            result='result', foo=1,
+            result=res, foo=1,
         )
         )
+        self.assertTrue(res.fully_consumed())
         self.b.fallback_chord_unlock.assert_called_with(
         self.b.fallback_chord_unlock.assert_called_with(
-            'group_id', 'body', result='result', foo=1,
+            'group_id', 'body', result=res, foo=1,
         )
         )
 
 
     def test_get_missing_meta(self):
     def test_get_missing_meta(self):

+ 3 - 1
celery/tests/backends/test_cache.py

@@ -12,6 +12,7 @@ from celery import group, signature, uuid
 from celery.backends.cache import CacheBackend, DummyClient, backends
 from celery.backends.cache import CacheBackend, DummyClient, backends
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.five import items, bytes_if_py2, string, text_t
 from celery.five import items, bytes_if_py2, string, text_t
+from celery.utils.functional import regen
 
 
 from celery.tests.case import AppCase, Mock, mock, patch, skip
 from celery.tests.case import AppCase, Mock, mock, patch, skip
 
 
@@ -67,7 +68,8 @@ class test_CacheBackend(AppCase):
 
 
     def test_apply_chord(self):
     def test_apply_chord(self):
         tb = CacheBackend(backend='memory://', app=self.app)
         tb = CacheBackend(backend='memory://', app=self.app)
-        gid, res = uuid(), [self.app.AsyncResult(uuid()) for _ in range(3)]
+        gid = uuid()
+        res = regen(self.app.AsyncResult(uuid()) for _ in range(3))
         tb.apply_chord(group(app=self.app), (), gid, {}, result=res)
         tb.apply_chord(group(app=self.app), (), gid, {}, result=res)
 
 
     @patch('celery.result.GroupResult.restore')
     @patch('celery.result.GroupResult.restore')

+ 28 - 10
celery/tests/backends/test_redis.py

@@ -278,6 +278,12 @@ class test_RedisBackend(AppCase):
         b.add_to_chord(gid, 'sig')
         b.add_to_chord(gid, 'sig')
         b.client.incr.assert_called_with(b.get_key_for_group(gid, '.t'), 1)
         b.client.incr.assert_called_with(b.get_key_for_group(gid, '.t'), 1)
 
 
+    def test_set_chord_size(self):
+        b = self.Backend('redis://', app=self.app)
+        gid = uuid()
+        b.set_chord_size(gid, 10)
+        b.client.set.assert_called_with(b.get_key_for_group(gid, '.s'), 10)
+
     def test_expires_is_None(self):
     def test_expires_is_None(self):
         b = self.Backend(expires=None, app=self.app)
         b = self.Backend(expires=None, app=self.app)
         self.assertEqual(
         self.assertEqual(
@@ -304,7 +310,6 @@ class test_RedisBackend(AppCase):
         self.app.tasks['foobarbaz'] = task
         self.app.tasks['foobarbaz'] = task
         task.request.chord = signature(task)
         task.request.chord = signature(task)
         task.request.id = tid
         task.request.id = tid
-        task.request.chord['chord_size'] = 10
         task.request.group = 'group_id'
         task.request.group = 'group_id'
         return task
         return task
 
 
@@ -312,6 +317,8 @@ class test_RedisBackend(AppCase):
     def test_on_chord_part_return(self, restore):
     def test_on_chord_part_return(self, restore):
         tasks = [self.create_task() for i in range(10)]
         tasks = [self.create_task() for i in range(10)]
 
 
+        self.b.set_chord_size('group_id', 10)
+
         for i in range(10):
         for i in range(10):
             self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
             self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
             self.assertTrue(self.b.client.rpush.call_count)
             self.assertTrue(self.b.client.rpush.call_count)
@@ -319,20 +326,27 @@ class test_RedisBackend(AppCase):
         self.assertTrue(self.b.client.lrange.call_count)
         self.assertTrue(self.b.client.lrange.call_count)
         jkey = self.b.get_key_for_group('group_id', '.j')
         jkey = self.b.get_key_for_group('group_id', '.j')
         tkey = self.b.get_key_for_group('group_id', '.t')
         tkey = self.b.get_key_for_group('group_id', '.t')
-        self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
+        skey = self.b.get_key_for_group('group_id', '.s')
+        self.b.client.delete.assert_has_calls([
+            call(jkey), call(tkey), call(skey),
+        ])
         self.b.client.expire.assert_has_calls([
         self.b.client.expire.assert_has_calls([
-            call(jkey, 86400), call(tkey, 86400),
+            call(jkey, 86400), call(tkey, 86400), call(skey, 86400),
         ])
         ])
 
 
     def test_on_chord_part_return__success(self):
     def test_on_chord_part_return__success(self):
-        with self.chord_context(2) as (_, request, callback):
+        with self.chord_context(3) as (_, request, callback):
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
             callback.delay.assert_not_called()
             callback.delay.assert_not_called()
+            self.b.set_chord_size('gid1', 3)
             self.b.on_chord_part_return(request, states.SUCCESS, 20)
             self.b.on_chord_part_return(request, states.SUCCESS, 20)
-            callback.delay.assert_called_with([10, 20])
+            callback.delay.assert_not_called()
+            self.b.on_chord_part_return(request, states.SUCCESS, 30)
+            callback.delay.assert_called_with([10, 20, 30])
 
 
     def test_on_chord_part_return__callback_raises(self):
     def test_on_chord_part_return__callback_raises(self):
         with self.chord_context(1) as (_, request, callback):
         with self.chord_context(1) as (_, request, callback):
+            self.b.set_chord_size('gid1', 1)
             callback.delay.side_effect = KeyError(10)
             callback.delay.side_effect = KeyError(10)
             task = self.app._tasks['add'] = Mock(name='add_task')
             task = self.app._tasks['add'] = Mock(name='add_task')
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
@@ -342,22 +356,27 @@ class test_RedisBackend(AppCase):
 
 
     def test_on_chord_part_return__ChordError(self):
     def test_on_chord_part_return__ChordError(self):
         with self.chord_context(1) as (_, request, callback):
         with self.chord_context(1) as (_, request, callback):
+            self.b.set_chord_size('gid1', 1)
             self.b.client.pipeline = ContextMock()
             self.b.client.pipeline = ContextMock()
             raise_on_second_call(self.b.client.pipeline, ChordError())
             raise_on_second_call(self.b.client.pipeline, ChordError())
-            self.b.client.pipeline.return_value.rpush().llen().get().expire(
-            ).expire().execute.return_value = (1, 1, 0, 4, 5)
+            self._set_pipeline_ret(1, 1, 0, 1, 4, 5, 6)
             task = self.app._tasks['add'] = Mock(name='add_task')
             task = self.app._tasks['add'] = Mock(name='add_task')
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
             task.backend.fail_from_current_stack.assert_called_with(
             task.backend.fail_from_current_stack.assert_called_with(
                 callback.id, exc=ANY,
                 callback.id, exc=ANY,
             )
             )
 
 
+    def _set_pipeline_ret(self, *args):
+            self.b.client.pipeline.return_value.rpush().llen().get().get()\
+                .expire().expire().expire()\
+                .execute.return_value = args
+
     def test_on_chord_part_return__other_error(self):
     def test_on_chord_part_return__other_error(self):
         with self.chord_context(1) as (_, request, callback):
         with self.chord_context(1) as (_, request, callback):
+            self.b.set_chord_size('gid1', 1)
             self.b.client.pipeline = ContextMock()
             self.b.client.pipeline = ContextMock()
             raise_on_second_call(self.b.client.pipeline, RuntimeError())
             raise_on_second_call(self.b.client.pipeline, RuntimeError())
-            self.b.client.pipeline.return_value.rpush().llen().get().expire(
-            ).expire().execute.return_value = (1, 1, 0, 4, 5)
+            self._set_pipeline_ret(1, 1, 0, 1, 4, 5, 6)
             task = self.app._tasks['add'] = Mock(name='add_task')
             task = self.app._tasks['add'] = Mock(name='add_task')
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
             task.backend.fail_from_current_stack.assert_called_with(
             task.backend.fail_from_current_stack.assert_called_with(
@@ -373,7 +392,6 @@ class test_RedisBackend(AppCase):
             request.group = 'gid1'
             request.group = 'gid1'
             callback = ms.return_value = Signature('add')
             callback = ms.return_value = Signature('add')
             callback.id = 'id1'
             callback.id = 'id1'
-            callback['chord_size'] = size
             callback.delay = Mock(name='callback.delay')
             callback.delay = Mock(name='callback.delay')
             yield tasks, request, callback
             yield tasks, request, callback
 
 

+ 19 - 0
celery/tests/tasks/test_canvas.py

@@ -15,6 +15,7 @@ from celery.canvas import (
     maybe_unroll_group,
     maybe_unroll_group,
 )
 )
 from celery.result import EagerResult
 from celery.result import EagerResult
+from celery.utils.functional import _regen
 
 
 from celery.tests.case import (
 from celery.tests.case import (
     AppCase, ContextMock, MagicMock, Mock, depends_on_current_app,
     AppCase, ContextMock, MagicMock, Mock, depends_on_current_app,
@@ -328,6 +329,8 @@ class test_chain(CanvasCase):
         c._use_link = True
         c._use_link = True
         tasks, results = c.prepare_steps((), c.tasks)
         tasks, results = c.prepare_steps((), c.tasks)
 
 
+        print(tasks[-2].tasks)
+
         self.assertEqual(tasks[-1].args[0], 5)
         self.assertEqual(tasks[-1].args[0], 5)
         self.assertIsInstance(tasks[-2], chord)
         self.assertIsInstance(tasks[-2], chord)
         self.assertEqual(len(tasks[-2].tasks), 5)
         self.assertEqual(len(tasks[-2].tasks), 5)
@@ -580,6 +583,14 @@ class test_group(CanvasCase):
         g = group([self.add.s(i, i) for i in range(10)])
         g = group([self.add.s(i, i) for i in range(10)])
         self.assertListEqual(list(iter(g)), g.tasks)
         self.assertListEqual(list(iter(g)), g.tasks)
 
 
+    def test_maintains_generator(self):
+        g = group(self.add.s(x, x) for x in range(3))
+        self.assertIsInstance(g.tasks, _regen)
+        self.assertFalse(g.tasks.fully_consumed())
+        g.freeze()
+        self.assertIsInstance(g.tasks, _regen)
+        self.assertFalse(g.tasks.fully_consumed())
+
 
 
 class test_chord(CanvasCase):
 class test_chord(CanvasCase):
 
 
@@ -656,6 +667,14 @@ class test_chord(CanvasCase):
         x.tasks = [self.add.s(2, 2)]
         x.tasks = [self.add.s(2, 2)]
         x.freeze()
         x.freeze()
 
 
+    def test_maintains_generator(self):
+        x = chord((self.add.s(i, i) for i in range(3)), body=self.mul.s(4))
+        self.assertIsInstance(x.tasks, _regen)
+        self.assertFalse(x.tasks.fully_consumed())
+        x.freeze()
+        self.assertIsInstance(x.tasks, group)
+        self.assertFalse(x.tasks.tasks.fully_consumed())
+
 
 
 class test_maybe_signature(CanvasCase):
 class test_maybe_signature(CanvasCase):
 
 

+ 58 - 33
celery/tests/utils/test_functional.py

@@ -10,6 +10,7 @@ from celery.utils.functional import (
     firstmethod,
     firstmethod,
     first,
     first,
     maybe_list,
     maybe_list,
+    lookahead,
     mlazy,
     mlazy,
     padlist,
     padlist,
     regen,
     regen,
@@ -78,6 +79,12 @@ class test_utils(Case):
         self.assertIsNone(first(predicate, range(10, 20)))
         self.assertIsNone(first(predicate, range(10, 20)))
         self.assertEqual(iterations[0], 10)
         self.assertEqual(iterations[0], 10)
 
 
+    def test_lookahead(self):
+        self.assertEqual(
+            list(lookahead(x for x in range(6))),
+            [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, None)],
+        )
+
     def test_maybe_list(self):
     def test_maybe_list(self):
         self.assertEqual(maybe_list(1), [1])
         self.assertEqual(maybe_list(1), [1])
         self.assertEqual(maybe_list([1]), [1])
         self.assertEqual(maybe_list([1]), [1])
@@ -98,7 +105,10 @@ class test_mlazy(Case):
 
 
 class test_regen(Case):
 class test_regen(Case):
 
 
-    def test_regen_list(self):
+    def setup(self):
+        self.g = regen(iter(list(range(10))))
+
+    def test_list(self):
         l = [1, 2]
         l = [1, 2]
         r = regen(iter(l))
         r = regen(iter(l))
         self.assertIs(regen(l), l)
         self.assertIs(regen(l), l)
@@ -109,39 +119,54 @@ class test_regen(Case):
         fun, args = r.__reduce__()
         fun, args = r.__reduce__()
         self.assertEqual(fun(*args), l)
         self.assertEqual(fun(*args), l)
 
 
-    def test_regen_gen(self):
-        g = regen(iter(list(range(10))))
-        self.assertEqual(g[7], 7)
-        self.assertEqual(g[6], 6)
-        self.assertEqual(g[5], 5)
-        self.assertEqual(g[4], 4)
-        self.assertEqual(g[3], 3)
-        self.assertEqual(g[2], 2)
-        self.assertEqual(g[1], 1)
-        self.assertEqual(g[0], 0)
-        self.assertEqual(g.data, list(range(10)))
-        self.assertEqual(g[8], 8)
-        self.assertEqual(g[0], 0)
-        g = regen(iter(list(range(10))))
-        self.assertEqual(g[0], 0)
-        self.assertEqual(g[1], 1)
-        self.assertEqual(g.data, list(range(10)))
-        g = regen(iter([1]))
-        self.assertEqual(g[0], 1)
+    def test_index(self):
+        self.assertEqual(self.g[7], 7)
+        self.assertEqual(self.g[6], 6)
+        self.assertEqual(self.g[5], 5)
+        self.assertEqual(self.g[4], 4)
+        self.assertEqual(self.g[3], 3)
+        self.assertEqual(self.g[2], 2)
+        self.assertEqual(self.g[1], 1)
+        self.assertEqual(self.g[0], 0)
+        self.assertFalse(self.g.fully_consumed())
+        self.assertEqual(self.g.data, list(range(10)))
+        self.assertTrue(self.g.fully_consumed())
+        self.assertEqual(self.g[8], 8)
+        self.assertEqual(self.g[0], 0)
+        self.assertListEqual(list(iter(self.g)), list(range(10)))
+
+    def test_index_2(self):
+        self.assertEqual(self.g[0], 0)
+        self.assertEqual(self.g[1], 1)
+        self.assertEqual(self.g.data, list(range(10)))
+
+    def test_index_error(self):
         with self.assertRaises(IndexError):
         with self.assertRaises(IndexError):
-            g[1]
-        self.assertEqual(g.data, [1])
-
-        g = regen(iter(list(range(10))))
-        self.assertEqual(g[-1], 9)
-        self.assertEqual(g[-2], 8)
-        self.assertEqual(g[-3], 7)
-        self.assertEqual(g[-4], 6)
-        self.assertEqual(g[-5], 5)
-        self.assertEqual(g[5], 5)
-        self.assertEqual(g.data, list(range(10)))
-
-        self.assertListEqual(list(iter(g)), list(range(10)))
+            self.g[11]
+        self.assertTrue(self.g.fully_consumed())
+        self.assertListEqual(list(iter(self.g)), list(range(10)))
+
+    def test_negative_index(self):
+        self.assertEqual(self.g[-1], 9)
+        self.assertEqual(self.g[-2], 8)
+        self.assertEqual(self.g[-3], 7)
+        self.assertEqual(self.g[-4], 6)
+        self.assertEqual(self.g[-5], 5)
+        self.assertEqual(self.g[5], 5)
+        self.assertEqual(self.g.data, list(range(10)))
+
+    def test_iter(self):
+        list(iter(self.g))
+        self.assertTrue(self.g.fully_consumed())
+        self.assertListEqual(list(iter(self.g)), list(range(10)))
+
+    def test_repr(self):
+        repr(self.g)
+        self.assertFalse(self.g.fully_consumed())
+
+    def test_bool(self):
+        bool(self.g)
+        self.assertFalse(self.g.fully_consumed())
 
 
 
 
 class test_head_from_fun(Case):
 class test_head_from_fun(Case):

+ 0 - 4
celery/utils/abstract.py

@@ -101,10 +101,6 @@ class CallableSignature(CallableTask):  # pragma: no cover
     def subtask_type(self):
     def subtask_type(self):
         pass
         pass
 
 
-    @abstractproperty
-    def chord_size(self):
-        pass
-
     @abstractproperty
     @abstractproperty
     def immutable(self):
     def immutable(self):
         pass
         pass

+ 64 - 10
celery/utils/functional.py

@@ -12,7 +12,7 @@ import sys
 
 
 from functools import partial
 from functools import partial
 from inspect import isfunction
 from inspect import isfunction
-from itertools import chain, islice
+from itertools import islice, tee
 
 
 from kombu.utils.functional import (
 from kombu.utils.functional import (
     LRUCache, dictfilter, lazy, maybe_evaluate, memoize,
     LRUCache, dictfilter, lazy, maybe_evaluate, memoize,
@@ -20,12 +20,13 @@ from kombu.utils.functional import (
 )
 )
 from vine import promise
 from vine import promise
 
 
-from celery.five import UserList, getfullargspec, range
+from celery.five import UserList, getfullargspec, range, zip_longest
 
 
 __all__ = [
 __all__ = [
     'LRUCache', 'is_list', 'maybe_list', 'memoize', 'mlazy', 'noop',
     'LRUCache', 'is_list', 'maybe_list', 'memoize', 'mlazy', 'noop',
     'first', 'firstmethod', 'chunks', 'padlist', 'mattrgetter', 'uniq',
     'first', 'firstmethod', 'chunks', 'padlist', 'mattrgetter', 'uniq',
-    'regen', 'dictfilter', 'lazy', 'maybe_evaluate', 'head_from_fun',
+    'lookahead', 'regen', 'dictfilter', 'lazy', 'maybe_evaluate',
+    'head_from_fun',
 ]
 ]
 
 
 IS_PY3 = sys.version_info[0] == 3
 IS_PY3 = sys.version_info[0] == 3
@@ -178,6 +179,22 @@ def uniq(it):
     return (seen.add(obj) or obj for obj in it if obj not in seen)
     return (seen.add(obj) or obj for obj in it if obj not in seen)
 
 
 
 
+def lookahead(it):
+    """Yield pairs of ``(current, next)`` items in ``it``.
+
+    `next` is None if `current` is the last item.
+
+    Example::
+
+        >>> list(lookahead(x for x in range(6)))
+        [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, None)]
+
+    """
+    a, b = tee(it)
+    next(b, None)
+    return zip_longest(a, b, fillvalue=None)
+
+
 def regen(it):
 def regen(it):
     """``Regen`` takes any iterable, and if the object is an
     """``Regen`` takes any iterable, and if the object is an
     generator it will cache the evaluated list on first access,
     generator it will cache the evaluated list on first access,
@@ -194,6 +211,11 @@ class _regen(UserList, list):
         self.__it = it
         self.__it = it
         self.__index = 0
         self.__index = 0
         self.__consumed = []
         self.__consumed = []
+        self.__done = False
+
+    def fully_consumed(self):
+        """Return whether the iterator has been fully consumed."""
+        return self.__done
 
 
     def __reduce__(self):
     def __reduce__(self):
         return list, (self.data,)
         return list, (self.data,)
@@ -201,29 +223,61 @@ class _regen(UserList, list):
     def __length_hint__(self):
     def __length_hint__(self):
         return self.__it.__length_hint__()
         return self.__it.__length_hint__()
 
 
+    def __len__(self):
+        # CPython iter() calls len() first on lists, so we cannot
+        # have __len__ calling __iter__.
+        if self.__done:
+            return len(self.__consumed)
+        try:
+            return self.__length_hint__()
+        except Exception:
+            return NotImplemented
+
+    def __repr__(self, ellipsis='...', sep=', '):
+        # override list.__repr__ to avoid consuming the generator
+        if self.__done:
+            return repr(self.__consumed)
+        return '[{0}]'.format(sep.join(map(repr, self.__consumed)) + ellipsis)
+
     def __iter__(self):
     def __iter__(self):
-        return chain(self.__consumed, self.__it)
+        return iter(self.__consumed) if self.__done else self._iter_cont()
+
+    def _iter_cont(self):
+        append = self.__consumed.append
+        for y in self.__it:
+            append(y)
+            yield y
+        self.__done = True
 
 
     def __getitem__(self, index):
     def __getitem__(self, index):
-        if index < 0:
+        if index < 0 or self.__done:
             return self.data[index]
             return self.data[index]
         try:
         try:
             return self.__consumed[index]
             return self.__consumed[index]
         except IndexError:
         except IndexError:
+            it = iter(self)
             try:
             try:
                 for i in range(self.__index, index + 1):
                 for i in range(self.__index, index + 1):
-                    self.__consumed.append(next(self.__it))
+                    next(it)
             except StopIteration:
             except StopIteration:
                 raise IndexError(index)
                 raise IndexError(index)
             else:
             else:
                 return self.__consumed[index]
                 return self.__consumed[index]
 
 
+    def __bool__(self, sentinel=object()):
+        # bool for list calls len() which would consume the generator:
+        # override to consume maximum of one item.
+        return (
+            len(self.__consumed) or
+            next(iter(self), sentinel) is not sentinel
+        )
+    __nonzero__ = __bool__  # XXX Py2
+
     @property
     @property
     def data(self):
     def data(self):
-        try:
-            self.__consumed.extend(list(self.__it))
-        except StopIteration:
-            pass
+        # consume the generator
+        if not self.__done:
+            list(iter(self))
         return self.__consumed
         return self.__consumed