Browse Source

Add support for delayed expansion of generators to groups and chords.
Prior to this change, if a generator was passed to group or chord it was completely expanded prior to any tasks being submitted to the backend. After this change, tasks will be submitted to the backend as they are retreived from the generator, allowing work to begin much sooner for long-running generators.
This is currently a work in progress: only redis is supported.

Chad Dombrova 9 years ago
parent
commit
23abf3576b

+ 8 - 4
celery/backends/base.py

@@ -386,6 +386,9 @@ class Backend(object):
     def on_chord_part_return(self, request, state, result, **kwargs):
         pass
 
+    def set_chord_size(self, group_id, size):
+        pass
+
     def fallback_chord_unlock(self, group_id, body, result=None,
                               countdown=1, **kwargs):
         kwargs['result'] = [r.as_tuple() for r in result]
@@ -394,11 +397,12 @@ class Backend(object):
         )
 
     def apply_chord(self, header, partial_args, group_id, body,
-                    options={}, **kwargs):
+                    options={}, result=None, **kwargs):
+        result = list(result)
         fixed_options = {k: v for k, v in items(options) if k != 'task_id'}
-        result = header(*partial_args, task_id=group_id, **fixed_options or {})
-        self.fallback_chord_unlock(group_id, body, **kwargs)
-        return result
+        res = header(*partial_args, task_id=group_id, **fixed_options or {})
+        self.fallback_chord_unlock(group_id, body, result=result, **kwargs)
+        return res
 
     def current_task_children(self, request=None):
         request = request or getattr(current_task(), 'request', None)

+ 14 - 3
celery/backends/redis.py

@@ -239,6 +239,9 @@ class RedisBackend(base.BaseKeyValueStoreBackend, async.AsyncBackendMixin):
             raise ChordError('Dependency {0} raised {1!r}'.format(tid, retval))
         return retval
 
+    def set_chord_size(self, group_id, chord_size):
+        self.set(self.get_key_for_group(group_id, '.s'), chord_size)
+
     def apply_chord(self, header, partial_args, group_id, body,
                     result=None, options={}, **kwargs):
         # avoids saving the group in the redis db.
@@ -254,17 +257,24 @@ class RedisBackend(base.BaseKeyValueStoreBackend, async.AsyncBackendMixin):
         client = self.client
         jkey = self.get_key_for_group(gid, '.j')
         tkey = self.get_key_for_group(gid, '.t')
+        skey = self.get_key_for_group(gid, '.s')
         result = self.encode_result(result, state)
         with client.pipeline() as pipe:
-            _, readycount, totaldiff, _, _ = pipe                           \
+            _, readycount, totaldiff, total, _, _, _ = pipe                 \
                 .rpush(jkey, self.encode([1, tid, state, result]))          \
                 .llen(jkey)                                                 \
                 .get(tkey)                                                  \
+                .get(skey)                                                  \
                 .expire(jkey, 86400)                                        \
                 .expire(tkey, 86400)                                        \
+                .expire(skey, 86400)                                        \
                 .execute()
 
-        totaldiff = int(totaldiff or 0)
+        if total is None:
+            # chord is not completely submitted yet
+            return
+
+        total = int(total) + int(totaldiff or 0)
 
         try:
             callback = maybe_signature(request.chord, app=app)
@@ -272,10 +282,11 @@ class RedisBackend(base.BaseKeyValueStoreBackend, async.AsyncBackendMixin):
             if readycount == total:
                 decode, unpack = self.decode, self._unpack_chord_result
                 with client.pipeline() as pipe:
-                    resl, _, _ = pipe               \
+                    resl, _, _, _ = pipe            \
                         .lrange(jkey, 0, total)     \
                         .delete(jkey)               \
                         .delete(tkey)               \
+                        .delete(skey)               \
                         .execute()
                 try:
                     callback.delay([unpack(tup, decode) for tup in resl])

+ 56 - 92
celery/canvas.py

@@ -18,18 +18,18 @@ from collections import MutableSequence, deque
 from copy import deepcopy
 from functools import partial as _partial, reduce
 from operator import itemgetter
-from itertools import chain as _chain
+from itertools import chain as _chain, tee
 
 from kombu.utils import cached_property, fxrange, reprcall, uuid
 from vine import barrier
 
 from celery._state import current_app
-from celery.five import python_2_unicode_compatible
+from celery.five import python_2_unicode_compatible, zip
 from celery.local import try_import
 from celery.result import GroupResult
 from celery.utils import abstract
 from celery.utils.functional import (
-    maybe_list, is_list, _regen, regen, chunks as _chunks,
+    maybe_list, is_list, _regen, regen, lookahead, chunks as _chunks,
 )
 from celery.utils.text import truncate
 
@@ -114,12 +114,6 @@ def task_name_from(task):
     return getattr(task, 'name', task)
 
 
-def _upgrade(fields, sig):
-    """Used by custom signatures in .from_dict, to keep common fields."""
-    sig.update(chord_size=fields.get('chord_size'))
-    return sig
-
-
 @abstract.CallableSignature.register
 @python_2_unicode_compatible
 class Signature(dict):
@@ -178,13 +172,14 @@ class Signature(dict):
         else:
             self._type = task
 
-        init(self,
-             task=task_name, args=tuple(args or ()),
-             kwargs=kwargs or {},
-             options=dict(options or {}, **ex),
-             subtask_type=subtask_type,
-             immutable=immutable,
-             chord_size=None)
+        init(
+            self,
+            task=task_name, args=tuple(args or ()),
+            kwargs=kwargs or {},
+            options=dict(options or {}, **ex),
+            subtask_type=subtask_type,
+            immutable=immutable,
+        )
 
     def __call__(self, *partial_args, **partial_kwargs):
         args, kwargs, _ = self._merge(partial_args, partial_kwargs, None)
@@ -216,7 +211,6 @@ class Signature(dict):
         s = Signature.from_dict({'task': self.task, 'args': tuple(args),
                                  'kwargs': kwargs, 'options': deepcopy(opts),
                                  'subtask_type': self.subtask_type,
-                                 'chord_size': self.chord_size,
                                  'immutable': self.immutable}, app=self._app)
         s._type = self._type
         return s
@@ -410,7 +404,6 @@ class Signature(dict):
     kwargs = _getitem_property('kwargs')
     options = _getitem_property('options')
     subtask_type = _getitem_property('subtask_type')
-    chord_size = _getitem_property('chord_size')
     immutable = _getitem_property('immutable')
 
 
@@ -594,7 +587,7 @@ class chain(Signature):
                 tasks = d['kwargs']['tasks'] = list(tasks)
             # First task must be signature object to get app
             tasks[0] = maybe_signature(tasks[0], app=app)
-        return _upgrade(d, chain(*tasks, app=app, **d['options']))
+        return chain(*tasks, app=app, **d['options'])
 
     @property
     def app(self):
@@ -630,9 +623,7 @@ class _basemap(Signature):
 
     @classmethod
     def from_dict(cls, d, app=None):
-        return _upgrade(
-            d, cls(*cls._unpack_args(d['kwargs']), app=app, **d['options']),
-        )
+        return cls(*cls._unpack_args(d['kwargs']), app=app, **d['options'])
 
 
 @Signature.register_type
@@ -670,10 +661,7 @@ class chunks(Signature):
 
     @classmethod
     def from_dict(self, d, app=None):
-        return _upgrade(
-            d, chunks(*self._unpack_args(
-                d['kwargs']), app=app, **d['options']),
-        )
+        return chunks(*self._unpack_args(d['kwargs']), app=app, **d['options'])
 
     def apply_async(self, args=(), kwargs={}, **opts):
         return self.group().apply_async(
@@ -705,7 +693,7 @@ def _maybe_group(tasks, app):
     elif isinstance(tasks, abstract.CallableSignature):
         tasks = [tasks]
     else:
-        tasks = [signature(t, app=app) for t in tasks]
+        tasks = regen(signature(t, app=app) for t in tasks)
     return tasks
 
 
@@ -728,45 +716,29 @@ class group(Signature):
 
     @classmethod
     def from_dict(self, d, app=None):
-        return _upgrade(
-            d, group(d['kwargs']['tasks'], app=app, **d['options']),
-        )
+        return group(d['kwargs']['tasks'], app=app, **d['options'])
 
     def __len__(self):
         return len(self.tasks)
 
-    def _prepared(self, tasks, partial_args, group_id, root_id, app,
-                  CallableSignature=abstract.CallableSignature,
-                  from_dict=Signature.from_dict,
-                  isinstance=isinstance, tuple=tuple):
-        for task in tasks:
-            if isinstance(task, CallableSignature):
-                # local sigs are always of type Signature, and we
-                # clone them to make sure we do not modify the originals.
-                task = task.clone()
-            else:
-                # serialized sigs must be converted to Signature.
-                task = from_dict(task, app=app)
-            if isinstance(task, group):
-                # needs yield_from :(
-                unroll = task._prepared(
-                    task.tasks, partial_args, group_id, root_id, app,
-                )
-                for taskN, resN in unroll:
-                    yield taskN, resN
-            else:
-                if partial_args and not task.immutable:
-                    task.args = tuple(partial_args) + tuple(task.args)
-                yield task, task.freeze(group_id=group_id, root_id=root_id)
+    def _prepared(self, tasks, partial_args, group_id, root_id, app):
+        for task in self._unroll_tasks(tasks, app=app, clone=True):
+            if partial_args and not task.immutable:
+                task.args = tuple(partial_args) + tuple(task.args)
+            yield task, task.freeze(group_id=group_id, root_id=root_id)
 
     def _apply_tasks(self, tasks, producer=None, app=None, p=None,
                      add_to_parent=None, chord=None, **options):
         app = app or self.app
         with app.producer_or_acquire(producer) as producer:
-            for sig, res in tasks:
+            for count, (curr_task, next_task) in enumerate(lookahead(tasks)):
+                sig, res = curr_task
+                ichord = sig.options.get('chord') or chord
+                if ichord is not None and next_task is None:
+                    # last task in chord: set the length *before* applying
+                    app.backend.set_chord_size(sig.options['group_id'], count)
                 sig.apply_async(producer=producer, add_to_parent=False,
-                                chord=sig.options.get('chord') or chord,
-                                **options)
+                                chord=ichord, **options)
 
                 # adding callback to result, such that it will gradually
                 # fulfill the barrier.
@@ -846,17 +818,21 @@ class group(Signature):
     def __call__(self, *partial_args, **options):
         return self.apply_async(partial_args, **options)
 
-    def _freeze_unroll(self, new_tasks, group_id, chord, root_id, parent_id):
-        stack = deque(self.tasks)
-        while stack:
-            task = maybe_signature(stack.popleft(), app=self._app).clone()
+    def _unroll_tasks(self, tasks, app=None, clone=True):
+        for task in tasks:
+            task = maybe_signature(task, app=app or self._app, clone=clone)
             if isinstance(task, group):
-                stack.extendleft(task.tasks)
+                it = task._unroll_tasks(task.tasks, app=app, clone=clone)
+                for subtask in it:
+                    yield subtask
             else:
-                new_tasks.append(task)
-                yield task.freeze(group_id=group_id,
-                                  chord=chord, root_id=root_id,
-                                  parent_id=parent_id)
+                yield task
+
+    def _freeze_tasks(self, tasks, group_id, chord, root_id, parent_id):
+        for task in tasks:
+            yield task.freeze(group_id=group_id,
+                              chord=chord, root_id=root_id,
+                              parent_id=parent_id)
 
     def freeze(self, _id=None, group_id=None, chord=None,
                root_id=None, parent_id=None):
@@ -871,16 +847,11 @@ class group(Signature):
             opts['chord'] = chord
         root_id = opts.setdefault('root_id', root_id)
         parent_id = opts.setdefault('parent_id', parent_id)
-        new_tasks = []
-        # Need to unroll subgroups early so that chord gets the
-        # right result instance for chord_unlock etc.
-        results = list(self._freeze_unroll(
-            new_tasks, group_id, chord, root_id, parent_id,
-        ))
-        if isinstance(self.tasks, MutableSequence):
-            self.tasks[:] = new_tasks
-        else:
-            self.tasks = new_tasks
+        tasks1, tasks2 = tee(self._unroll_tasks(self.tasks, clone=True))
+        results = regen(
+            self._freeze_tasks(tasks1, group_id, chord, root_id, parent_id))
+        # if self.tasks is consumed, it will also populate the group results
+        self.tasks = regen(x[0] for x in zip(tasks2, results))
         return self.app.GroupResult(gid, results)
     _freeze = freeze
 
@@ -942,7 +913,7 @@ class chord(Signature):
     @classmethod
     def from_dict(self, d, app=None):
         args, d['kwargs'] = self._unpack_args(**d['kwargs'])
-        return _upgrade(d, self(*args, app=app, **d))
+        return self(*args, app=app, **d)
 
     @staticmethod
     def _unpack_args(header=None, body=None, **kwargs):
@@ -990,17 +961,8 @@ class chord(Signature):
             args=(tasks.apply(args, kwargs).get(propagate=propagate),),
         )
 
-    def _traverse_tasks(self, tasks, value=None):
-        stack = deque(list(tasks))
-        while stack:
-            task = stack.popleft()
-            if isinstance(task, group):
-                stack.extend(task.tasks)
-            else:
-                yield task if value is None else value
-
     def __length_hint__(self):
-        return sum(self._traverse_tasks(self.tasks, 1))
+        return len(list(self._unroll_tasks(self.tasks, clone=False)))
 
     def run(self, header, body, partial_args, app=None, interval=None,
             countdown=1, max_retries=None, eager=False,
@@ -1008,7 +970,6 @@ class chord(Signature):
         app = app or self._get_app(body)
         group_id = uuid()
         root_id = body.options.get('root_id')
-        body.chord_size = self.__length_hint__()
         options = dict(self.options, **options) if options else self.options
         if options:
             options.pop('task_id', None)
@@ -1070,13 +1031,16 @@ def signature(varies, *args, **kwargs):
 subtask = signature   # XXX compat
 
 
-def maybe_signature(d, app=None):
+def maybe_signature(d, app=None, clone=False):
     if d is not None:
-        if (isinstance(d, dict) and
-                not isinstance(d, abstract.CallableSignature)):
-            d = signature(d)
+        if isinstance(d, dict):
+            if not isinstance(d, abstract.CallableSignature):
+                d = signature(d)
+            elif clone:
+                d = d.clone()
+
         if app is not None:
             d._app = app
-        return d
+    return d
 
 maybe_subtask = maybe_signature  # XXX compat

+ 2 - 0
celery/tests/app/test_builtins.py

@@ -115,6 +115,7 @@ class test_group(BuiltinsCase):
     def test_task(self, current_worker_task):
         g, result = self.mock_group(self.add.s(2), self.add.s(4))
         self.task(g.tasks, result, result.id, (2,)).results
+        print('TASKS: %r' % (g.tasks,))
         g.tasks[0].clone().apply_async.assert_called_with(
             group_id=result.id, producer=self.app.producer_or_acquire(),
             add_to_parent=False,
@@ -178,5 +179,6 @@ class test_chord(BuiltinsCase):
     def test_apply_eager_with_arguments(self):
         self.app.conf.task_always_eager = True
         x = chord([self.add.s(i) for i in range(10)], body=self.xsum.s())
+        print(list(x.tasks))
         r = x.apply_async([1])
         self.assertEqual(r.get(), 55)

+ 12 - 8
celery/tests/backends/test_base.py

@@ -8,10 +8,13 @@ from contextlib import contextmanager
 from celery.exceptions import ChordError, TimeoutError
 from celery.five import items, bytes_if_py2, range
 from celery.utils import serialization
-from celery.utils.serialization import subclass_exception
-from celery.utils.serialization import find_pickleable_exception as fnpe
-from celery.utils.serialization import UnpickleableExceptionWrapper
-from celery.utils.serialization import get_pickleable_exception as gpe
+from celery.utils.serialization import (
+    subclass_exception,
+    find_pickleable_exception as fnpe,
+    UnpickleableExceptionWrapper,
+    get_pickleable_exception as gpe,
+)
+from celery.utils.functional import pass1, regen
 
 from celery import states
 from celery import group, uuid
@@ -22,7 +25,6 @@ from celery.backends.base import (
     _nulldict,
 )
 from celery.result import result_from_tuple
-from celery.utils.functional import pass1
 
 from celery.tests.case import ANY, AppCase, Case, Mock, call, patch, skip
 
@@ -84,7 +86,7 @@ class test_BaseBackend_interface(AppCase):
         self.app.tasks[unlock] = Mock()
         self.b.apply_chord(
             group(app=self.app), (), 'dakj221', None,
-            result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
+            result=regen(self.app.AsyncResult(x) for x in [1, 2, 3]),
         )
         self.assertTrue(self.app.tasks[unlock].apply_async.call_count)
 
@@ -524,12 +526,14 @@ class test_KeyValueStoreBackend(AppCase):
     def test_chord_apply_fallback(self):
         self.b.implements_incr = False
         self.b.fallback_chord_unlock = Mock()
+        res = regen(x for x in range(10))
         self.b.apply_chord(
             group(app=self.app), (), 'group_id', 'body',
-            result='result', foo=1,
+            result=res, foo=1,
         )
+        self.assertTrue(res.fully_consumed())
         self.b.fallback_chord_unlock.assert_called_with(
-            'group_id', 'body', result='result', foo=1,
+            'group_id', 'body', result=res, foo=1,
         )
 
     def test_get_missing_meta(self):

+ 3 - 1
celery/tests/backends/test_cache.py

@@ -12,6 +12,7 @@ from celery import group, signature, uuid
 from celery.backends.cache import CacheBackend, DummyClient, backends
 from celery.exceptions import ImproperlyConfigured
 from celery.five import items, bytes_if_py2, string, text_t
+from celery.utils.functional import regen
 
 from celery.tests.case import AppCase, Mock, mock, patch, skip
 
@@ -67,7 +68,8 @@ class test_CacheBackend(AppCase):
 
     def test_apply_chord(self):
         tb = CacheBackend(backend='memory://', app=self.app)
-        gid, res = uuid(), [self.app.AsyncResult(uuid()) for _ in range(3)]
+        gid = uuid()
+        res = regen(self.app.AsyncResult(uuid()) for _ in range(3))
         tb.apply_chord(group(app=self.app), (), gid, {}, result=res)
 
     @patch('celery.result.GroupResult.restore')

+ 28 - 10
celery/tests/backends/test_redis.py

@@ -278,6 +278,12 @@ class test_RedisBackend(AppCase):
         b.add_to_chord(gid, 'sig')
         b.client.incr.assert_called_with(b.get_key_for_group(gid, '.t'), 1)
 
+    def test_set_chord_size(self):
+        b = self.Backend('redis://', app=self.app)
+        gid = uuid()
+        b.set_chord_size(gid, 10)
+        b.client.set.assert_called_with(b.get_key_for_group(gid, '.s'), 10)
+
     def test_expires_is_None(self):
         b = self.Backend(expires=None, app=self.app)
         self.assertEqual(
@@ -304,7 +310,6 @@ class test_RedisBackend(AppCase):
         self.app.tasks['foobarbaz'] = task
         task.request.chord = signature(task)
         task.request.id = tid
-        task.request.chord['chord_size'] = 10
         task.request.group = 'group_id'
         return task
 
@@ -312,6 +317,8 @@ class test_RedisBackend(AppCase):
     def test_on_chord_part_return(self, restore):
         tasks = [self.create_task() for i in range(10)]
 
+        self.b.set_chord_size('group_id', 10)
+
         for i in range(10):
             self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
             self.assertTrue(self.b.client.rpush.call_count)
@@ -319,20 +326,27 @@ class test_RedisBackend(AppCase):
         self.assertTrue(self.b.client.lrange.call_count)
         jkey = self.b.get_key_for_group('group_id', '.j')
         tkey = self.b.get_key_for_group('group_id', '.t')
-        self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
+        skey = self.b.get_key_for_group('group_id', '.s')
+        self.b.client.delete.assert_has_calls([
+            call(jkey), call(tkey), call(skey),
+        ])
         self.b.client.expire.assert_has_calls([
-            call(jkey, 86400), call(tkey, 86400),
+            call(jkey, 86400), call(tkey, 86400), call(skey, 86400),
         ])
 
     def test_on_chord_part_return__success(self):
-        with self.chord_context(2) as (_, request, callback):
+        with self.chord_context(3) as (_, request, callback):
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
             callback.delay.assert_not_called()
+            self.b.set_chord_size('gid1', 3)
             self.b.on_chord_part_return(request, states.SUCCESS, 20)
-            callback.delay.assert_called_with([10, 20])
+            callback.delay.assert_not_called()
+            self.b.on_chord_part_return(request, states.SUCCESS, 30)
+            callback.delay.assert_called_with([10, 20, 30])
 
     def test_on_chord_part_return__callback_raises(self):
         with self.chord_context(1) as (_, request, callback):
+            self.b.set_chord_size('gid1', 1)
             callback.delay.side_effect = KeyError(10)
             task = self.app._tasks['add'] = Mock(name='add_task')
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
@@ -342,22 +356,27 @@ class test_RedisBackend(AppCase):
 
     def test_on_chord_part_return__ChordError(self):
         with self.chord_context(1) as (_, request, callback):
+            self.b.set_chord_size('gid1', 1)
             self.b.client.pipeline = ContextMock()
             raise_on_second_call(self.b.client.pipeline, ChordError())
-            self.b.client.pipeline.return_value.rpush().llen().get().expire(
-            ).expire().execute.return_value = (1, 1, 0, 4, 5)
+            self._set_pipeline_ret(1, 1, 0, 1, 4, 5, 6)
             task = self.app._tasks['add'] = Mock(name='add_task')
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
             task.backend.fail_from_current_stack.assert_called_with(
                 callback.id, exc=ANY,
             )
 
+    def _set_pipeline_ret(self, *args):
+            self.b.client.pipeline.return_value.rpush().llen().get().get()\
+                .expire().expire().expire()\
+                .execute.return_value = args
+
     def test_on_chord_part_return__other_error(self):
         with self.chord_context(1) as (_, request, callback):
+            self.b.set_chord_size('gid1', 1)
             self.b.client.pipeline = ContextMock()
             raise_on_second_call(self.b.client.pipeline, RuntimeError())
-            self.b.client.pipeline.return_value.rpush().llen().get().expire(
-            ).expire().execute.return_value = (1, 1, 0, 4, 5)
+            self._set_pipeline_ret(1, 1, 0, 1, 4, 5, 6)
             task = self.app._tasks['add'] = Mock(name='add_task')
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
             task.backend.fail_from_current_stack.assert_called_with(
@@ -373,7 +392,6 @@ class test_RedisBackend(AppCase):
             request.group = 'gid1'
             callback = ms.return_value = Signature('add')
             callback.id = 'id1'
-            callback['chord_size'] = size
             callback.delay = Mock(name='callback.delay')
             yield tasks, request, callback
 

+ 19 - 0
celery/tests/tasks/test_canvas.py

@@ -15,6 +15,7 @@ from celery.canvas import (
     maybe_unroll_group,
 )
 from celery.result import EagerResult
+from celery.utils.functional import _regen
 
 from celery.tests.case import (
     AppCase, ContextMock, MagicMock, Mock, depends_on_current_app,
@@ -328,6 +329,8 @@ class test_chain(CanvasCase):
         c._use_link = True
         tasks, results = c.prepare_steps((), c.tasks)
 
+        print(tasks[-2].tasks)
+
         self.assertEqual(tasks[-1].args[0], 5)
         self.assertIsInstance(tasks[-2], chord)
         self.assertEqual(len(tasks[-2].tasks), 5)
@@ -580,6 +583,14 @@ class test_group(CanvasCase):
         g = group([self.add.s(i, i) for i in range(10)])
         self.assertListEqual(list(iter(g)), g.tasks)
 
+    def test_maintains_generator(self):
+        g = group(self.add.s(x, x) for x in range(3))
+        self.assertIsInstance(g.tasks, _regen)
+        self.assertFalse(g.tasks.fully_consumed())
+        g.freeze()
+        self.assertIsInstance(g.tasks, _regen)
+        self.assertFalse(g.tasks.fully_consumed())
+
 
 class test_chord(CanvasCase):
 
@@ -656,6 +667,14 @@ class test_chord(CanvasCase):
         x.tasks = [self.add.s(2, 2)]
         x.freeze()
 
+    def test_maintains_generator(self):
+        x = chord((self.add.s(i, i) for i in range(3)), body=self.mul.s(4))
+        self.assertIsInstance(x.tasks, _regen)
+        self.assertFalse(x.tasks.fully_consumed())
+        x.freeze()
+        self.assertIsInstance(x.tasks, group)
+        self.assertFalse(x.tasks.tasks.fully_consumed())
+
 
 class test_maybe_signature(CanvasCase):
 

+ 58 - 33
celery/tests/utils/test_functional.py

@@ -10,6 +10,7 @@ from celery.utils.functional import (
     firstmethod,
     first,
     maybe_list,
+    lookahead,
     mlazy,
     padlist,
     regen,
@@ -78,6 +79,12 @@ class test_utils(Case):
         self.assertIsNone(first(predicate, range(10, 20)))
         self.assertEqual(iterations[0], 10)
 
+    def test_lookahead(self):
+        self.assertEqual(
+            list(lookahead(x for x in range(6))),
+            [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, None)],
+        )
+
     def test_maybe_list(self):
         self.assertEqual(maybe_list(1), [1])
         self.assertEqual(maybe_list([1]), [1])
@@ -98,7 +105,10 @@ class test_mlazy(Case):
 
 class test_regen(Case):
 
-    def test_regen_list(self):
+    def setup(self):
+        self.g = regen(iter(list(range(10))))
+
+    def test_list(self):
         l = [1, 2]
         r = regen(iter(l))
         self.assertIs(regen(l), l)
@@ -109,39 +119,54 @@ class test_regen(Case):
         fun, args = r.__reduce__()
         self.assertEqual(fun(*args), l)
 
-    def test_regen_gen(self):
-        g = regen(iter(list(range(10))))
-        self.assertEqual(g[7], 7)
-        self.assertEqual(g[6], 6)
-        self.assertEqual(g[5], 5)
-        self.assertEqual(g[4], 4)
-        self.assertEqual(g[3], 3)
-        self.assertEqual(g[2], 2)
-        self.assertEqual(g[1], 1)
-        self.assertEqual(g[0], 0)
-        self.assertEqual(g.data, list(range(10)))
-        self.assertEqual(g[8], 8)
-        self.assertEqual(g[0], 0)
-        g = regen(iter(list(range(10))))
-        self.assertEqual(g[0], 0)
-        self.assertEqual(g[1], 1)
-        self.assertEqual(g.data, list(range(10)))
-        g = regen(iter([1]))
-        self.assertEqual(g[0], 1)
+    def test_index(self):
+        self.assertEqual(self.g[7], 7)
+        self.assertEqual(self.g[6], 6)
+        self.assertEqual(self.g[5], 5)
+        self.assertEqual(self.g[4], 4)
+        self.assertEqual(self.g[3], 3)
+        self.assertEqual(self.g[2], 2)
+        self.assertEqual(self.g[1], 1)
+        self.assertEqual(self.g[0], 0)
+        self.assertFalse(self.g.fully_consumed())
+        self.assertEqual(self.g.data, list(range(10)))
+        self.assertTrue(self.g.fully_consumed())
+        self.assertEqual(self.g[8], 8)
+        self.assertEqual(self.g[0], 0)
+        self.assertListEqual(list(iter(self.g)), list(range(10)))
+
+    def test_index_2(self):
+        self.assertEqual(self.g[0], 0)
+        self.assertEqual(self.g[1], 1)
+        self.assertEqual(self.g.data, list(range(10)))
+
+    def test_index_error(self):
         with self.assertRaises(IndexError):
-            g[1]
-        self.assertEqual(g.data, [1])
-
-        g = regen(iter(list(range(10))))
-        self.assertEqual(g[-1], 9)
-        self.assertEqual(g[-2], 8)
-        self.assertEqual(g[-3], 7)
-        self.assertEqual(g[-4], 6)
-        self.assertEqual(g[-5], 5)
-        self.assertEqual(g[5], 5)
-        self.assertEqual(g.data, list(range(10)))
-
-        self.assertListEqual(list(iter(g)), list(range(10)))
+            self.g[11]
+        self.assertTrue(self.g.fully_consumed())
+        self.assertListEqual(list(iter(self.g)), list(range(10)))
+
+    def test_negative_index(self):
+        self.assertEqual(self.g[-1], 9)
+        self.assertEqual(self.g[-2], 8)
+        self.assertEqual(self.g[-3], 7)
+        self.assertEqual(self.g[-4], 6)
+        self.assertEqual(self.g[-5], 5)
+        self.assertEqual(self.g[5], 5)
+        self.assertEqual(self.g.data, list(range(10)))
+
+    def test_iter(self):
+        list(iter(self.g))
+        self.assertTrue(self.g.fully_consumed())
+        self.assertListEqual(list(iter(self.g)), list(range(10)))
+
+    def test_repr(self):
+        repr(self.g)
+        self.assertFalse(self.g.fully_consumed())
+
+    def test_bool(self):
+        bool(self.g)
+        self.assertFalse(self.g.fully_consumed())
 
 
 class test_head_from_fun(Case):

+ 0 - 4
celery/utils/abstract.py

@@ -101,10 +101,6 @@ class CallableSignature(CallableTask):  # pragma: no cover
     def subtask_type(self):
         pass
 
-    @abstractproperty
-    def chord_size(self):
-        pass
-
     @abstractproperty
     def immutable(self):
         pass

+ 64 - 10
celery/utils/functional.py

@@ -12,7 +12,7 @@ import sys
 
 from functools import partial
 from inspect import isfunction
-from itertools import chain, islice
+from itertools import islice, tee
 
 from kombu.utils.functional import (
     LRUCache, dictfilter, lazy, maybe_evaluate, memoize,
@@ -20,12 +20,13 @@ from kombu.utils.functional import (
 )
 from vine import promise
 
-from celery.five import UserList, getfullargspec, range
+from celery.five import UserList, getfullargspec, range, zip_longest
 
 __all__ = [
     'LRUCache', 'is_list', 'maybe_list', 'memoize', 'mlazy', 'noop',
     'first', 'firstmethod', 'chunks', 'padlist', 'mattrgetter', 'uniq',
-    'regen', 'dictfilter', 'lazy', 'maybe_evaluate', 'head_from_fun',
+    'lookahead', 'regen', 'dictfilter', 'lazy', 'maybe_evaluate',
+    'head_from_fun',
 ]
 
 IS_PY3 = sys.version_info[0] == 3
@@ -178,6 +179,22 @@ def uniq(it):
     return (seen.add(obj) or obj for obj in it if obj not in seen)
 
 
+def lookahead(it):
+    """Yield pairs of ``(current, next)`` items in ``it``.
+
+    `next` is None if `current` is the last item.
+
+    Example::
+
+        >>> list(lookahead(x for x in range(6)))
+        [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, None)]
+
+    """
+    a, b = tee(it)
+    next(b, None)
+    return zip_longest(a, b, fillvalue=None)
+
+
 def regen(it):
     """``Regen`` takes any iterable, and if the object is an
     generator it will cache the evaluated list on first access,
@@ -194,6 +211,11 @@ class _regen(UserList, list):
         self.__it = it
         self.__index = 0
         self.__consumed = []
+        self.__done = False
+
+    def fully_consumed(self):
+        """Return whether the iterator has been fully consumed."""
+        return self.__done
 
     def __reduce__(self):
         return list, (self.data,)
@@ -201,29 +223,61 @@ class _regen(UserList, list):
     def __length_hint__(self):
         return self.__it.__length_hint__()
 
+    def __len__(self):
+        # CPython iter() calls len() first on lists, so we cannot
+        # have __len__ calling __iter__.
+        if self.__done:
+            return len(self.__consumed)
+        try:
+            return self.__length_hint__()
+        except Exception:
+            return NotImplemented
+
+    def __repr__(self, ellipsis='...', sep=', '):
+        # override list.__repr__ to avoid consuming the generator
+        if self.__done:
+            return repr(self.__consumed)
+        return '[{0}]'.format(sep.join(map(repr, self.__consumed)) + ellipsis)
+
     def __iter__(self):
-        return chain(self.__consumed, self.__it)
+        return iter(self.__consumed) if self.__done else self._iter_cont()
+
+    def _iter_cont(self):
+        append = self.__consumed.append
+        for y in self.__it:
+            append(y)
+            yield y
+        self.__done = True
 
     def __getitem__(self, index):
-        if index < 0:
+        if index < 0 or self.__done:
             return self.data[index]
         try:
             return self.__consumed[index]
         except IndexError:
+            it = iter(self)
             try:
                 for i in range(self.__index, index + 1):
-                    self.__consumed.append(next(self.__it))
+                    next(it)
             except StopIteration:
                 raise IndexError(index)
             else:
                 return self.__consumed[index]
 
+    def __bool__(self, sentinel=object()):
+        # bool for list calls len() which would consume the generator:
+        # override to consume maximum of one item.
+        return (
+            len(self.__consumed) or
+            next(iter(self), sentinel) is not sentinel
+        )
+    __nonzero__ = __bool__  # XXX Py2
+
     @property
     def data(self):
-        try:
-            self.__consumed.extend(list(self.__it))
-        except StopIteration:
-            pass
+        # consume the generator
+        if not self.__done:
+            list(iter(self))
         return self.__consumed