浏览代码

99% coverage (excluding celery.concurrency.asynpool and experimental backends)

Ask Solem 9 年之前
父节点
当前提交
a28d300463
共有 49 个文件被更改,包括 1668 次插入191 次删除
  1. 3 0
      .coveragerc
  2. 6 5
      celery/apps/worker.py
  3. 2 2
      celery/backends/base.py
  4. 5 2
      celery/backends/redis.py
  5. 1 1
      celery/bin/base.py
  6. 2 2
      celery/bin/celery.py
  7. 3 2
      celery/canvas.py
  8. 2 2
      celery/events/state.py
  9. 24 16
      celery/fixups/django.py
  10. 5 4
      celery/local.py
  11. 7 5
      celery/platforms.py
  12. 0 1
      celery/result.py
  13. 37 0
      celery/tests/app/test_app.py
  14. 12 0
      celery/tests/app/test_beat.py
  15. 24 4
      celery/tests/app/test_utils.py
  16. 18 1
      celery/tests/backends/test_amqp.py
  17. 131 2
      celery/tests/backends/test_base.py
  18. 33 4
      celery/tests/backends/test_mongodb.py
  19. 171 48
      celery/tests/backends/test_redis.py
  20. 2 0
      celery/tests/bin/celery.py
  21. 58 0
      celery/tests/bin/test_base.py
  22. 47 0
      celery/tests/bin/test_celery.py
  23. 23 2
      celery/tests/bin/test_celeryd_detach.py
  24. 5 0
      celery/tests/bin/test_multi.py
  25. 6 6
      celery/tests/bin/test_worker.py
  26. 53 10
      celery/tests/case.py
  27. 5 0
      celery/tests/contrib/test_rdb.py
  28. 68 2
      celery/tests/events/test_events.py
  29. 85 4
      celery/tests/events/test_state.py
  30. 88 0
      celery/tests/fixups/test_django.py
  31. 144 2
      celery/tests/tasks/test_canvas.py
  32. 122 16
      celery/tests/tasks/test_result.py
  33. 34 0
      celery/tests/tasks/test_trace.py
  34. 106 1
      celery/tests/utils/test_functional.py
  35. 1 0
      celery/tests/utils/test_imports.py
  36. 6 0
      celery/tests/utils/test_local.py
  37. 112 6
      celery/tests/utils/test_platforms.py
  38. 9 0
      celery/tests/utils/test_saferepr.py
  39. 17 1
      celery/tests/utils/test_timer2.py
  40. 2 2
      celery/tests/utils/test_timeutils.py
  41. 44 17
      celery/tests/worker/test_autoreload.py
  42. 25 0
      celery/tests/worker/test_bootsteps.py
  43. 47 2
      celery/tests/worker/test_components.py
  44. 4 0
      celery/tests/worker/test_control.py
  45. 52 3
      celery/tests/worker/test_worker.py
  46. 8 8
      celery/utils/log.py
  47. 3 3
      celery/utils/saferepr.py
  48. 2 2
      celery/utils/timeutils.py
  49. 4 3
      celery/worker/components.py

+ 3 - 0
.coveragerc

@@ -16,3 +16,6 @@ omit =
     *celery/backends/couchdb.py
     *celery/backends/couchbase.py
     *celery/backends/cassandra.py
+    *celery/backends/riak.py
+    *celery/concurrency/asynpool.py
+    *celery/utils/debug.py

+ 6 - 5
celery/apps/worker.py

@@ -147,18 +147,19 @@ class Worker(WorkController):
         trace.setup_worker_optimizations(self.app, self.hostname)
 
     def on_start(self):
+        app = self.app
         if not self._custom_logging and self.redirect_stdouts:
-            self.app.log.redirect_stdouts(self.redirect_stdouts_level)
+            app.log.redirect_stdouts(self.redirect_stdouts_level)
 
         WorkController.on_start(self)
 
         # this signal can be used to e.g. change queues after
         # the -Q option has been applied.
         signals.celeryd_after_setup.send(
-            sender=self.hostname, instance=self, conf=self.app.conf,
+            sender=self.hostname, instance=self, conf=app.conf,
         )
 
-        if not self.app.conf.value_set_for('accept_content'):
+        if not app.conf.value_set_for('accept_content'):  # pragma: no cover
             warnings.warn(CDeprecationWarning(W_PICKLE_DEPRECATED))
 
         if self.purge:
@@ -187,7 +188,7 @@ class Worker(WorkController):
 
     def purge_messages(self):
         count = self.app.control.purge()
-        if count:
+        if count:  # pragma: no cover
             print('purge: Erased {0} {1} from the queue.\n'.format(
                 count, pluralize(count, 'message')))
 
@@ -209,7 +210,7 @@ class Worker(WorkController):
         appr = '{0}:{1:#x}'.format(app.main or '__main__', id(app))
         if not isinstance(app.loader, AppLoader):
             loader = qualname(app.loader)
-            if loader.startswith('celery.loaders'):
+            if loader.startswith('celery.loaders'):  # pragma: no cover
                 loader = loader[14:]
             appr += ' ({0})'.format(loader)
         if self.autoscale:

+ 2 - 2
celery/backends/base.py

@@ -394,7 +394,7 @@ class KeyValueStoreBackend(BaseBackend):
     implements_incr = False
 
     def __init__(self, *args, **kwargs):
-        if hasattr(self.key_t, '__func__'):
+        if hasattr(self.key_t, '__func__'):  # pragma: no cover
             self.key_t = self.key_t.__func__  # remove binding
         self._encode_prefixes()
         super(KeyValueStoreBackend, self).__init__(*args, **kwargs)
@@ -583,7 +583,7 @@ class KeyValueStoreBackend(BaseBackend):
                 )
         val = self.incr(key)
         size = len(deps)
-        if val > size:
+        if val > size:  # pragma: no cover
             logger.warning('Chord counter incremented too many times for %r',
                            gid)
         elif val == size:

+ 5 - 2
celery/backends/redis.py

@@ -39,6 +39,10 @@ REDIS_MISSING = """\
 You need to install the redis library in order to use \
 the Redis result store backend."""
 
+E_LOST = """\
+Connection to Redis lost: Retry (%s/%s) %s.\
+"""
+
 logger = get_logger(__name__)
 error = logger.error
 
@@ -137,8 +141,7 @@ class RedisBackend(KeyValueStoreBackend):
 
     def on_connection_error(self, max_retries, exc, intervals, retries):
         tts = next(intervals)
-        error('Connection to Redis lost: Retry (%s/%s) %s.',
-              retries, max_retries or 'Inf',
+        error(E_LOST, retries, max_retries or 'Inf',
               humanize_seconds(tts, 'in '))
         return tts
 

+ 1 - 1
celery/bin/base.py

@@ -95,7 +95,7 @@ from celery.utils.imports import symbol_by_name, import_from_cwd
 
 try:
     input = raw_input
-except NameError:
+except NameError:  # pragma: no cover
     pass
 
 # always enable DeprecationWarnings, so our users can see them.

+ 2 - 2
celery/bin/celery.py

@@ -740,13 +740,13 @@ class CeleryCommand(Command):
                             # is (maybe) a value for this option
                             rest.extend([value, nxt])
                             index += 1
-                    except IndexError:
+                    except IndexError:  # pragma: no cover
                         rest.append(value)
                         break
                 else:
                     break
                 index += 1
-            if argv[index:]:
+            if argv[index:]:  # pragma: no cover
                 # if there are more arguments left then divide and swap
                 # we assume the first argument in argv[i:] is the command
                 # name.

+ 3 - 2
celery/canvas.py

@@ -261,7 +261,8 @@ class Signature(dict):
     def apply_async(self, args=(), kwargs={}, route_name=None, **options):
         try:
             _apply = self._apply_async
-        except IndexError:  # no tasks for chain, etc to find type
+        except IndexError:  # pragma: no cover
+            # no tasks for chain, etc to find type
             return
         # For callbacks: extra args are prepended to the stored args.
         if args or kwargs or options:
@@ -337,7 +338,7 @@ class Signature(dict):
     def __repr__(self):
         return self.reprcall()
 
-    if JSON_NEEDS_UNICODE_KEYS:
+    if JSON_NEEDS_UNICODE_KEYS:  # pragma: no cover
         def items(self):
             for k, v in dict.items(self):
                 yield k.decode() if isinstance(k, bytes) else k, v

+ 2 - 2
celery/events/state.py

@@ -166,7 +166,7 @@ class Worker(object):
                 if drift > max_drift:
                     _warn_drift(self.hostname, drift,
                                 local_received, timestamp)
-                if local_received:
+                if local_received:  # pragma: no cover
                     hearts = len(heartbeats)
                     if hearts > hbmax - 1:
                         hb_pop(0)
@@ -218,7 +218,7 @@ class Task(object):
         'timestamp', 'runtime', 'traceback', 'exchange', 'routing_key',
         'clock', 'client', 'root_id', 'parent_id',
     )
-    if not PYPY:
+    if not PYPY:  # pragma: no cover
         __slots__ = ('__dict__', '__weakref__')
 
     #: How to merge out of order events.

+ 24 - 16
celery/fixups/django.py

@@ -15,7 +15,7 @@ from celery.exceptions import FixupWarning
 
 if sys.version_info[0] < 3 and not hasattr(sys, 'pypy_version_info'):
     from StringIO import StringIO
-else:
+else:  # pragma: no cover
     from io import StringIO
 
 
@@ -66,12 +66,16 @@ class DjangoFixup(object):
         signals.worker_init.connect(self.on_worker_init)
         return self
 
-    @cached_property
+    @property
     def worker_fixup(self):
         if self._worker_fixup is None:
             self._worker_fixup = DjangoWorkerFixup(self.app)
         return self._worker_fixup
 
+    @worker_fixup.setter
+    def worker_fixup(self, value):
+        self._worker_fixup = value
+
     def on_import_modules(self, **kwargs):
         # call django.setup() before task modules are imported
         self.worker_fixup.validate_models()
@@ -160,36 +164,40 @@ class DjangoWorkerFixup(object):
             _oracle_database_errors
         )
 
-    def validate_models(self):
+    def django_setup(self):
         import django
         try:
             django_setup = django.setup
-        except AttributeError:
+        except AttributeError:  # pragma: no cover
             pass
         else:
             django_setup()
-        s = StringIO()
+
+    def validate_models(self):
+        self.django_setup()
         try:
             from django.core.management.validation import get_validation_errors
         except ImportError:
-            from django.core.management.base import BaseCommand
-            cmd = BaseCommand()
-            try:
-                # since django 1.5
-                from django.core.management.base import OutputWrapper
-                cmd.stdout = OutputWrapper(sys.stdout)
-                cmd.stderr = OutputWrapper(sys.stderr)
-            except ImportError:
-                cmd.stdout, cmd.stderr = sys.stdout, sys.stderr
-
-            cmd.check()
+            self._validate_models_django17()
         else:
+            s = StringIO()
             num_errors = get_validation_errors(s, None)
             if num_errors:
                 raise RuntimeError(
                     'One or more Django models did not validate:\n{0}'.format(
                         s.getvalue()))
 
+    def _validate_models_django17(self):
+        from django.core.management import base
+        print(base)
+        cmd = base.BaseCommand()
+        try:
+            cmd.stdout = base.OutputWrapper(sys.stdout)
+            cmd.stderr = base.OutputWrapper(sys.stderr)
+        except ImportError:  # before django 1.5
+            cmd.stdout, cmd.stderr = sys.stdout, sys.stderr
+        cmd.check()
+
     def install(self):
         signals.beat_embedded_init.connect(self.close_database)
         signals.worker_ready.connect(self.on_worker_ready)

+ 5 - 4
celery/local.py

@@ -99,9 +99,10 @@ class Proxy(object):
         loc = object.__getattribute__(self, '_Proxy__local')
         if not hasattr(loc, '__release_local__'):
             return loc(*self.__args, **self.__kwargs)
-        try:
+        try:  # pragma: no cover
+            # not sure what this is about
             return getattr(loc, self.__name__)
-        except AttributeError:
+        except AttributeError:  # pragma: no cover
             raise RuntimeError('no object bound to {0.__name__}'.format(self))
 
     @property
@@ -286,7 +287,7 @@ class Proxy(object):
     def __reduce__(self):
         return self._get_current_object().__reduce__()
 
-    if not PY3:
+    if not PY3:  # pragma: no cover
         def __cmp__(self, other):
             return cmp(self._get_current_object(), other)  # noqa
 
@@ -361,7 +362,7 @@ class PromiseProxy(Proxy):
                 finally:
                     try:
                         object.__delattr__(self, '__pending__')
-                    except AttributeError:
+                    except AttributeError:  # pragma: no cover
                         pass
             return thing
 

+ 7 - 5
celery/platforms.py

@@ -21,10 +21,6 @@ import warnings
 
 from collections import namedtuple
 
-try:
-    from billiard.process import current_process
-except ImportError:
-    current_process = None
 from billiard.compat import get_fdmax, close_open_fds
 # fileno used to be in this module
 from kombu.utils import maybe_fileno
@@ -34,6 +30,11 @@ from contextlib import contextmanager
 from .local import try_import
 from .five import items, reraise, string_t
 
+try:
+    from billiard.process import current_process
+except ImportError:  # pragma: no cover
+    current_process = None
+
 _setproctitle = try_import('setproctitle')
 resource = try_import('resource')
 pwd = try_import('pwd')
@@ -340,7 +341,8 @@ class DaemonContext(object):
     def _detach(self):
         if os.fork() == 0:      # first child
             os.setsid()         # create new session
-            if os.fork() > 0:   # second child
+            if os.fork() > 0:   # pragma: no cover
+                # second child
                 os._exit(0)
         else:
             os._exit(0)

+ 0 - 1
celery/result.py

@@ -9,7 +9,6 @@
 from __future__ import absolute_import
 
 import time
-import warnings
 
 from collections import OrderedDict, deque
 from contextlib import contextmanager

+ 37 - 0
celery/tests/app/test_app.py

@@ -303,6 +303,43 @@ class test_App(AppCase):
             self.assertEqual(app.conf.broker_url, 'foo://bar')
             self.assertEqual(app.conf.result_backend, 'foo')
 
+    def test_pending_configuration__compat_settings_mixing(self):
+        with self.Celery(broker='foo://bar', backend='foo') as app:
+            app.conf.update(
+                CELERY_ALWAYS_EAGER=4,
+                CELERY_DEFAULT_DELIVERY_MODE=63,
+                CELERYD_AGENT='foo:Barz',
+                worker_consumer='foo:Fooz',
+            )
+            with self.assertRaises(ImproperlyConfigured):
+                self.assertEqual(app.conf.task_always_eager, 4)
+
+    def test_pending_configuration__compat_settings_mixing_new(self):
+        with self.Celery(broker='foo://bar', backend='foo') as app:
+            app.conf.update(
+                task_always_eager=4,
+                task_default_delivery_mode=63,
+                worker_agent='foo:Barz',
+                CELERYD_CONSUMER='foo:Fooz',
+                CELERYD_AUTOSCALER='foo:Xuzzy',
+            )
+            with self.assertRaises(ImproperlyConfigured):
+                self.assertEqual(app.conf.worker_consumer, 'foo:Fooz')
+
+    def test_pending_configuration__compat_settings_mixing_alt(self):
+        with self.Celery(broker='foo://bar', backend='foo') as app:
+            app.conf.update(
+                task_always_eager=4,
+                task_default_delivery_mode=63,
+                worker_agent='foo:Barz',
+                CELERYD_CONSUMER='foo:Fooz',
+                worker_consumer='foo:Fooz',
+                CELERYD_AUTOSCALER='foo:Xuzzy',
+                worker_autoscaler='foo:Xuzzy'
+            )
+            self.assertEqual(app.conf.task_always_eager, 4)
+            self.assertEqual(app.conf.worker_autoscaler, 'foo:Xuzzy')
+
     def test_pending_configuration__setdefault(self):
         with self.Celery(broker='foo://bar') as app:
             app.conf.setdefault('worker_agent', 'foo:Bar')

+ 12 - 0
celery/tests/app/test_beat.py

@@ -83,6 +83,18 @@ class test_ScheduleEntry(AppCase):
         entry = self.create_entry()
         self.assertIn('<Entry:', repr(entry))
 
+    def test_reduce(self):
+        entry = self.create_entry(schedule=timedelta(seconds=10))
+        fun, args = entry.__reduce__()
+        res = fun(*args)
+        self.assertEqual(res.schedule, entry.schedule)
+
+    def test_lt(self):
+        e1 = self.create_entry(schedule=timedelta(seconds=10))
+        e2 = self.create_entry(schedule=timedelta(seconds=2))
+        self.assertLess(e2, e1)
+        self.assertTrue(e1 < object())
+
     def test_update(self):
         entry = self.create_entry()
         self.assertEqual(entry.schedule, timedelta(seconds=10))

+ 24 - 4
celery/tests/app/test_utils.py

@@ -7,10 +7,8 @@ from celery.app.utils import Settings, filter_hidden_settings, bugreport
 from celery.tests.case import AppCase, Mock
 
 
-class TestSettings(AppCase):
-    """
-    Tests of celery.app.utils.Settings
-    """
+class test_Settings(AppCase):
+
     def test_is_mapping(self):
         """Settings should be a collections.Mapping"""
         self.assertTrue(issubclass(Settings, Mapping))
@@ -19,6 +17,28 @@ class TestSettings(AppCase):
         """Settings should be a collections.MutableMapping"""
         self.assertTrue(issubclass(Settings, MutableMapping))
 
+    def test_find(self):
+        self.assertTrue(self.app.conf.find_option('always_eager'))
+
+    def test_get_by_parts(self):
+        self.app.conf.task_do_this_and_that = 303
+        self.assertEqual(
+            self.app.conf.get_by_parts('task', 'do', 'this', 'and', 'that'),
+            303,
+        )
+
+    def test_find_value_for_key(self):
+        self.assertEqual(
+            self.app.conf.find_value_for_key('always_eager'),
+            False,
+        )
+
+    def test_table(self):
+        self.assertTrue(self.app.conf.table(with_defaults=True))
+        self.assertTrue(self.app.conf.table(with_defaults=False))
+        self.assertTrue(self.app.conf.table(censored=False))
+        self.assertTrue(self.app.conf.table(censored=True))
+
 
 class test_filter_hidden_settings(AppCase):
 

+ 18 - 1
celery/tests/backends/test_amqp.py

@@ -33,6 +33,20 @@ class test_AMQPBackend(AppCase):
         opts = dict(dict(serializer='pickle', persistent=True), **opts)
         return AMQPBackend(self.app, **opts)
 
+    def test_destination_for(self):
+        b = self.create_backend()
+        request = Mock()
+        self.assertTupleEqual(
+            b.destination_for('id', request),
+            (b.rkey('id'), request.correlation_id),
+        )
+
+    def test_store_result__no_routing_key(self):
+        b = self.create_backend()
+        b.destination_for = Mock()
+        b.destination_for.return_value = None, None
+        b.store_result('id', None, states.SUCCESS)
+
     def test_mark_as_done(self):
         tb1 = self.create_backend(max_cached_results=1)
         tb2 = self.create_backend(max_cached_results=1)
@@ -268,8 +282,11 @@ class test_AMQPBackend(AppCase):
         with self.app.pool.acquire_channel(block=False) as (_, channel):
             binding = b._create_binding(uuid())
             consumer = b.Consumer(channel, binding, no_ack=True)
+            callback = Mock()
             with self.assertRaises(socket.timeout):
-                b.drain_events(Connection(), consumer, timeout=0.1)
+                b.drain_events(Connection(), consumer, timeout=0.1,
+                               on_interval=callback)
+                callback.assert_called_with()
 
     def test_get_many(self):
         b = self.create_backend(max_cached_results=10)

+ 131 - 2
celery/tests/backends/test_base.py

@@ -5,7 +5,7 @@ import types
 
 from contextlib import contextmanager
 
-from celery.exceptions import ChordError
+from celery.exceptions import ChordError, TimeoutError
 from celery.five import items, range
 from celery.utils import serialization
 from celery.utils.serialization import subclass_exception
@@ -19,11 +19,13 @@ from celery.backends.base import (
     BaseBackend,
     KeyValueStoreBackend,
     DisabledBackend,
+    _nulldict,
 )
 from celery.result import result_from_tuple
 from celery.utils import uuid
+from celery.utils.functional import pass1
 
-from celery.tests.case import AppCase, Mock, SkipTest, patch
+from celery.tests.case import ANY, AppCase, Case, Mock, SkipTest, call, patch
 
 
 class wrapobject(object):
@@ -40,6 +42,15 @@ Impossible = subclass_exception('Impossible', object, 'foo.module')
 Lookalike = subclass_exception('Lookalike', wrapobject, 'foo.module')
 
 
+class test_nulldict(Case):
+
+    def test_nulldict(self):
+        x = _nulldict()
+        x['foo'] = 1
+        x.update(foo=1, bar=2)
+        x.setdefault('foo', 3)
+
+
 class test_serialization(AppCase):
 
     def test_create_exception_cls(self):
@@ -247,6 +258,69 @@ class test_BaseBackend_dict(AppCase):
         self.assertTrue(b.is_cached('foo'))
         self.assertFalse(b.is_cached('false'))
 
+    def test_mark_as_done__chord(self):
+        b = BaseBackend(app=self.app)
+        b._store_result = Mock()
+        request = Mock(name='request')
+        b.on_chord_part_return = Mock()
+        b.mark_as_done('id', 10, request=request)
+        b.on_chord_part_return.assert_called_with(request, states.SUCCESS, 10)
+
+    def test_mark_as_failure__chord(self):
+        b = BaseBackend(app=self.app)
+        b._store_result = Mock()
+        request = Mock(name='request')
+        b.on_chord_part_return = Mock()
+        exc = KeyError()
+        b.mark_as_failure('id', exc, request=request)
+        b.on_chord_part_return.assert_called_with(request, states.FAILURE, exc)
+
+    def test_mark_as_revoked__chord(self):
+        b = BaseBackend(app=self.app)
+        b._store_result = Mock()
+        request = Mock(name='request')
+        b.on_chord_part_return = Mock()
+        b.mark_as_revoked('id', 'revoked', request=request)
+        b.on_chord_part_return.assert_called_with(request, states.REVOKED, ANY)
+
+    def test_chord_error_from_stack_raises(self):
+        b = BaseBackend(app=self.app)
+        exc = KeyError()
+        callback = Mock(name='callback')
+        callback.options = {'link_error': []}
+        task = self.app.tasks[callback.task] = Mock()
+        b.fail_from_current_stack = Mock()
+        group = self.patch('celery.group')
+        group.side_effect = exc
+        b.chord_error_from_stack(callback, exc=ValueError())
+        task.backend.fail_from_current_stack.assert_called_with(
+            callback.id, exc=exc)
+
+    def test_exception_to_python_when_None(self):
+        b = BaseBackend(app=self.app)
+        self.assertIsNone(b.exception_to_python(None))
+
+    def test_wait_for__on_interval(self):
+        self.patch('time.sleep')
+        b = BaseBackend(app=self.app)
+        b._get_task_meta_for = Mock()
+        b._get_task_meta_for.return_value = {'status': states.PENDING}
+        callback = Mock(name='callback')
+        with self.assertRaises(TimeoutError):
+            b.wait_for(task_id='1', on_interval=callback, timeout=1)
+        callback.assert_called_with()
+
+        b._get_task_meta_for.return_value = {'status': states.SUCCESS}
+        b.wait_for(task_id='1', timeout=None)
+
+    def test_get_children(self):
+        b = BaseBackend(app=self.app)
+        b._get_task_meta_for = Mock()
+        b._get_task_meta_for.return_value = {}
+        self.assertIsNone(b.get_children('id'))
+        b._get_task_meta_for.return_value = {'children': 3}
+        self.assertEqual(b.get_children('id'), 3)
+
 
 class test_KeyValueStoreBackend(AppCase):
 
@@ -282,6 +356,17 @@ class test_KeyValueStoreBackend(AppCase):
             self.assertEqual(i, 9)
             self.assertTrue(list(self.b.get_many(list(ids))))
 
+            self.b._cache.clear()
+            callback = Mock(name='callback')
+            it = self.b.get_many(list(ids), on_message=callback)
+            for i, (got_id, got_state) in enumerate(it):
+                self.assertEqual(got_state['result'], ids[got_id])
+            self.assertEqual(i, 9)
+            self.assertTrue(list(self.b.get_many(list(ids))))
+            callback.assert_has_calls([
+                call(ANY) for id in ids
+            ])
+
     def test_get_many_times_out(self):
         tasks = [uuid() for _ in range(4)]
         self.b._cache[tasks[1]] = {'status': 'PENDING'}
@@ -302,6 +387,50 @@ class test_KeyValueStoreBackend(AppCase):
             self.b.on_chord_part_return(task.request, state, result),
         )
 
+    @patch('celery.backends.base.GroupResult')
+    @patch('celery.backends.base.maybe_signature')
+    def test_chord_part_return_restore_raises(self, maybe_signature,
+                                              GroupResult):
+        self.b.implements_incr = True
+        GroupResult.restore.side_effect = KeyError()
+        self.b.chord_error_from_stack = Mock()
+        callback = Mock(name='callback')
+        request = Mock(name='request')
+        request.group = 'gid'
+        maybe_signature.return_value = callback
+        self.b.on_chord_part_return(request, states.SUCCESS, 10)
+        self.b.chord_error_from_stack.assert_called_with(
+            callback, ANY,
+        )
+
+    @patch('celery.backends.base.GroupResult')
+    @patch('celery.backends.base.maybe_signature')
+    def test_chord_part_return_restore_empty(self, maybe_signature,
+                                             GroupResult):
+        self.b.implements_incr = True
+        GroupResult.restore.return_value = None
+        self.b.chord_error_from_stack = Mock()
+        callback = Mock(name='callback')
+        request = Mock(name='request')
+        request.group = 'gid'
+        maybe_signature.return_value = callback
+        self.b.on_chord_part_return(request, states.SUCCESS, 10)
+        self.b.chord_error_from_stack.assert_called_with(
+            callback, ANY,
+        )
+
+    def test_filter_ready(self):
+        self.b.decode_result = Mock()
+        self.b.decode_result.side_effect = pass1
+        self.assertEqual(
+            len(list(self.b._filter_ready([
+                (1, {'status': states.RETRY}),
+                (2, {'status': states.FAILURE}),
+                (3, {'status': states.SUCCESS}),
+            ]))),
+            2,
+        )
+
     @contextmanager
     def _chord_part_context(self, b):
 

+ 33 - 4
celery/tests/backends/test_mongodb.py

@@ -4,10 +4,14 @@ import datetime
 
 from pickle import loads, dumps
 
+from kombu.exceptions import EncodeError
+
 from celery import uuid
 from celery import states
 from celery.backends import mongodb as module
-from celery.backends.mongodb import MongoBackend, Bunch, pymongo
+from celery.backends.mongodb import (
+    InvalidDocument, MongoBackend, Bunch, pymongo,
+)
 from celery.exceptions import ImproperlyConfigured
 from celery.tests.case import (
     AppCase, MagicMock, Mock, SkipTest, ANY,
@@ -123,13 +127,14 @@ class test_MongoBackend(AppCase):
         self.assertEqual(mb.password, 'celerypassword')
         self.assertEqual(mb.database_name, 'another_db')
 
+        mb = MongoBackend(app=self.app, url='mongodb://')
+
     @depends_on_current_app
     def test_reduce(self):
         x = MongoBackend(app=self.app)
         self.assertTrue(loads(dumps(x)))
 
     def test_get_connection_connection_exists(self):
-
         with patch('pymongo.MongoClient') as mock_Connection:
             self.backend._connection = sentinel._connection
 
@@ -139,7 +144,6 @@ class test_MongoBackend(AppCase):
             self.assertFalse(mock_Connection.called)
 
     def test_get_connection_no_connection_host(self):
-
         with patch('pymongo.MongoClient') as mock_Connection:
             self.backend._connection = None
             self.backend.host = MONGODB_HOST
@@ -154,7 +158,6 @@ class test_MongoBackend(AppCase):
             self.assertEqual(sentinel.connection, connection)
 
     def test_get_connection_no_connection_mongodb_uri(self):
-
         with patch('pymongo.MongoClient') as mock_Connection:
             mongodb_uri = 'mongodb://%s:%d' % (MONGODB_HOST, MONGODB_PORT)
             self.backend._connection = None
@@ -230,6 +233,11 @@ class test_MongoBackend(AppCase):
         mock_collection.save.assert_called_once_with(ANY)
         self.assertEqual(sentinel.result, ret_val)
 
+        mock_collection.save.side_effect = InvalidDocument()
+        with self.assertRaises(EncodeError):
+            self.backend._store_result(
+                sentinel.task_id, sentinel.result, sentinel.status)
+
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_get_task_meta_for(self, mock_get_database):
         datetime.datetime = self._reset['datetime']
@@ -315,6 +323,9 @@ class test_MongoBackend(AppCase):
             list(ret_val.keys()),
         )
 
+        mock_collection.find_one.return_value = None
+        self.backend._restore_group(sentinel.taskset_id)
+
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_delete_group(self, mock_get_database):
         self.backend.taskmeta_collection = MONGODB_COLLECTION
@@ -387,3 +398,21 @@ class test_MongoBackend(AppCase):
             self.assertDictEqual(options, {
                 'maxPoolSize': self.backend.max_pool_size
             })
+
+
+class test_MongoBackend_no_mock(AppCase):
+
+    def test_encode_decode(self):
+        backend = MongoBackend(app=self.app)
+        data = {'foo': 1}
+        self.assertTrue(backend.decode(backend.encode(data)))
+        backend.serializer = 'bson'
+        self.assertEquals(backend.encode(data), data)
+        self.assertEquals(backend.decode(data), data)
+
+    def test_de(self):
+        backend = MongoBackend(app=self.app)
+        data = {'foo': 1}
+        self.assertTrue(backend.encode(data))
+        backend.serializer = 'bson'
+        self.assertEquals(backend.encode(data), data)

+ 171 - 48
celery/tests/backends/test_redis.py

@@ -2,21 +2,32 @@ from __future__ import absolute_import
 
 from datetime import timedelta
 
+from contextlib import contextmanager
 from pickle import loads, dumps
 
 from celery import signature
 from celery import states
-from celery import group
 from celery import uuid
+from celery.canvas import Signature
 from celery.datastructures import AttributeDict
-from celery.exceptions import ImproperlyConfigured
+from celery.exceptions import ChordError, ImproperlyConfigured
 
 from celery.tests.case import (
-    AppCase, Mock, MockCallbacks, SkipTest,
+    ANY, AppCase, ContextMock, Mock, MockCallbacks, SkipTest,
     call, depends_on_current_app, patch,
 )
 
 
+def raise_on_second_call(mock, exc, *retval):
+
+    def on_first_call(*args, **kwargs):
+        mock.side_effect = exc
+        return mock.return_value
+    mock.side_effect = on_first_call
+    if retval:
+        mock.return_value, = retval
+
+
 class Connection(object):
     connected = True
 
@@ -121,8 +132,14 @@ class test_RedisBackend(AppCase):
 
         return _RedisBackend
 
+    def get_E_LOST(self):
+        from celery.backends.redis import E_LOST
+        return E_LOST
+
     def setup(self):
         self.Backend = self.get_backend()
+        self.E_LOST = self.get_E_LOST()
+        self.b = self.Backend(app=self.app)
 
     @depends_on_current_app
     def test_reduce(self):
@@ -184,6 +201,70 @@ class test_RedisBackend(AppCase):
         })
         self.Backend(app=self.app)
 
+    @patch('celery.backends.redis.error')
+    def test_on_connection_error(self, error):
+        intervals = iter([10, 20, 30])
+        exc = KeyError()
+        self.assertEqual(
+            self.b.on_connection_error(None, exc, intervals, 1), 10,
+        )
+        error.assert_called_with(self.E_LOST, 1, 'Inf', 'in 10.00 seconds')
+        self.assertEqual(
+            self.b.on_connection_error(10, exc, intervals, 2), 20,
+        )
+        error.assert_called_with(self.E_LOST, 2, 10, 'in 20.00 seconds')
+        self.assertEqual(
+            self.b.on_connection_error(10, exc, intervals, 3), 30,
+        )
+        error.assert_called_with(self.E_LOST, 3, 10, 'in 30.00 seconds')
+
+    def test_incr(self):
+        self.b.client = Mock(name='client')
+        self.b.incr('foo')
+        self.b.client.incr.assert_called_with('foo')
+
+    def test_expire(self):
+        self.b.client = Mock(name='client')
+        self.b.expire('foo', 300)
+        self.b.client.expire.assert_called_with('foo', 300)
+
+    def test_apply_chord(self):
+        header = Mock(name='header')
+        header.results = [Mock(name='t1'), Mock(name='t2')]
+        print(self.b.apply_chord,)
+        self.b.apply_chord(
+            header, (1, 2), 'gid', None,
+            options={'max_retries': 10},
+        )
+        header.assert_called_with(1, 2, max_retries=10, task_id='gid')
+
+    def test_unpack_chord_result(self):
+        self.b.exception_to_python = Mock(name='etp')
+        decode = Mock(name='decode')
+        exc = KeyError()
+        tup = decode.return_value = (1, 'id1', states.FAILURE, exc)
+        with self.assertRaises(ChordError):
+            self.b._unpack_chord_result(tup, decode)
+        decode.assert_called_with(tup)
+        self.b.exception_to_python.assert_called_with(exc)
+
+        exc = ValueError()
+        tup = decode.return_value = (2, 'id2', states.RETRY, exc)
+        ret = self.b._unpack_chord_result(tup, decode)
+        self.b.exception_to_python.assert_called_with(exc)
+        self.assertIs(ret, self.b.exception_to_python())
+
+    def test_on_chord_part_return_no_gid_or_tid(self):
+        request = Mock(name='request')
+        request.id = request.group = None
+        self.assertIsNone(self.b.on_chord_part_return(request, 'SUCCESS', 10))
+
+    def test_ConnectionPool(self):
+        self.b.redis = Mock(name='redis')
+        self.assertIsNone(self.b._ConnectionPool)
+        self.assertIs(self.b.ConnectionPool, self.b.redis.ConnectionPool)
+        self.assertIs(self.b.ConnectionPool, self.b.redis.ConnectionPool)
+
     def test_expires_defaults_to_config(self):
         self.app.conf.result_expires = 10
         b = self.Backend(expires=None, app=self.app)
@@ -210,68 +291,110 @@ class test_RedisBackend(AppCase):
         b = self.Backend(expires=timedelta(minutes=1), app=self.app)
         self.assertEqual(b.expires, 60)
 
-    def test_apply_chord(self):
-        self.Backend(app=self.app).apply_chord(
-            group(app=self.app), (), 'group_id', {},
-            result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
-        )
-
     def test_mget(self):
-        b = self.Backend(app=self.app)
-        self.assertTrue(b.mget(['a', 'b', 'c']))
-        b.client.mget.assert_called_with(['a', 'b', 'c'])
+        self.assertTrue(self.b.mget(['a', 'b', 'c']))
+        self.b.client.mget.assert_called_with(['a', 'b', 'c'])
 
     def test_set_no_expire(self):
-        b = self.Backend(app=self.app)
-        b.expires = None
-        b.set('foo', 'bar')
+        self.b.expires = None
+        self.b.set('foo', 'bar')
+
+    def create_task(self):
+        tid = uuid()
+        task = Mock(name='task-{0}'.format(tid))
+        task.name = 'foobarbaz'
+        self.app.tasks['foobarbaz'] = task
+        task.request.chord = signature(task)
+        task.request.id = tid
+        task.request.chord['chord_size'] = 10
+        task.request.group = 'group_id'
+        return task
 
     @patch('celery.result.GroupResult.restore')
     def test_on_chord_part_return(self, restore):
-        b = self.Backend(app=self.app)
-
-        def create_task():
-            tid = uuid()
-            task = Mock(name='task-{0}'.format(tid))
-            task.name = 'foobarbaz'
-            self.app.tasks['foobarbaz'] = task
-            task.request.chord = signature(task)
-            task.request.id = tid
-            task.request.chord['chord_size'] = 10
-            task.request.group = 'group_id'
-            return task
-
-        tasks = [create_task() for i in range(10)]
+        tasks = [self.create_task() for i in range(10)]
 
         for i in range(10):
-            b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
-            self.assertTrue(b.client.rpush.call_count)
-            b.client.rpush.reset_mock()
-        self.assertTrue(b.client.lrange.call_count)
-        jkey = b.get_key_for_group('group_id', '.j')
-        tkey = b.get_key_for_group('group_id', '.t')
-        b.client.delete.assert_has_calls([call(jkey), call(tkey)])
-        b.client.expire.assert_has_calls([
+            self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
+            self.assertTrue(self.b.client.rpush.call_count)
+            self.b.client.rpush.reset_mock()
+        self.assertTrue(self.b.client.lrange.call_count)
+        jkey = self.b.get_key_for_group('group_id', '.j')
+        tkey = self.b.get_key_for_group('group_id', '.t')
+        self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
+        self.b.client.expire.assert_has_calls([
             call(jkey, 86400), call(tkey, 86400),
         ])
 
+    def test_on_chord_part_return__success(self):
+        with self.chord_context(2) as (_, request, callback):
+            self.b.on_chord_part_return(request, states.SUCCESS, 10)
+            self.assertFalse(callback.delay.called)
+            self.b.on_chord_part_return(request, states.SUCCESS, 20)
+            callback.delay.assert_called_with([10, 20])
+
+    def test_on_chord_part_return__callback_raises(self):
+        with self.chord_context(1) as (_, request, callback):
+            callback.delay.side_effect = KeyError(10)
+            task = self.app._tasks['add'] = Mock(name='add_task')
+            self.b.on_chord_part_return(request, states.SUCCESS, 10)
+            task.backend.fail_from_current_stack.assert_called_with(
+                callback.id, exc=ANY,
+            )
+
+    def test_on_chord_part_return__ChordError(self):
+        with self.chord_context(1) as (_, request, callback):
+            self.b.client.pipeline = ContextMock()
+            raise_on_second_call(self.b.client.pipeline, ChordError())
+            self.b.client.pipeline.return_value.rpush().llen().get().expire(
+            ).expire().execute.return_value = (1, 1, 0, 4, 5)
+            task = self.app._tasks['add'] = Mock(name='add_task')
+            self.b.on_chord_part_return(request, states.SUCCESS, 10)
+            task.backend.fail_from_current_stack.assert_called_with(
+                callback.id, exc=ANY,
+            )
+
+    def test_on_chord_part_return__other_error(self):
+        with self.chord_context(1) as (_, request, callback):
+            self.b.client.pipeline = ContextMock()
+            raise_on_second_call(self.b.client.pipeline, RuntimeError())
+            self.b.client.pipeline.return_value.rpush().llen().get().expire(
+            ).expire().execute.return_value = (1, 1, 0, 4, 5)
+            task = self.app._tasks['add'] = Mock(name='add_task')
+            self.b.on_chord_part_return(request, states.SUCCESS, 10)
+            task.backend.fail_from_current_stack.assert_called_with(
+                callback.id, exc=ANY,
+            )
+
+    @contextmanager
+    def chord_context(self, size=1):
+        with patch('celery.backends.redis.maybe_signature') as ms:
+            tasks = [self.create_task() for i in range(size)]
+            request = Mock(name='request')
+            request.id = 'id1'
+            request.group = 'gid1'
+            callback = ms.return_value = Signature('add')
+            callback.id = 'id1'
+            callback['chord_size'] = size
+            callback.delay = Mock(name='callback.delay')
+            yield tasks, request, callback
+
     def test_process_cleanup(self):
-        self.Backend(app=self.app).process_cleanup()
+        self.b.process_cleanup()
 
     def test_get_set_forget(self):
-        b = self.Backend(app=self.app)
         tid = uuid()
-        b.store_result(tid, 42, states.SUCCESS)
-        self.assertEqual(b.get_status(tid), states.SUCCESS)
-        self.assertEqual(b.get_result(tid), 42)
-        b.forget(tid)
-        self.assertEqual(b.get_status(tid), states.PENDING)
+        self.b.store_result(tid, 42, states.SUCCESS)
+        self.assertEqual(self.b.get_status(tid), states.SUCCESS)
+        self.assertEqual(self.b.get_result(tid), 42)
+        self.b.forget(tid)
+        self.assertEqual(self.b.get_status(tid), states.PENDING)
 
     def test_set_expires(self):
-        b = self.Backend(expires=512, app=self.app)
+        self.b = self.Backend(expires=512, app=self.app)
         tid = uuid()
-        key = b.get_key_for_task(tid)
-        b.store_result(tid, 42, states.SUCCESS)
-        b.client.expire.assert_called_with(
+        key = self.b.get_key_for_task(tid)
+        self.b.store_result(tid, 42, states.SUCCESS)
+        self.b.client.expire.assert_called_with(
             key, 512,
         )

+ 2 - 0
celery/tests/bin/celery.py

@@ -0,0 +1,2 @@
+from __future__ import absolute_import, unicode_literals
+# here for a test

+ 58 - 0
celery/tests/bin/test_base.py

@@ -236,11 +236,23 @@ class test_Command(AppCase):
         self.assertTrue(cmd.find_app('celery.tests.bin.proj.app'))
         self.assertTrue(cmd.find_app('celery.tests.bin.proj'))
         self.assertTrue(cmd.find_app('celery.tests.bin.proj:hello'))
+        self.assertTrue(cmd.find_app('celery.tests.bin.proj.hello'))
         self.assertTrue(cmd.find_app('celery.tests.bin.proj.app:app'))
+        self.assertTrue(cmd.find_app('celery.tests.bin.proj.app.app'))
+        with self.assertRaises(AttributeError):
+            cmd.find_app('celery.tests.bin')
 
         with self.assertRaises(AttributeError):
             cmd.find_app(__name__)
 
+    @patch('celery.bin.base.input')
+    def test_ask(self, input):
+        cmd = MockCommand(app=self.app)
+        input.return_value = 'yes'
+        self.assertEqual(cmd.ask('q', ('yes', 'no'), 'no'), 'yes')
+        input.return_value = 'nop'
+        self.assertEqual(cmd.ask('q', ('yes', 'no'), 'no'), 'no')
+
     def test_host_format(self):
         cmd = MockCommand(app=self.app)
         with patch('socket.gethostname') as hn:
@@ -291,6 +303,52 @@ class test_Command(AppCase):
         self.assertEqual(cmd.app.conf.worker_prefetch_multiplier, 100)
         self.assertListEqual(rest, ['--loglevel=INFO'])
 
+        cmd.app = None
+        cmd.get_app = Mock(name='get_app')
+        cmd.get_app.return_value = self.app
+        self.app.user_options['preload'] = [
+            Option('--foo', action='store_true'),
+        ]
+        cmd.setup_app_from_commandline(argv=[
+            '--foo', '--loglevel=INFO', '--',
+            'broker.url=amqp://broker.example.com',
+            '.prefetch_multiplier=100'])
+        self.assertIs(cmd.app, cmd.get_app())
+
+    def test_preparse_options__required_short(self):
+        cmd = MockCommand(app=self.app)
+        with self.assertRaises(ValueError):
+            cmd.preparse_options(
+                ['a', '-f'], [Option('-f', action='store')])
+
+    def test_preparse_options__longopt_whitespace(self):
+        cmd = MockCommand(app=self.app)
+        cmd.preparse_options(
+            ['a', '--foo', 'val'], [Option('--foo', action='store')])
+
+    def test_preparse_options__shortopt_store_true(self):
+        cmd = MockCommand(app=self.app)
+        cmd.preparse_options(
+            ['a', '--foo'], [Option('--foo', action='store_true')])
+
+    def test_get_default_app(self):
+        self.patch('celery._state.get_current_app')
+        cmd = MockCommand(app=self.app)
+        from celery._state import get_current_app
+        self.assertIs(cmd._get_default_app(), get_current_app())
+
+    def test_set_colored(self):
+        cmd = MockCommand(app=self.app)
+        cmd.colored = 'foo'
+        self.assertEqual(cmd.colored, 'foo')
+
+    def test_set_no_color(self):
+        cmd = MockCommand(app=self.app)
+        cmd.no_color = False
+        _ = cmd.colored  # noqa
+        cmd.no_color = True
+        self.assertFalse(cmd.colored.enabled)
+
     def test_find_app(self):
         cmd = MockCommand(app=self.app)
         with patch('celery.bin.base.symbol_by_name') as sbn:

+ 47 - 0
celery/tests/bin/test_celery.py

@@ -9,6 +9,7 @@ from kombu.utils.json import dumps
 from celery import __main__
 from celery.platforms import EX_FAILURE, EX_USAGE, EX_OK
 from celery.bin.base import Error
+from celery.bin import celery as mod
 from celery.bin.celery import (
     Command,
     list_,
@@ -179,6 +180,13 @@ class test_purge(AppCase):
         a.run(force=True)
         self.assertIn('100 messages', out.getvalue())
 
+        a.out = Mock(name='out')
+        a.ask = Mock(name='ask')
+        a.run(force=False)
+        a.ask.assert_called_with(a.warn_prompt, ('yes', 'no'), 'no')
+        a.ask.return_value = 'yes'
+        a.run(force=False)
+
 
 class test_result(AppCase):
 
@@ -303,6 +311,20 @@ class test_CeleryCommand(AppCase):
             x = CeleryCommand(app=self.app)
             x.load_extension_commands()
 
+    def test_load_extensions_commands(self):
+        with patch('celery.bin.celery.Extensions') as Ext:
+            prev, mod.command_classes = list(mod.command_classes), Mock()
+            try:
+                ext = Ext.return_value = Mock(name='Extension')
+                ext.load.return_value = ['foo', 'bar']
+                x = CeleryCommand(app=self.app)
+                x.load_extension_commands()
+                mod.command_classes.append.assert_called_with(
+                    ('Extensions', ['foo', 'bar'], 'magenta'),
+                )
+            finally:
+                mod.command_classes = prev
+
     def test_determine_exit_status(self):
         self.assertEqual(determine_exit_status('true'), EX_OK)
         self.assertEqual(determine_exit_status(''), EX_FAILURE)
@@ -327,6 +349,15 @@ class test_CeleryCommand(AppCase):
             ['foo', '--foo=1'],
         )
 
+    def test_register_command(self):
+        prev, CeleryCommand.commands = dict(CeleryCommand.commands), {}
+        try:
+            fun = Mock(name='fun')
+            CeleryCommand.register_command(fun, name='foo')
+            self.assertIs(CeleryCommand.commands['foo'], fun)
+        finally:
+            CeleryCommand.commands = prev
+
     def test_handle_argv(self):
         x = CeleryCommand(app=self.app)
         x.execute = Mock()
@@ -457,6 +488,10 @@ class test_inspect(AppCase):
         callback({'foo': {'ok': 'pong'}})
         self.assertIn('OK', out.getvalue())
 
+        with patch('celery.bin.celery.json.dumps') as dumps:
+            i.run('ping', json=True)
+            self.assertTrue(dumps.called)
+
         instance = real.return_value = Mock()
         instance.ping.return_value = None
         with self.assertRaises(Error):
@@ -468,6 +503,18 @@ class test_inspect(AppCase):
         i.say_chat('<-', 'hello')
         self.assertFalse(out.getvalue())
 
+    def test_objgraph(self):
+        i = inspect(app=self.app)
+        i.call = Mock(name='call')
+        i.objgraph('Message', foo=1)
+        i.call.assert_called_with('objgraph', 'Message', foo=1)
+
+    def test_conf(self):
+        i = inspect(app=self.app)
+        i.call = Mock(name='call')
+        i.conf(with_defaults=True, foo=1)
+        i.call.assert_called_with('conf', True, foo=1)
+
 
 class test_control(AppCase):
 

+ 23 - 2
celery/tests/bin/test_celeryd_detach.py

@@ -30,6 +30,11 @@ if not IS_WINDOWS:
             )
             execv.assert_called_with('/bin/boo', ['/bin/boo', 'a', 'b', 'c'])
 
+            r = detach('/bin/boo', ['a', 'b', 'c'],
+                       logfile='/var/log', pidfile='/var/pid',
+                       executable='/bin/foo', app=self.app)
+            execv.assert_called_with('/bin/foo', ['/bin/foo', 'a', 'b', 'c'])
+
             execv.side_effect = Exception('foo')
             r = detach('/bin/boo', ['a', 'b', 'c'],
                        logfile='/var/log', pidfile='/var/pid', app=self.app)
@@ -38,17 +43,33 @@ if not IS_WINDOWS:
             setup_logs.assert_called_with('ERROR', '/var/log')
             self.assertEqual(r, 1)
 
+            self.patch('celery.current_app')
+            from celery import current_app
+            r = detach('/bin/boo', ['a', 'b', 'c'],
+                       logfile='/var/log', pidfile='/var/pid', app=None)
+            current_app.log.setup_logging_subsystem.assert_called_with(
+                'ERROR', '/var/log',
+            )
+
 
 class test_PartialOptionParser(AppCase):
 
     def test_parser(self):
         x = detached_celeryd(self.app)
         p = x.Parser('celeryd_detach')
-        options, values = p.parse_args(['--logfile=foo', '--fake', '--enable',
-                                        'a', 'b', '-c1', '-d', '2'])
+        options, values = p.parse_args([
+            '--logfile=foo', '--fake', '--enable',
+            'a', 'b', '-c1', '-d', '2',
+        ])
         self.assertEqual(options.logfile, 'foo')
         self.assertEqual(values, ['a', 'b'])
         self.assertEqual(p.leftovers, ['--enable', '-c1', '-d', '2'])
+        options, values = p.parse_args([
+            '--fake', '--enable',
+            '--pidfile=/var/pid/foo.pid',
+            'a', 'b', '-c1', '-d', '2',
+        ])
+        self.assertEqual(options.pidfile, '/var/pid/foo.pid')
 
         with override_stdouts():
             with self.assertRaises(SystemExit):

+ 5 - 0
celery/tests/bin/test_multi.py

@@ -165,6 +165,11 @@ class test_MultiTool(AppCase):
         self.t.note('hello world')
         self.assertFalse(self.fh.getvalue())
 
+    def test_carp(self):
+        self.t.say = Mock()
+        self.t.carp('foo')
+        self.t.say.assert_called_with('foo', True, self.t.stderr)
+
     def test_info(self):
         self.t.verbose = True
         self.t.info('hello info')

+ 6 - 6
celery/tests/bin/test_worker.py

@@ -185,13 +185,13 @@ class test_Worker(WorkerAppCase):
 
         prev_loader = self.app.loader
         worker = self.Worker(app=self.app, queues='foo,bar,baz,xuzzy,do,re,mi')
-        self.app.loader = Mock()
-        self.app.loader.__module__ = 'acme.baked_beans'
-        self.assertTrue(worker.startup_info())
+        with patch('celery.apps.worker.qualname') as qualname:
+            qualname.return_value = 'acme.backed_beans.Loader'
+            self.assertTrue(worker.startup_info())
 
-        self.app.loader = Mock()
-        self.app.loader.__module__ = 'celery.loaders.foo'
-        self.assertTrue(worker.startup_info())
+        with patch('celery.apps.worker.qualname') as qualname:
+            qualname.return_value = 'celery.loaders.Loader'
+            self.assertTrue(worker.startup_info())
 
         from celery.loaders.app import AppLoader
         self.app.loader = AppLoader(app=self.app)

+ 53 - 10
celery/tests/case.py

@@ -315,10 +315,44 @@ class Case(unittest.TestCase):
         self.addCleanup(manager.stop)
         return patched
 
-    def mock_modules(self, *modules):
-        manager = mock_module(*modules)
-        manager.__enter__()
-        self.addCleanup(partial(manager.__exit__, None, None, None))
+    def mock_modules(self, *mods):
+        modules = []
+        for mod in mods:
+            mod = mod.split('.')
+            modules.extend(reversed([
+                '.'.join(mod[:-i] if i else mod) for i in range(len(mod))
+            ]))
+        modules = sorted(set(modules))
+        return self.wrap_context(mock_module(*modules))
+
+    def on_nth_call_do(self, mock, side_effect, n=1):
+
+        def on_call(*args, **kwargs):
+            if mock.call_count >= n:
+                mock.side_effect = side_effect
+            return mock.return_value
+        mock.side_effect = on_call
+        return mock
+
+    def on_nth_call_return(self, mock, retval, n=1):
+
+        def on_call(*args, **kwargs):
+            if mock.call_count >= n:
+                mock.return_value = retval
+            return mock.return_value
+        mock.side_effect = on_call
+        return mock
+
+    def mask_modules(self, *modules):
+        self.wrap_context(mask_modules(*modules))
+
+    def wrap_context(self, context):
+        ret = context.__enter__()
+        self.addCleanup(partial(context.__exit__, None, None, None))
+        return ret
+
+    def mock_environ(self, env_name, env_value):
+        return self.wrap_context(mock_environ(env_name, env_value))
 
     def assertWarns(self, expected_warning):
         return _AssertWarnsContext(expected_warning, self, None)
@@ -543,19 +577,28 @@ def wrap_logger(logger, loglevel=logging.ERROR):
         logger.handlers = old_handlers
 
 
+@contextmanager
+def mock_environ(env_name, env_value):
+    sentinel = object()
+    prev_val = os.environ.get(env_name, sentinel)
+    os.environ[env_name] = env_value
+    try:
+        yield env_value
+    finally:
+        if prev_val is sentinel:
+            os.environ.pop(env_name, None)
+        else:
+            os.environ[env_name] = prev_val
+
+
 def with_environ(env_name, env_value):
 
     def _envpatched(fun):
 
         @wraps(fun)
         def _patch_environ(*args, **kwargs):
-            prev_val = os.environ.get(env_name)
-            os.environ[env_name] = env_value
-            try:
+            with mock_environ(env_name, env_value):
                 return fun(*args, **kwargs)
-            finally:
-                os.environ[env_name] = prev_val or ''
-
         return _patch_environ
     return _envpatched
 

+ 5 - 0
celery/tests/contrib/test_rdb.py

@@ -58,6 +58,11 @@ class test_Rdb(AppCase):
 
             # _close_session
             rdb._close_session()
+            rdb.active = True
+            rdb._handle = None
+            rdb._client = None
+            rdb._sock = None
+            rdb._close_session()
 
             # do_continue
             rdb.set_continue = Mock()

+ 68 - 2
celery/tests/events/test_events.py

@@ -2,11 +2,13 @@ from __future__ import absolute_import
 
 import socket
 
-from celery.events import Event
-from celery.tests.case import AppCase, Mock
+from celery.events import CLIENT_CLOCK_SKEW, Event
+
+from celery.tests.case import AppCase, Mock, call
 
 
 class MockProducer(object):
+
     raise_on_publish = False
 
     def __init__(self, *args, **kwargs):
@@ -93,6 +95,44 @@ class test_EventDispatcher(AppCase):
 
         eventer.flush()
 
+    def test_send_buffer_group(self):
+        buf_received = [None]
+        producer = MockProducer()
+        producer.connection = self.app.connection()
+        connection = Mock()
+        connection.transport.driver_type = 'amqp'
+        eventer = self.app.events.Dispatcher(
+            connection, enabled=False,
+            buffer_group={'task'}, buffer_limit=2,
+        )
+        eventer.producer = producer
+        eventer.enabled = True
+        eventer._publish = Mock(name='_publish')
+
+        def on_eventer_publish(events, *args, **kwargs):
+            buf_received[0] = list(events)
+        eventer._publish.side_effect = on_eventer_publish
+        self.assertFalse(eventer._group_buffer['task'])
+        eventer.on_send_buffered = Mock(name='on_send_buffered')
+        eventer.send('task-received', uuid=1)
+        prev_buffer = eventer._group_buffer['task']
+        self.assertTrue(eventer._group_buffer['task'])
+        eventer.on_send_buffered.assert_called_with()
+        eventer.send('task-received', uuid=1)
+        self.assertFalse(eventer._group_buffer['task'])
+        eventer._publish.assert_has_calls(
+            call([], eventer.producer, 'task.multi'),
+        )
+        # clear in place
+        self.assertIs(eventer._group_buffer['task'], prev_buffer)
+        self.assertEqual(len(buf_received[0]), 2)
+        eventer.on_send_buffered = None
+        eventer.send('task-received', uuid=1)
+
+    def test_flush_no_groups_no_errors(self):
+        eventer = self.app.events.Dispatcher(Mock())
+        eventer.flush(errors=False, groups=False)
+
     def test_enter_exit(self):
         with self.app.connection() as conn:
             d = self.app.events.Dispatcher(conn)
@@ -174,6 +214,10 @@ class test_EventReceiver(AppCase):
         r._receive(message, object())
         self.assertTrue(got_event[0])
 
+    def test_accept_argument(self):
+        r = self.app.events.Receiver(Mock(), accept={'app/foo'})
+        self.assertEqual(r.accept, {'app/foo'})
+
     def test_catch_all_event(self):
 
         message = {'type': 'world-war'}
@@ -217,6 +261,28 @@ class test_EventReceiver(AppCase):
         self.assertFalse(ts_adjust.called)
         r.adjust_clock.assert_called_with(313)
 
+    def test_event_from_message_clock_from_client(self):
+        r = self.app.events.Receiver(Mock(), node_id='celery.tests')
+        r.clock.value = 302
+        r.adjust_clock = Mock()
+
+        body = {'type': 'task-sent'}
+        r.event_from_message(
+            body, localize=False, adjust_timestamp=Mock(),
+        )
+        self.assertEqual(body['clock'], r.clock.value + CLIENT_CLOCK_SKEW)
+
+    def test_receive_multi(self):
+        r = self.app.events.Receiver(Mock(name='connection'))
+        r.process = Mock(name='process')
+        efm = r.event_from_message = Mock(name='event_from_message')
+
+        def on_efm(*args):
+            return args
+        efm.side_effect = on_efm
+        r._receive([1, 2, 3], Mock())
+        r.process.assert_has_calls([call(1), call(2), call(3)])
+
     def test_itercapture_limit(self):
         connection = self.app.connection()
         channel = connection.channel()

+ 85 - 4
celery/tests/events/test_state.py

@@ -10,11 +10,12 @@ from itertools import count
 from celery import states
 from celery.events import Event
 from celery.events.state import (
+    HEARTBEAT_EXPIRE_WINDOW,
+    HEARTBEAT_DRIFT_MAX,
     State,
     Worker,
     Task,
-    HEARTBEAT_EXPIRE_WINDOW,
-    HEARTBEAT_DRIFT_MAX,
+    heartbeat_expires,
 )
 from celery.five import range
 from celery.utils import uuid
@@ -104,6 +105,7 @@ class ev_task_states(replay):
                   traceback='line 1 at main', hostname='utest1'),
             Event('task-succeeded', uuid=tid, result='4',
                   runtime=0.1234, hostname='utest1'),
+            Event('foo-bar'),
         ]
 
 
@@ -181,6 +183,12 @@ class test_Worker(AppCase):
             hash(Worker(hostname='foo')), hash(Worker(hostname='bar')),
         )
 
+    def test_heartbeat_expires__Decimal(self):
+        self.assertEqual(
+            heartbeat_expires(Decimal(344313.37), freq=60, expire_window=200),
+            344433.37,
+        )
+
     def test_compatible_with_Decimal(self):
         w = Worker('george@vandelay.com')
         timestamp, local_received = Decimal(_float_to_decimal(time())), time()
@@ -192,6 +200,39 @@ class test_Worker(AppCase):
         })
         self.assertTrue(w.alive)
 
+    def test_eq_ne_other(self):
+        self.assertEqual(Worker('a@b.com'), Worker('a@b.com'))
+        self.assertNotEqual(Worker('a@b.com'), Worker('b@b.com'))
+        self.assertNotEqual(Worker('a@b.com'), object())
+
+    def test_reduce_direct(self):
+        w = Worker('george@vandelay.com')
+        w.event('worker-online', 10.0, 13.0, fields={
+            'hostname': 'george@vandelay.com',
+            'timestamp': 10.0,
+            'local_received': 13.0,
+            'freq': 60,
+        })
+        fun, args = w.__reduce__()
+        w2 = fun(*args)
+        self.assertEqual(w2.hostname, w.hostname)
+        self.assertEqual(w2.pid, w.pid)
+        self.assertEqual(w2.freq, w.freq)
+        self.assertEqual(w2.heartbeats, w.heartbeats)
+        self.assertEqual(w2.clock, w.clock)
+        self.assertEqual(w2.active, w.active)
+        self.assertEqual(w2.processed, w.processed)
+        self.assertEqual(w2.loadavg, w.loadavg)
+        self.assertEqual(w2.sw_ident, w.sw_ident)
+
+    def test_update(self):
+        w = Worker('george@vandelay.com')
+        w.update({'idx': '301'}, foo=1, clock=30, bah='foo')
+        self.assertEqual(w.idx, '301')
+        self.assertEqual(w.foo, 1)
+        self.assertEqual(w.clock, 30)
+        self.assertEqual(w.bah, 'foo')
+
     def test_survives_missing_timestamp(self):
         worker = Worker(hostname='foo')
         worker.event('heartbeat')
@@ -263,6 +304,12 @@ class test_Task(AppCase):
                          sorted(task.info(['args', 'kwargs']).keys()))
         self.assertFalse(list(task.info('foo')))
 
+    def test_reduce_direct(self):
+        task = Task(uuid='uuid', name='tasks.add', args='(2, 2)')
+        fun, args = task.__reduce__()
+        task2 = fun(*args)
+        self.assertEqual(task, task2)
+
     def test_ready(self):
         task = Task(uuid='abcdefg',
                     name='tasks.add')
@@ -341,6 +388,39 @@ class test_State(AppCase):
         self.assertEqual(now[1][0], tC)
         self.assertEqual(now[2][0], tA)
 
+    def test_get_or_create_task(self):
+        state = State()
+        task, created = state.get_or_create_task('id1')
+        self.assertEqual(task.uuid, 'id1')
+        self.assertTrue(created)
+        task2, created2 = state.get_or_create_task('id1')
+        self.assertIs(task2, task)
+        self.assertFalse(created2)
+
+    def test_get_or_create_worker(self):
+        state = State()
+        worker, created = state.get_or_create_worker('george@vandelay.com')
+        self.assertEqual(worker.hostname, 'george@vandelay.com')
+        self.assertTrue(created)
+        worker2, created2 = state.get_or_create_worker('george@vandelay.com')
+        self.assertIs(worker2, worker)
+        self.assertFalse(created2)
+
+    def test_get_or_create_worker__with_defaults(self):
+        state = State()
+        worker, created = state.get_or_create_worker(
+            'george@vandelay.com', pid=30,
+        )
+        self.assertEqual(worker.hostname, 'george@vandelay.com')
+        self.assertEqual(worker.pid, 30)
+        self.assertTrue(created)
+        worker2, created2 = state.get_or_create_worker(
+            'george@vandelay.com', pid=40,
+        )
+        self.assertIs(worker2, worker)
+        self.assertEqual(worker2.pid, 40)
+        self.assertFalse(created2)
+
     def test_worker_online_offline(self):
         r = ev_worker_online_offline(State())
         next(r)
@@ -478,10 +558,11 @@ class test_State(AppCase):
         r.play()
         self.assertEqual(sorted(r.state.task_types()), ['task1', 'task2'])
 
-    def test_tasks_by_timestamp(self):
+    def test_tasks_by_time(self):
         r = ev_snapshot(State())
         r.play()
-        self.assertEqual(len(list(r.state.tasks_by_timestamp())), 20)
+        self.assertEqual(len(list(r.state.tasks_by_time())), 20)
+        self.assertEqual(len(list(r.state.tasks_by_time(reverse=False))), 20)
 
     def test_tasks_by_type(self):
         r = ev_snapshot(State())

+ 88 - 0
celery/tests/fixups/test_django.py

@@ -31,6 +31,45 @@ class FixupCase(AppCase):
 class test_DjangoFixup(FixupCase):
     Fixup = DjangoFixup
 
+    def test_setting_default_app(self):
+        from celery.fixups import django
+        prev, django.default_app = django.default_app, None
+        try:
+            app = Mock(name='app')
+            DjangoFixup(app)
+            app.set_default.assert_called_with()
+        finally:
+            django.default_app = prev
+
+    @patch('celery.fixups.django.DjangoWorkerFixup')
+    def test_worker_fixup_property(self, DjangoWorkerFixup):
+        f = DjangoFixup(self.app)
+        f._worker_fixup = None
+        self.assertIs(f.worker_fixup, DjangoWorkerFixup())
+        self.assertIs(f.worker_fixup, DjangoWorkerFixup())
+
+    def test_on_import_modules(self):
+        f = DjangoFixup(self.app)
+        f.worker_fixup = Mock(name='worker_fixup')
+        f.on_import_modules()
+        f.worker_fixup.validate_models.assert_called_with()
+
+    def test_autodiscover_tasks_pre17(self):
+        self.mask_modules('django.apps')
+        f = DjangoFixup(self.app)
+        f._settings = Mock(name='_settings')
+        self.assertIs(f.autodiscover_tasks(), f._settings.INSTALLED_APPS)
+
+    @patch('django.apps.apps', create=True)
+    def test_autodiscover_tasks(self, apps):
+        f = DjangoFixup(self.app)
+        configs = [Mock(name='c1'), Mock(name='c2')]
+        apps.get_app_configs.return_value = configs
+        self.assertEqual(
+            f.autodiscover_tasks(),
+            [c.name for c in configs],
+        )
+
     def test_fixup(self):
         with patch('celery.fixups.django.DjangoFixup') as Fixup:
             with patch.dict(os.environ, DJANGO_SETTINGS_MODULE=''):
@@ -149,6 +188,11 @@ class test_DjangoWorkerFixup(FixupCase):
                             f._db.connection = None
                             f.on_worker_process_init()
 
+                            f.validate_models = Mock(name='validate_models')
+                            self.mock_environ('FORKED_BY_MULTIPROCESSING', '1')
+                            f.on_worker_process_init()
+                            f.validate_models.assert_called_with()
+
     def test_on_task_prerun(self):
         task = Mock()
         with self.fixup_context(self.app) as (f, _, _):
@@ -204,6 +248,13 @@ class test_DjangoWorkerFixup(FixupCase):
                 _close.assert_called_with()
                 self.assertEqual(f._db_recycles, 1)
 
+    def test_close_database__django16(self):
+        with self.fixup_context(self.app) as (f, _, _):
+            f._db.connections = Mock(name='db.connections')
+            f._db.connections.all.side_effect = AttributeError()
+            f._close_database()
+            f._db.close_old_connections.assert_called_with()
+
     def test__close_database(self):
         with self.fixup_context(self.app) as (f, _, _):
             conns = [Mock(), Mock(), Mock()]
@@ -245,6 +296,43 @@ class test_DjangoWorkerFixup(FixupCase):
                 f._settings.DEBUG = True
                 f.on_worker_ready()
 
+    def test_validate_models(self):
+        self.patch('celery.fixups.django.symbol_by_name')
+        self.patch('celery.fixups.django.import_module')
+        f = self.Fixup(self.app)
+        self.mock_modules('django.core.management.validation')
+        f.django_setup = Mock(name='django.setup')
+        from django.core.management.validation import get_validation_errors
+        get_validation_errors.return_value = 0
+        f.validate_models()
+        f.django_setup.assert_called_with()
+        get_validation_errors.return_value = 3
+        with self.assertRaises(RuntimeError):
+            f.validate_models()
+
+        self.mask_modules('django.core.management.validation')
+        f._validate_models_django17 = Mock('validate17')
+        f.validate_models()
+        f._validate_models_django17.assert_called_with()
+
+    def test_validate_models_django17(self):
+        self.patch('celery.fixups.django.symbol_by_name')
+        self.patch('celery.fixups.django.import_module')
+        self.mock_modules('django.core.management.base')
+        from django.core.management import base
+        f = self.Fixup(self.app)
+        f._validate_models_django17()
+        base.BaseCommand.assert_called_with()
+        base.BaseCommand().check.assert_called_with()
+
+    def test_django_setup(self):
+        self.patch('celery.fixups.django.symbol_by_name')
+        self.patch('celery.fixups.django.import_module')
+        django, = self.mock_modules('django')
+        f = self.Fixup(self.app)
+        f.django_setup()
+        django.setup.assert_called_with()
+
     def test_mysql_errors(self):
         with patch_modules('MySQLdb'):
             import MySQLdb as mod

+ 144 - 2
celery/tests/tasks/test_canvas.py

@@ -12,10 +12,13 @@ from celery.canvas import (
     chunks,
     _maybe_group,
     maybe_signature,
+    maybe_unroll_group,
 )
 from celery.result import EagerResult
 
-from celery.tests.case import AppCase, ContextMock, Mock
+from celery.tests.case import (
+    AppCase, ContextMock, MagicMock, Mock, depends_on_current_app,
+)
 
 SIG = Signature({'task': 'TASK',
                  'args': ('A1',),
@@ -24,6 +27,18 @@ SIG = Signature({'task': 'TASK',
                  'subtask_type': ''})
 
 
+class test_maybe_unroll_group(AppCase):
+
+    def test_when_no_len_and_no_length_hint(self):
+        g = MagicMock(name='group')
+        g.tasks.__len__.side_effect = TypeError()
+        g.tasks.__length_hint__ = Mock()
+        g.tasks.__length_hint__.return_value = 0
+        self.assertIs(maybe_unroll_group(g), g)
+        g.tasks.__length_hint__.side_effect = AttributeError()
+        self.assertIs(maybe_unroll_group(g), g)
+
+
 class CanvasCase(AppCase):
 
     def setup(self):
@@ -60,6 +75,12 @@ class test_Signature(CanvasCase):
         self.assertEqual(SIG.options, {'task_id': 'TASK_ID'})
         self.assertEqual(SIG.subtask_type, '')
 
+    def test_call(self):
+        x = Signature('foo', (1, 2), {'arg1': 33}, app=self.app)
+        x.type = Mock(name='type')
+        x(3, 4, arg2=66)
+        x.type.assert_called_with(3, 4, 1, 2, arg1=33, arg2=66)
+
     def test_link_on_scalar(self):
         x = Signature('TASK', link=Signature('B'))
         self.assertTrue(x.options['link'])
@@ -68,6 +89,16 @@ class test_Signature(CanvasCase):
         self.assertIn(Signature('B'), x.options['link'])
         self.assertIn(Signature('C'), x.options['link'])
 
+    def test_json(self):
+        x = Signature('TASK', link=Signature('B', app=self.app), app=self.app)
+        self.assertDictEqual(x.__json__(), dict(x))
+
+    @depends_on_current_app
+    def test_reduce(self):
+        x = Signature('TASK', (2, 4), app=self.app)
+        fun, args = x.__reduce__()
+        self.assertEqual(fun(*args), x)
+
     def test_replace(self):
         x = Signature('TASK', ('A'), {})
         self.assertTupleEqual(x.replace(args=('B',)).args, ('B',))
@@ -255,6 +286,35 @@ class test_chain(CanvasCase):
         self.assertEqual(tasks[-4].parent_id, tasks[-3].id)
         self.assertEqual(tasks[-4].root_id, 'root')
 
+    def test_splices_chains(self):
+        c = chain(
+            self.add.s(5, 5),
+            chain(self.add.s(6), self.add.s(7), self.add.s(8), app=self.app),
+            app=self.app,
+        )
+        c.freeze()
+        tasks, _ = c._frozen
+        self.assertEqual(len(tasks), 4)
+
+    def test_from_dict_no_tasks(self):
+        self.assertTrue(chain.from_dict(
+            dict(chain(app=self.app)), app=self.app))
+
+    @depends_on_current_app
+    def test_app_falls_back_to_default(self):
+        from celery._state import current_app
+        self.assertIs(chain().app, current_app)
+
+    def test_handles_dicts(self):
+        c = chain(
+            self.add.s(5, 5), dict(self.add.s(8)), app=self.app,
+        )
+        c.freeze()
+        tasks, _ = c._frozen
+        for task in tasks:
+            self.assertIsInstance(task, Signature)
+            self.assertIs(task.app, self.app)
+
     def test_group_to_chord(self):
         c = (
             self.add.s(5) |
@@ -316,7 +376,7 @@ class test_chain(CanvasCase):
         def s(*args, **kwargs):
             return static(self.add, args, kwargs, type=self.add, app=self.app)
 
-        c = s(2, 2) | s(4, 4) | s(8, 8)
+        c = s(2, 2) | s(4) | s(8)
         r1 = c.apply_async(task_id='some_id')
         self.assertEqual(r1.id, 'some_id')
 
@@ -423,6 +483,11 @@ class test_group(CanvasCase):
         self.assertIsInstance(signature(x), group)
         self.assertIsInstance(signature(dict(x)), group)
 
+    def test_group_with_group_argument(self):
+        g1 = group(self.add.s(2, 2), self.add.s(4, 4), app=self.app)
+        g2 = group(g1, app=self.app)
+        self.assertIs(g2.tasks, g1.tasks)
+
     def test_maybe_group_sig(self):
         self.assertListEqual(
             _maybe_group(self.add.s(2, 2), self.app), [self.add.s(2, 2)],
@@ -437,6 +502,35 @@ class test_group(CanvasCase):
         x = group([self.add.s(4, 4), self.add.s(8, 8)])
         x.apply_async()
 
+    def test_prepare_with_dict(self):
+        x = group([self.add.s(4, 4), dict(self.add.s(8, 8))], app=self.app)
+        x.apply_async()
+
+    def test_group_in_group(self):
+        g1 = group(self.add.s(2, 2), self.add.s(4, 4), app=self.app)
+        g2 = group(self.add.s(8, 8), g1, self.add.s(16, 16), app=self.app)
+        g2.apply_async()
+
+    def test_set_immutable(self):
+        g1 = group(Mock(name='t1'), Mock(name='t2'), app=self.app)
+        g1.set_immutable(True)
+        for task in g1.tasks:
+            task.set_immutable.assert_called_with(True)
+
+    def test_link(self):
+        g1 = group(Mock(name='t1'), Mock(name='t2'), app=self.app)
+        sig = Mock(name='sig')
+        g1.link(sig)
+        g1.tasks[0].link.assert_called_with(sig.clone().set(immutable=True))
+
+    def test_link_error(self):
+        g1 = group(Mock(name='t1'), Mock(name='t2'), app=self.app)
+        sig = Mock(name='sig')
+        g1.link_error(sig)
+        g1.tasks[0].link_error.assert_called_with(
+            sig.clone().set(immutable=True),
+        )
+
     def test_apply_empty(self):
         x = group(app=self.app)
         x.apply()
@@ -500,6 +594,41 @@ class test_chord(CanvasCase):
         z = y.clone()
         self.assertIsNone(z.kwargs.get('body'))
 
+    def test_argument_is_group(self):
+        x = chord(group(self.add.s(2, 2), self.add.s(4, 4), app=self.app))
+        self.assertTrue(x.tasks)
+
+    def test_set_parent_id(self):
+        x = chord(group(self.add.s(2, 2)))
+        x.tasks = [self.add.s(2, 2)]
+        x.set_parent_id('pid')
+
+    def test_app_when_app(self):
+        app = Mock(name='app')
+        x = chord([self.add.s(4, 4)], app=app)
+        self.assertIs(x.app, app)
+
+    def test_app_when_app_in_task(self):
+        t1 = Mock(name='t1')
+        t2 = Mock(name='t2')
+        x = chord([t1, self.add.s(4, 4)])
+        self.assertIs(x.app, x.tasks[0].app)
+        t1.app = None
+        x = chord([t1], body=t2)
+        self.assertIs(x.app, t2._app)
+
+    @depends_on_current_app
+    def test_app_fallback_to_current(self):
+        from celery._state import current_app
+        t1 = Mock(name='t1')
+        t1.app = t1._app = None
+        x = chord([t1], body=t1)
+        self.assertIs(x.app, current_app)
+
+    def test_set_immutable(self):
+        x = chord([Mock(name='t1'), Mock(name='t2')], app=self.app)
+        x.set_immutable(True)
+
     def test_links_to_body(self):
         x = chord([self.add.s(2, 2), self.add.s(4, 4)], body=self.mul.s(4))
         x.link(self.div.s(2))
@@ -519,6 +648,12 @@ class test_chord(CanvasCase):
         x.kwargs['body'] = None
         self.assertIn('without body', repr(x))
 
+    def test_freeze_tasks_is_not_group(self):
+        x = chord([self.add.s(2, 2)], body=self.add.s(), app=self.app)
+        x.freeze()
+        x.tasks = [self.add.s(2, 2)]
+        x.freeze()
+
 
 class test_maybe_signature(CanvasCase):
 
@@ -530,6 +665,13 @@ class test_maybe_signature(CanvasCase):
             maybe_signature(dict(self.add.s()), app=self.app), Signature,
         )
 
+    def test_is_list(self):
+        sigs = [dict(self.add.s(2, 2)), dict(self.add.s(4, 4))]
+        sigs = maybe_signature(sigs, app=self.app)
+        for sig in sigs:
+            self.assertIsInstance(sig, Signature)
+            self.assertIs(sig.app, self.app)
+
     def test_when_sig(self):
         s = self.add.s()
         self.assertIs(maybe_signature(s, app=self.app), s)

+ 122 - 16
celery/tests/tasks/test_result.py

@@ -3,18 +3,23 @@ from __future__ import absolute_import
 from contextlib import contextmanager
 
 from celery import states
-from celery.exceptions import IncompleteStream, TimeoutError
+from celery.exceptions import (
+    ImproperlyConfigured, IncompleteStream, TimeoutError,
+)
 from celery.five import range
 from celery.result import (
     AsyncResult,
     EagerResult,
+    ResultSet,
     result_from_tuple,
     assert_will_not_block,
 )
 from celery.utils import uuid
 from celery.utils.serialization import pickle
 
-from celery.tests.case import AppCase, Mock, depends_on_current_app, patch
+from celery.tests.case import (
+    AppCase, Mock, call, depends_on_current_app, patch,
+)
 
 
 def mock_task(name, state, result):
@@ -66,12 +71,22 @@ class test_AsyncResult(AppCase):
         task_join_will_block.return_value = False
         assert_will_not_block()
 
+    def test_without_id(self):
+        with self.assertRaises(ValueError):
+            AsyncResult(None, app=self.app)
+
     def test_compat_properties(self):
         x = self.app.AsyncResult('1')
         self.assertEqual(x.task_id, x.id)
         x.task_id = '2'
         self.assertEqual(x.id, '2')
 
+    @depends_on_current_app
+    def test_reduce_direct(self):
+        x = AsyncResult('1', app=self.app)
+        fun, args = x.__reduce__()
+        self.assertEqual(fun(*args), x)
+
     def test_children(self):
         x = self.app.AsyncResult('1')
         children = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
@@ -167,6 +182,15 @@ class test_AsyncResult(AppCase):
         a2 = self.app.AsyncResult('uuid')
         self.assertEqual(pickle.loads(pickle.dumps(a2)).id, 'uuid')
 
+    def test_maybe_set_cache_empty(self):
+        self.app.AsyncResult('uuid')._maybe_set_cache(None)
+
+    def test_set_cache__children(self):
+        r1 = self.app.AsyncResult('id1')
+        r2 = self.app.AsyncResult('id2')
+        r1._set_cache({'children': [r2.as_tuple()]})
+        self.assertIn(r2, r1.children)
+
     def test_successful(self):
         ok_res = self.app.AsyncResult(self.task1['id'])
         nok_res = self.app.AsyncResult(self.task3['id'])
@@ -224,13 +248,22 @@ class test_AsyncResult(AppCase):
         pending_res = self.app.AsyncResult(uuid())
         self.assertFalse(pending_res.traceback)
 
+    def test_get__backend_gives_None(self):
+        res = self.app.AsyncResult(self.task1['id'])
+        res.backend.wait_for = Mock(name='wait_for')
+        res.backend.wait_for.return_value = None
+        self.assertIsNone(res.get())
+
     def test_get(self):
         ok_res = self.app.AsyncResult(self.task1['id'])
         ok2_res = self.app.AsyncResult(self.task2['id'])
         nok_res = self.app.AsyncResult(self.task3['id'])
         nok2_res = self.app.AsyncResult(self.task4['id'])
 
-        self.assertEqual(ok_res.get(), 'the')
+        callback = Mock(name='callback')
+
+        self.assertEqual(ok_res.get(callback=callback), 'the')
+        callback.assert_called_with(ok_res.id, 'the')
         self.assertEqual(ok2_res.get(), 'quick')
         with self.assertRaises(KeyError):
             nok_res.get()
@@ -238,6 +271,21 @@ class test_AsyncResult(AppCase):
         self.assertIsInstance(nok2_res.result, KeyError)
         self.assertEqual(ok_res.info, 'the')
 
+    def test_eq_ne(self):
+        r1 = self.app.AsyncResult(self.task1['id'])
+        r2 = self.app.AsyncResult(self.task1['id'])
+        r3 = self.app.AsyncResult(self.task2['id'])
+        self.assertEqual(r1, r2)
+        self.assertNotEqual(r1, r3)
+        self.assertEqual(r1, r2.id)
+        self.assertNotEqual(r1, r3.id)
+
+    @depends_on_current_app
+    def test_reduce_restore(self):
+        r1 = self.app.AsyncResult(self.task1['id'])
+        fun, args = r1.__reduce__()
+        self.assertEqual(fun(*args), r1)
+
     def test_get_timeout(self):
         res = self.app.AsyncResult(self.task4['id'])  # has RETRY state
         with self.assertRaises(TimeoutError):
@@ -288,6 +336,29 @@ class test_ResultSet(AppCase):
         x.get()
         self.assertTrue(x.join_native.called)
 
+    def test_eq_ne(self):
+        g1 = self.app.ResultSet(
+            self.app.AsyncResult('id1'),
+            self.app.AsyncResult('id2'),
+        )
+        g2 = self.app.ResultSet(
+            self.app.AsyncResult('id1'),
+            self.app.AsyncResult('id2'),
+        )
+        g3 = self.app.ResultSet(
+            self.app.AsyncResult('id3'),
+            self.app.AsyncResult('id1'),
+        )
+        self.assertEqual(g1, g2)
+        self.assertNotEqual(g1, g3)
+        self.assertNotEqual(g1, object())
+
+    def test_takes_app_from_first_task(self):
+        x = ResultSet([self.app.AsyncResult('id1')])
+        self.assertIs(x.app, x.results[0].app)
+        x.app = self.app
+        self.assertIs(x.app, self.app)
+
     def test_get_empty(self):
         x = self.app.ResultSet([])
         self.assertIsNone(x.supports_native_join)
@@ -432,6 +503,24 @@ class test_GroupResult(AppCase):
         ts2 = self.app.GroupResult(uuid(), [self.app.AsyncResult(uuid())])
         self.assertEqual(pickle.loads(pickle.dumps(ts2)), ts2)
 
+    @depends_on_current_app
+    def test_reduce(self):
+        ts = self.app.GroupResult(uuid(), [self.app.AsyncResult(uuid())])
+        fun, args = ts.__reduce__()
+        ts2 = fun(*args)
+        self.assertEqual(ts2.id, ts.id)
+        self.assertEqual(ts, ts2)
+
+    def test_eq_ne(self):
+        ts = self.app.GroupResult(uuid(), [self.app.AsyncResult(uuid())])
+        ts2 = self.app.GroupResult(ts.id, ts.results)
+        ts3 = self.app.GroupResult(uuid(), [self.app.AsyncResult(uuid())])
+        ts4 = self.app.GroupResult(ts.id, [self.app.AsyncResult(uuid())])
+        self.assertEqual(ts, ts2)
+        self.assertNotEqual(ts, ts3)
+        self.assertNotEqual(ts, ts4)
+        self.assertNotEqual(ts, object())
+
     def test_len(self):
         self.assertEqual(len(self.ts), self.size)
 
@@ -439,7 +528,7 @@ class test_GroupResult(AppCase):
         self.assertFalse(self.ts == 1)
 
     @depends_on_current_app
-    def test_reduce(self):
+    def test_pickleable(self):
         self.assertTrue(pickle.loads(pickle.dumps(self.ts)))
 
     def test_iterate_raises(self):
@@ -471,8 +560,8 @@ class test_GroupResult(AppCase):
         ts.save()
         with self.assertRaises(AttributeError):
             ts.save(backend=object())
-        self.assertEqual(self.app.GroupResult.restore(ts.id).subtasks,
-                         ts.subtasks)
+        self.assertEqual(self.app.GroupResult.restore(ts.id).results,
+                         ts.results)
         ts.delete()
         self.assertIsNone(self.app.GroupResult.restore(ts.id))
         with self.assertRaises(AttributeError):
@@ -480,13 +569,18 @@ class test_GroupResult(AppCase):
 
     def test_join_native(self):
         backend = SimpleBackend()
-        subtasks = [self.app.AsyncResult(uuid(), backend=backend)
-                    for i in range(10)]
-        ts = self.app.GroupResult(uuid(), subtasks)
+        results = [self.app.AsyncResult(uuid(), backend=backend)
+                   for i in range(10)]
+        ts = self.app.GroupResult(uuid(), results)
         ts.app.backend = backend
-        backend.ids = [subtask.id for subtask in subtasks]
+        backend.ids = [result.id for result in results]
         res = ts.join_native()
         self.assertEqual(res, list(range(10)))
+        callback = Mock(name='callback')
+        self.assertFalse(ts.join_native(callback=callback))
+        callback.assert_has_calls([
+            call(r.id, i) for i, r in enumerate(ts.results)
+        ])
 
     def test_join_native_raises(self):
         ts = self.app.GroupResult(uuid(), [self.app.AsyncResult(uuid())])
@@ -518,11 +612,11 @@ class test_GroupResult(AppCase):
 
     def test_iter_native(self):
         backend = SimpleBackend()
-        subtasks = [self.app.AsyncResult(uuid(), backend=backend)
-                    for i in range(10)]
-        ts = self.app.GroupResult(uuid(), subtasks)
+        results = [self.app.AsyncResult(uuid(), backend=backend)
+                   for i in range(10)]
+        ts = self.app.GroupResult(uuid(), results)
         ts.app.backend = backend
-        backend.ids = [subtask.id for subtask in subtasks]
+        backend.ids = [result.id for result in results]
         self.assertEqual(len(list(ts.iter_native())), 10)
 
     def test_iterate_yields(self):
@@ -555,6 +649,9 @@ class test_GroupResult(AppCase):
         ar4.get = Mock()
         ts2 = self.app.GroupResult(uuid(), [ar4])
         self.assertTrue(ts2.join(timeout=0.1))
+        callback = Mock(name='callback')
+        self.assertFalse(ts2.join(timeout=0.1, callback=callback))
+        callback.assert_called_with(ar4.id, ar4.get())
 
     def test_iter_native_when_empty_group(self):
         ts = self.app.GroupResult(uuid(), [])
@@ -579,6 +676,15 @@ class test_GroupResult(AppCase):
     def test_failed(self):
         self.assertFalse(self.ts.failed())
 
+    def test_maybe_reraise(self):
+        self.ts.results = [Mock(name='r1')]
+        self.ts.maybe_reraise()
+        self.ts.results[0].maybe_reraise.assert_called_with()
+
+    def test_join__on_message(self):
+        with self.assertRaises(ImproperlyConfigured):
+            self.ts.join(on_message=Mock())
+
     def test_waiting(self):
         self.assertFalse(self.ts.waiting())
 
@@ -603,11 +709,11 @@ class test_failed_AsyncResult(test_GroupResult):
     def setup(self):
         self.app.conf.result_serializer = 'pickle'
         self.size = 11
-        subtasks = make_mock_group(self.app, 10)
+        results = make_mock_group(self.app, 10)
         failed = mock_task('ts11', states.FAILURE, KeyError('Baz'))
         save_result(self.app, failed)
         failed_res = self.app.AsyncResult(failed['id'])
-        self.ts = self.app.GroupResult(uuid(), subtasks + [failed_res])
+        self.ts = self.app.GroupResult(uuid(), results + [failed_res])
 
     def test_completed_count(self):
         self.assertEqual(self.ts.completed_count(), len(self.ts) - 1)

+ 34 - 0
celery/tests/tasks/test_trace.py

@@ -16,6 +16,8 @@ from celery.app.trace import (
     log_policy_expected,
     log_policy_unexpected,
     trace_task,
+    _trace_task_ret,
+    _fast_trace_task,
     setup_worker_optimizations,
     reset_worker_optimizations,
 )
@@ -178,6 +180,11 @@ class test_trace(TraceCase):
         retval, info = self.trace(rejecting, (), {})
         self.assertEqual(info.state, states.REJECTED)
 
+    def test_backend_cleanup_raises(self):
+        self.add.backend.process_cleanup = Mock()
+        self.add.backend.process_cleanup.side_effect = RuntimeError()
+        self.trace(self.add, (2, 2), {})
+
     @patch('celery.canvas.maybe_signature')
     def test_callbacks__scalar(self, maybe_signature):
         sig = Mock(name='sig')
@@ -188,6 +195,18 @@ class test_trace(TraceCase):
             (4,), parent_id='id-1', root_id='root',
         )
 
+    @patch('celery.canvas.maybe_signature')
+    def test_chain_proto2(self, maybe_signature):
+        sig = Mock(name='sig')
+        sig2 = Mock(name='sig2')
+        request = {'chain': [sig2, sig], 'root_id': 'root'}
+        maybe_signature.return_value = sig
+        retval, _ = self.trace(self.add, (2, 2), {}, request=request)
+        sig.apply_async.assert_called_with(
+            (4, ), parent_id='id-1', root_id='root',
+            chain=[sig2],
+        )
+
     @patch('celery.canvas.maybe_signature')
     def test_callbacks__EncodeError(self, maybe_signature):
         sig = Mock(name='sig')
@@ -253,6 +272,21 @@ class test_trace(TraceCase):
         self.assertEqual(info.state, states.FAILURE)
         self.assertIs(info.retval, exc)
 
+    def test_trace_task_ret__no_content_type(self):
+        _trace_task_ret(
+            self.add.name, 'id1', {}, ((2, 2), {}), None, None,
+            app=self.app,
+        )
+
+    def test_fast_trace_task__no_content_type(self):
+        self.app.tasks[self.add.name].__trace__ = build_tracer(
+            self.add.name, self.add, app=self.app,
+        )
+        _fast_trace_task(
+            self.add.name, 'id1', {}, ((2, 2), {}), None, None,
+            app=self.app, _loc=[self.app.tasks, {}, 'hostname']
+        )
+
     def test_trace_exception_propagate(self):
         with self.assertRaises(KeyError):
             self.trace(self.raises, (KeyError('foo'),), {}, propagate=True)

+ 106 - 1
celery/tests/utils/test_functional.py

@@ -3,21 +3,37 @@ from __future__ import absolute_import
 import pickle
 import sys
 
+from itertools import count
+
 from kombu.utils.functional import lazy
 
 from celery.five import THREAD_TIMEOUT_MAX, items, range, nextfun
 from celery.utils.functional import (
+    DummyContext,
     LRUCache,
+    head_from_fun,
     firstmethod,
     first,
+    maybe_list,
+    memoize,
     mlazy,
     padlist,
-    maybe_list,
+    regen,
 )
 
 from celery.tests.case import Case, SkipTest
 
 
+class test_DummyContext(Case):
+
+    def test_context(self):
+        with DummyContext():
+            pass
+        with self.assertRaises(KeyError):
+            with DummyContext():
+                raise KeyError()
+
+
 class test_LRUCache(Case):
 
     def test_expires(self):
@@ -176,6 +192,24 @@ class test_utils(Case):
         self.assertIsNone(maybe_list(None))
 
 
+class test_memoize(Case):
+
+    def test_memoize(self):
+        counter = count(1)
+
+        @memoize(maxsize=2)
+        def x(i):
+            return next(counter)
+
+        self.assertEqual(x(1), 1)
+        self.assertEqual(x(1), 1)
+        self.assertEqual(x(2), 2)
+        self.assertEqual(x(3), 3)
+        self.assertEqual(x(1), 4)
+        x.clear()
+        self.assertEqual(x(3), 5)
+
+
 class test_mlazy(Case):
 
     def test_is_memoized(self):
@@ -186,3 +220,74 @@ class test_mlazy(Case):
         self.assertTrue(p.evaluated)
         self.assertEqual(p(), 20)
         self.assertEqual(repr(p), '20')
+
+
+class test_regen(Case):
+
+    def test_regen_list(self):
+        l = [1, 2]
+        r = regen(iter(l))
+        self.assertIs(regen(l), l)
+        self.assertEqual(r, l)
+        self.assertEqual(r, l)
+        self.assertEqual(r.__length_hint__(), 0)
+
+        fun, args = r.__reduce__()
+        self.assertEqual(fun(*args), l)
+
+    def test_regen_gen(self):
+        g = regen(iter(list(range(10))))
+        self.assertEqual(g[7], 7)
+        self.assertEqual(g[6], 6)
+        self.assertEqual(g[5], 5)
+        self.assertEqual(g[4], 4)
+        self.assertEqual(g[3], 3)
+        self.assertEqual(g[2], 2)
+        self.assertEqual(g[1], 1)
+        self.assertEqual(g[0], 0)
+        self.assertEqual(g.data, list(range(10)))
+        self.assertEqual(g[8], 8)
+        self.assertEqual(g[0], 0)
+        g = regen(iter(list(range(10))))
+        self.assertEqual(g[0], 0)
+        self.assertEqual(g[1], 1)
+        self.assertEqual(g.data, list(range(10)))
+        g = regen(iter([1]))
+        self.assertEqual(g[0], 1)
+        with self.assertRaises(IndexError):
+            g[1]
+        self.assertEqual(g.data, [1])
+
+        g = regen(iter(list(range(10))))
+        self.assertEqual(g[-1], 9)
+        self.assertEqual(g[-2], 8)
+        self.assertEqual(g[-3], 7)
+        self.assertEqual(g[-4], 6)
+        self.assertEqual(g[-5], 5)
+        self.assertEqual(g[5], 5)
+        self.assertEqual(g.data, list(range(10)))
+
+        self.assertListEqual(list(iter(g)), list(range(10)))
+
+
+class test_head_from_fun(Case):
+
+    def test_from_cls(self):
+        class X(object):
+            def __call__(x, y, kwarg=1):
+                pass
+
+        g = head_from_fun(X())
+        with self.assertRaises(TypeError):
+            g(1)
+        g(1, 2)
+        g(1, 2, kwarg=3)
+
+    def test_from_fun(self):
+        def f(x, y, kwarg=1):
+            pass
+        g = head_from_fun(f)
+        with self.assertRaises(TypeError):
+            g(1)
+        g(1, 2)
+        g(1, 2, kwarg=3)

+ 1 - 0
celery/tests/utils/test_imports.py

@@ -19,6 +19,7 @@ class test_import_utils(Case):
         imp.return_value = None
         with self.assertRaises(NotAPackage):
             find_module('foo.bar.baz', imp=imp)
+        self.assertTrue(find_module('celery.worker.request'))
 
     def test_qualname(self):
         Class = type('Fox', (object,), {'__module__': 'quick.brown'})

+ 6 - 0
celery/tests/utils/test_local.py

@@ -31,6 +31,12 @@ class test_Proxy(Case):
         self.assertEqual(Proxy.__module__, 'celery.local')
         self.assertIsInstance(Proxy.__doc__, str)
 
+    def test_doc(self):
+        def real():
+            pass
+        x = Proxy(real, __doc__='foo')
+        self.assertEqual(x.__doc__, 'foo')
+
     def test_name(self):
 
         def real():

+ 112 - 6
celery/tests/utils/test_platforms.py

@@ -12,7 +12,9 @@ from celery.five import open_fqdn
 from celery.platforms import (
     get_fdmax,
     ignore_errno,
+    check_privileges,
     set_process_title,
+    set_mp_process_title,
     signals,
     maybe_drop_privileges,
     setuid,
@@ -61,9 +63,14 @@ class test_fd_by_path(Case):
 
     def test_finds(self):
         test_file = tempfile.NamedTemporaryFile()
-        keep = fd_by_path([test_file.name])
-        self.assertEqual(keep, [test_file.file.fileno()])
-        test_file.close()
+        try:
+            keep = fd_by_path([test_file.name])
+            self.assertEqual(keep, [test_file.file.fileno()])
+            with patch('os.open') as _open:
+                _open.side_effect = OSError()
+                self.assertFalse(fd_by_path([test_file.name]))
+        finally:
+            test_file.close()
 
 
 class test_close_open_fds(Case):
@@ -99,13 +106,27 @@ class test_ignore_errno(Case):
 
 class test_set_process_title(Case):
 
-    def when_no_setps(self):
-        prev = platforms._setproctitle = platforms._setproctitle, None
+    def test_no_setps(self):
+        prev, platforms._setproctitle = platforms._setproctitle, None
         try:
             set_process_title('foo')
         finally:
             platforms._setproctitle = prev
 
+    @patch('celery.platforms.set_process_title')
+    @patch('celery.platforms.current_process')
+    def test_mp_no_hostname(self, current_process, set_process_title):
+        current_process().name = 'Foo'
+        set_mp_process_title('foo', info='hello')
+        set_process_title.assert_called_with('foo:Foo', info='hello')
+
+    @patch('celery.platforms.set_process_title')
+    @patch('celery.platforms.current_process')
+    def test_mp_hostname(self, current_process, set_process_title):
+        current_process().name = 'Foo'
+        set_mp_process_title('foo', hostname='a@q.com', info='hello')
+        set_process_title.assert_called_with('foo: a@q.com:Foo', info='hello')
+
 
 class test_Signals(Case):
 
@@ -146,6 +167,11 @@ class test_Signals(Case):
         signals.ignore('SIGTERM')
         set.assert_called_with(signals.signum('TERM'), signals.ignored)
 
+    @patch('signal.signal')
+    def test_reset(self, set):
+        signals.reset('SIGINT')
+        set.assert_called_with(signals.signum('INT'), signals.default)
+
     @patch('signal.signal')
     def test_setitem(self, set):
         def handle(*args):
@@ -180,13 +206,27 @@ if not platforms.IS_WINDOWS:
 
     class test_maybe_drop_privileges(Case):
 
+        def test_on_windows(self):
+            prev, sys.platform = sys.platform, 'win32'
+            try:
+                maybe_drop_privileges()
+            finally:
+                sys.platform = prev
+
+        @patch('os.getegid')
+        @patch('os.getgid')
+        @patch('os.geteuid')
+        @patch('os.getuid')
         @patch('celery.platforms.parse_uid')
         @patch('pwd.getpwuid')
         @patch('celery.platforms.setgid')
         @patch('celery.platforms.setuid')
         @patch('celery.platforms.initgroups')
         def test_with_uid(self, initgroups, setuid, setgid,
-                          getpwuid, parse_uid):
+                          getpwuid, parse_uid, getuid, geteuid,
+                          getgid, getegid):
+            geteuid.return_value = 10
+            getuid.return_value = 10
 
             class pw_struct(object):
                 pw_gid = 50001
@@ -204,6 +244,40 @@ if not platforms.IS_WINDOWS:
             initgroups.assert_called_with(5001, 50001)
             setuid.assert_has_calls([call(5001), call(0)])
 
+            setuid.side_effect = raise_on_second_call
+
+            def to_root_on_second_call(mock, first):
+                return_value = [first]
+
+                def on_first_call(*args, **kwargs):
+                    ret, return_value[0] = return_value[0], 0
+                    return ret
+                mock.side_effect = on_first_call
+            to_root_on_second_call(geteuid, 10)
+            to_root_on_second_call(getuid, 10)
+            with self.assertRaises(AssertionError):
+                maybe_drop_privileges(uid='user')
+
+            getuid.return_value = getuid.side_effect = None
+            geteuid.return_value = geteuid.side_effect = None
+            getegid.return_value = 0
+            getgid.return_value = 0
+            setuid.side_effect = raise_on_second_call
+            with self.assertRaises(AssertionError):
+                maybe_drop_privileges(gid='group')
+
+            getuid.reset_mock()
+            geteuid.reset_mock()
+            setuid.reset_mock()
+            getuid.side_effect = geteuid.side_effect = None
+
+            def raise_on_second_call(*args, **kwargs):
+                setuid.side_effect = OSError()
+                setuid.side_effect.errno = errno.ENOENT
+            setuid.side_effect = raise_on_second_call
+            with self.assertRaises(OSError):
+                maybe_drop_privileges(uid='user')
+
         @patch('celery.platforms.parse_uid')
         @patch('celery.platforms.parse_gid')
         @patch('celery.platforms.setgid')
@@ -421,6 +495,20 @@ if not platforms.IS_WINDOWS:
                 pass
             x.after_chdir.assert_called_with()
 
+            x = DaemonContext(workdir='/opt/workdir', umask="0755")
+            self.assertEqual(x.umask, 493)
+            x = DaemonContext(workdir='/opt/workdir', umask="493")
+            self.assertEqual(x.umask, 493)
+
+            x.redirect_to_null(None)
+
+            with patch('celery.platforms.mputil') as mputil:
+                x = DaemonContext(after_forkers=True)
+                x.open()
+                mputil._run_after_forkers.assert_called_with()
+                x = DaemonContext(after_forkers=False)
+                x.open()
+
     class test_Pidfile(Case):
 
         @patch('celery.platforms.Pidfile')
@@ -711,3 +799,21 @@ if not platforms.IS_WINDOWS:
             with self.assertRaises(OSError):
                 setgroups(list(range(400)))
             getgroups.assert_called_with()
+
+
+class test_check_privileges(Case):
+
+    def test_suspicious(self):
+        class Obj(object):
+            fchown = 13
+        prev, platforms.os = platforms.os, Obj()
+        try:
+            with self.assertRaises(AssertionError):
+                check_privileges({'pickle'})
+        finally:
+            platforms.os = prev
+        prev, platforms.os = platforms.os, object()
+        try:
+            check_privileges({'pickle'})
+        finally:
+            platforms.os = prev

+ 9 - 0
celery/tests/utils/test_saferepr.py

@@ -148,6 +148,15 @@ class test_saferepr(Case):
             saferepr(D_D_TEXT, 100).endswith("...', ...}}")
         )
 
+    def test_maxlevels(self):
+        saferepr(D_ALL, maxlevels=1)
+
+    def test_recursion(self):
+        d = {1: 2, 3: {4: 5}}
+        d[3][6] = d
+        res = saferepr(d)
+        self.assertIn('Recursion on', res)
+
     def test_same_as_repr(self):
         # Simple objects, small containers and classes that overwrite __repr__
         # For those the result should be the same as repr().

+ 17 - 1
celery/tests/utils/test_timer2.py

@@ -5,7 +5,7 @@ import time
 
 import celery.utils.timer2 as timer2
 
-from celery.tests.case import Case, Mock, patch
+from celery.tests.case import Case, Mock, patch, call
 from kombu.tests.case import redirect_stdouts
 
 
@@ -98,6 +98,11 @@ class test_Timer(Case):
         t.start = Mock()
         t.ensure_started()
         self.assertFalse(t.start.called)
+        t.running = False
+        t.on_start = Mock()
+        t.ensure_started()
+        t.on_start.assert_called_with(t)
+        t.start.assert_called_with()
 
     def test_call_repeatedly(self):
         t = timer2.Timer()
@@ -136,6 +141,17 @@ class test_Timer(Case):
         t.schedule.apply_entry(fun)
         self.assertTrue(logger.error.called)
 
+    @patch('celery.utils.timer2.sleep')
+    def test_on_tick(self, sleep):
+        on_tick = Mock(name='on_tick')
+        t = timer2.Timer(on_tick=on_tick)
+        ne = t._next_entry = Mock(name='_next_entry')
+        ne.return_value = 3.33
+        self.on_nth_call_do(ne, t._is_shutdown.set, 3)
+        t.run()
+        sleep.assert_called_with(3.33)
+        on_tick.assert_has_class(call(3.33), call(3.33), call(3.33))
+
     @redirect_stdouts
     def test_apply_entry_error_not_handled(self, stdout, stderr):
         t = timer2.Timer()

+ 2 - 2
celery/tests/utils/test_timeutils.py

@@ -248,6 +248,6 @@ class test_utcoffset(Case):
     def test_utcoffset(self):
         with patch('celery.utils.timeutils._time') as _time:
             _time.daylight = True
-            self.assertIsNotNone(utcoffset())
+            self.assertIsNotNone(utcoffset(time=_time))
             _time.daylight = False
-            self.assertIsNotNone(utcoffset())
+            self.assertIsNotNone(utcoffset(time=_time))

+ 44 - 17
celery/tests/worker/test_autoreload.py

@@ -18,7 +18,7 @@ from celery.worker.autoreload import (
     Autoreloader,
 )
 
-from celery.tests.case import AppCase, Case, Mock, SkipTest, patch, mock_open
+from celery.tests.case import AppCase, Case, Mock, patch, mock_open
 
 
 class test_WorkerComponent(AppCase):
@@ -75,6 +75,7 @@ class test_BaseMonitor(Case):
         x._on_change = Mock()
         x.on_change('foo')
         x._on_change.assert_called_with('foo')
+        x.on_event_loop_close(Mock())
 
 
 class test_StatMonitor(Case):
@@ -99,6 +100,12 @@ class test_StatMonitor(Case):
         stat.side_effect = OSError()
         x.start()
 
+    def test_register_with_event_loop(self):
+        hub = Mock(name='hub')
+        x = StatMonitor(['a'])
+        x.register_with_event_loop(hub)
+        hub.call_repeatedly.assert_called_with(2.0, x.find_changes)
+
     @patch('os.stat')
     def test_mtime_stat_raises(self, stat):
         stat.side_effect = ValueError()
@@ -122,10 +129,8 @@ class test_KQueueMonitor(Case):
         close.side_effect.errno = errno.EBADF
         x.stop()
 
-    def test_register_with_event_loop(self):
-        from kombu.utils import eventio
-        if eventio.kqueue is None:
-            raise SkipTest('version of kombu does not work with pypy')
+    @patch('kombu.utils.eventio.kqueue', create=True)
+    def test_register_with_event_loop(self, kqueue):
         x = KQueueMonitor(['a', 'b'])
         hub = Mock(name='hub')
         x.add_events = Mock(name='add_events()')
@@ -136,6 +141,15 @@ class test_KQueueMonitor(Case):
             x.handle_event,
         )
 
+    def test_register_with_event_loop_no_kqueue(self):
+        from kombu.utils import eventio
+        prev, eventio.kqueue = eventio.kqueue, None
+        try:
+            x = KQueueMonitor(['a'])
+            x.register_with_event_loop(Mock())
+        finally:
+            eventio.kqueue = prev
+
     def test_on_event_loop_close(self):
         x = KQueueMonitor(['a', 'b'])
         x.close = Mock()
@@ -201,21 +215,34 @@ class test_InotifyMonitor(Case):
 
     @patch('celery.worker.autoreload.pyinotify')
     def test_start(self, inotify):
-            x = InotifyMonitor(['a'])
-            inotify.IN_MODIFY = 1
-            inotify.IN_ATTRIB = 2
+        x = InotifyMonitor(['a'])
+        inotify.IN_MODIFY = 1
+        inotify.IN_ATTRIB = 2
+        x.start()
+
+        inotify.WatchManager.side_effect = ValueError()
+        with self.assertRaises(ValueError):
             x.start()
+        x.stop()
 
-            inotify.WatchManager.side_effect = ValueError()
-            with self.assertRaises(ValueError):
-                x.start()
-            x.stop()
+        x._on_change = None
+        x.process_(Mock())
+        x._on_change = Mock()
+        x.process_(Mock())
+        self.assertTrue(x._on_change.called)
 
-            x._on_change = None
-            x.process_(Mock())
-            x._on_change = Mock()
-            x.process_(Mock())
-            self.assertTrue(x._on_change.called)
+        x.create_notifier = Mock()
+        x._wm = Mock()
+        hub = Mock()
+        x.register_with_event_loop(hub)
+        x.create_notifier.assert_called_with()
+        hub.add_reader.assert_called_with(x._wm.get_fd(), x.on_readable)
+
+        x.on_event_loop_close(hub)
+        x._notifier = Mock()
+        x.on_readable()
+        x._notifier.read_events.assert_called_with()
+        x._notifier.process_events.assert_called_with()
 
 
 class test_default_implementation(Case):

+ 25 - 0
celery/tests/worker/test_bootsteps.py

@@ -148,6 +148,12 @@ class test_ConsumerStep(AppCase):
         step = Step(self)
         step.start(self)
 
+    def test_close_no_consumer_channel(self):
+        step = bootsteps.ConsumerStep(Mock())
+        step.consumers = [Mock()]
+        step.consumers[0].channel = None
+        step._close(Mock())
+
 
 class test_StartStopStep(AppCase):
 
@@ -177,6 +183,11 @@ class test_StartStopStep(AppCase):
         x.obj = None
         self.assertIsNone(x.start(self))
 
+    def test_terminate__no_obj(self):
+        x = self.Def(self)
+        x.obj = None
+        x.terminate(Mock())
+
     def test_include_when_disabled(self):
         x = self.Def(self)
         x.enabled = False
@@ -237,6 +248,20 @@ class test_Blueprint(AppCase):
         parent.steps = [None, None, None]
         blueprint.send_all(parent, 'close', 'Closing', reverse=False)
 
+    def test_send_all_raises(self):
+        parent = Mock()
+        blueprint = self.Blueprint(app=self.app)
+        parent.steps = [Mock()]
+        parent.steps[0].foo.side_effect = KeyError()
+        blueprint.send_all(parent, 'foo', propagate=False)
+        with self.assertRaises(KeyError):
+            blueprint.send_all(parent, 'foo', propagate=True)
+
+    def test_stop_state_in_TERMINATE(self):
+        blueprint = self.Blueprint(app=self.app)
+        blueprint.state = bootsteps.TERMINATE
+        blueprint.stop(Mock())
+
     def test_join_raises_IGNORE_ERRORS(self):
         prev, bootsteps.IGNORE_ERRORS = bootsteps.IGNORE_ERRORS, (KeyError,)
         try:

+ 47 - 2
celery/tests/worker/test_components.py

@@ -4,10 +4,46 @@ from __future__ import absolute_import
 # here to complete coverage.  Should move everyting to this module at some
 # point [-ask]
 
+from celery.exceptions import ImproperlyConfigured
 from celery.platforms import IS_WINDOWS
-from celery.worker.components import Pool
+from celery.worker.components import Beat, Hub, Pool, Timer
 
-from celery.tests.case import AppCase, Mock, SkipTest
+from celery.tests.case import AppCase, Mock, SkipTest, patch
+
+
+class test_Timer(AppCase):
+
+    def test_create__eventloop(self):
+        w = Mock(name='w')
+        w.use_eventloop = True
+        Timer(w).create(w)
+        self.assertFalse(w.timer.queue)
+
+
+class test_Hub(AppCase):
+
+    def setup(self):
+        self.w = Mock(name='w')
+        self.hub = Hub(self.w)
+        self.w.hub = Mock(name='w.hub')
+
+    @patch('celery.worker.components.set_event_loop')
+    @patch('celery.worker.components.get_event_loop')
+    def test_create(self, get_event_loop, set_event_loop):
+        self.hub._patch_thread_primitives = Mock(name='ptp')
+        self.assertIs(self.hub.create(self.w), self.hub)
+        self.hub._patch_thread_primitives.assert_called_with(self.w)
+
+    def test_start(self):
+        self.hub.start(self.w)
+
+    def test_stop(self):
+        self.hub.stop(self.w)
+        self.w.hub.close.assert_called_with()
+
+    def test_terminate(self):
+        self.hub.terminate(self.w)
+        self.w.hub.close.assert_called_with()
 
 
 class test_Pool(AppCase):
@@ -46,3 +82,12 @@ class test_Pool(AppCase):
 
         self.assertEqual(
             comp.instantiate.call_args[1]['max_memory_per_child'], 32)
+
+
+class test_Beat(AppCase):
+
+    def test_create__green(self):
+        w = Mock(name='w')
+        w.pool_cls.__module__ = 'foo_gevent'
+        with self.assertRaises(ImproperlyConfigured):
+            Beat(w).create(w)

+ 4 - 0
celery/tests/worker/test_control.py

@@ -562,6 +562,10 @@ class test_ControlPanel(AppCase):
         consumer.update_strategies.assert_called_with()
         self.assertFalse(_reload.called)
         self.assertFalse(_import.called)
+        consumer.controller.pool.restart.side_effect = NotImplementedError()
+        panel.handle('pool_restart', {'reloader': _reload})
+        consumer.controller.consumer = None
+        panel.handle('pool_restart', {'reloader': _reload})
 
     def test_pool_restart_import_modules(self):
         consumer = Consumer(self.app)

+ 52 - 3
celery/tests/worker/test_worker.py

@@ -12,10 +12,11 @@ from kombu import Connection
 from kombu.common import QoS, ignore_errors
 from kombu.transport.base import Message
 
-from celery.bootsteps import RUN, CLOSE, StartStopStep
+from celery.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
 from celery.concurrency.base import BasePool
 from celery.exceptions import (
-    WorkerShutdown, WorkerTerminate, TaskRevokedError, InvalidTaskError,
+    WorkerShutdown, WorkerTerminate, TaskRevokedError,
+    InvalidTaskError, ImproperlyConfigured,
 )
 from celery.five import Empty, range, Queue as FastQueue
 from celery.platforms import EX_FAILURE
@@ -828,6 +829,17 @@ class test_WorkController(AppCase):
             worker_direct(self.worker.hostname),
         )
 
+    def test_setup_queues__missing_queue(self):
+        self.app.amqp.queues.select = Mock(name='select')
+        self.app.amqp.queues.deselect = Mock(name='deselect')
+        self.app.amqp.queues.select.side_effect = KeyError()
+        self.app.amqp.queues.deselect.side_effect = KeyError()
+        with self.assertRaises(ImproperlyConfigured):
+            self.worker.setup_queues("x,y", exclude="foo,bar")
+        self.app.amqp.queues.select = Mock(name='select')
+        with self.assertRaises(ImproperlyConfigured):
+            self.worker.setup_queues("x,y", exclude="foo,bar")
+
     def test_send_worker_shutdown(self):
         with patch('celery.signals.worker_shutdown') as ws:
             self.worker._send_worker_shutdown()
@@ -1031,6 +1043,23 @@ class test_WorkController(AppCase):
         worker.consumer.close.side_effect = AttributeError()
         worker.signal_consumer_close()
 
+    def test_rusage__no_resource(self):
+        from celery import worker
+        prev, worker.resource = worker.resource, None
+        try:
+            self.worker.pool = Mock(name='pool')
+            with self.assertRaises(NotImplementedError):
+                self.worker.rusage()
+            self.worker.stats()
+        finally:
+            worker.resource = prev
+
+    def test_repr(self):
+        self.assertTrue(repr(self.worker))
+
+    def test_str(self):
+        self.assertEqual(str(self.worker), self.worker.hostname)
+
     def test_start__stop(self):
         worker = self.worker
         worker.blueprint.shutdown_complete.set()
@@ -1046,7 +1075,7 @@ class test_WorkController(AppCase):
         for w in worker.steps:
             self.assertTrue(w.start.call_count)
         worker.consumer = Mock()
-        worker.stop()
+        worker.stop(exitcode=3)
         for stopstep in worker.steps:
             self.assertTrue(stopstep.close.call_count)
             self.assertTrue(stopstep.stop.call_count)
@@ -1061,6 +1090,24 @@ class test_WorkController(AppCase):
         worker.start()
         worker.stop()
 
+    def test_start__KeyboardInterrupt(self):
+        worker = self.worker
+        worker.blueprint = Mock(name='blueprint')
+        worker.blueprint.start.side_effect = KeyboardInterrupt()
+        worker.stop = Mock(name='stop')
+        worker.start()
+        worker.stop.assert_called_with(exitcode=EX_FAILURE)
+
+    def test_register_with_event_loop(self):
+        worker = self.worker
+        hub = Mock(name='hub')
+        worker.blueprint = Mock(name='blueprint')
+        worker.register_with_event_loop(hub)
+        worker.blueprint.send_all.assert_called_with(
+            worker, 'register_with_event_loop', args=(hub,),
+            description='hub.register',
+        )
+
     def test_step_raises(self):
         worker = self.worker
         step = Mock()
@@ -1087,6 +1134,8 @@ class test_WorkController(AppCase):
         worker.terminate()
         for step in worker.steps:
             self.assertTrue(step.terminate.call_count)
+        worker.blueprint.state = TERMINATE
+        worker.terminate()
 
     def test_Hub_crate(self):
         w = Mock()

+ 8 - 8
celery/utils/log.py

@@ -59,7 +59,7 @@ def iter_open_logger_fds():
         try:
             for handler in logger.handlers:
                 try:
-                    if handler not in seen:
+                    if handler not in seen:  # pragma: no cover
                         yield handler.stream
                         seen.add(handler)
                 except AttributeError:
@@ -91,7 +91,7 @@ def logger_isa(l, p, max=1000):
             this = this.parent
             if not this:
                 break
-    else:
+    else:  # pragma: no cover
         raise RuntimeError('Logger hierarchy exceeds {0}'.format(max))
     return False
 
@@ -99,7 +99,7 @@ def logger_isa(l, p, max=1000):
 def get_logger(name):
     l = _get_logger(name)
     if logging.root not in (l, l.parent) and l is not base_logger:
-        if not logger_isa(l, base_logger):
+        if not logger_isa(l, base_logger):  # pragma: no cover
             l.parent = base_logger
     return l
 task_logger = get_logger('celery.task')
@@ -154,7 +154,7 @@ class ColorFormatter(logging.Formatter):
                     if isinstance(msg, string_t):
                         return text_t(color(safe_str(msg)))
                     return safe_str(color(msg))
-                except UnicodeDecodeError:
+                except UnicodeDecodeError:  # pragma: no cover
                     return safe_str(msg)  # skip colors
             except Exception as exc:
                 prev_msg, record.exc_info, record.msg = (
@@ -258,7 +258,7 @@ class LoggingProxy(object):
 def get_multiprocessing_logger():
     try:
         from billiard import util
-    except ImportError:
+    except ImportError:  # pragma: no cover
             pass
     else:
         return util.get_logger()
@@ -267,17 +267,17 @@ def get_multiprocessing_logger():
 def reset_multiprocessing_logger():
     try:
         from billiard import util
-    except ImportError:
+    except ImportError:  # pragma: no cover
         pass
     else:
-        if hasattr(util, '_logger'):
+        if hasattr(util, '_logger'):  # pragma: no cover
             util._logger = None
 
 
 def current_process():
     try:
         from billiard import process
-    except ImportError:
+    except ImportError:  # pragma: no cover
         pass
     else:
         return process.current_process()

+ 3 - 3
celery/utils/saferepr.py

@@ -36,7 +36,7 @@ __all__ = ['saferepr', 'reprstream']
 
 IS_PY3 = sys.version_info[0] == 3
 
-if IS_PY3:
+if IS_PY3:  # pragma: no cover
     range_t = (range, )
 else:
     class range_t(object):  # noqa
@@ -110,7 +110,7 @@ def _saferepr(o, maxlen=None, maxlevels=3, seen=None):
             val = saferepr(token.value, maxlen, maxlevels)
         elif isinstance(token, _quoted):
             val = token.value
-            if IS_PY3 and isinstance(val, bytes):
+            if IS_PY3 and isinstance(val, bytes):  # pragma: no cover
                 val = "b'%s'" % (bytes_to_str(truncate_bytes(val, maxlen)),)
             else:
                 val = "'%s'" % (truncate(val, maxlen),)
@@ -163,7 +163,7 @@ def reprstream(stack, seen=None, maxlevels=3, level=0, isinstance=isinstance):
                 yield text_t(val), it
             elif isinstance(val, chars_t):
                 yield _quoted(val), it
-            elif isinstance(val, range_t):
+            elif isinstance(val, range_t):  # pragma: no cover
                 yield repr(val), it
             else:
                 if isinstance(val, set_t):

+ 2 - 2
celery/utils/timeutils.py

@@ -86,7 +86,7 @@ class LocalTimezone(tzinfo):
     def tzname(self, dt):
         return _time.tzname[self._isdst(dt)]
 
-    if PY3:
+    if PY3:  # pragma: no cover
 
         def fromutc(self, dt):
             # The base tzinfo class no longer implements a DST
@@ -122,7 +122,7 @@ class _Zone(object):
             dt = make_aware(dt, orig or self.utc)
         return localize(dt, self.tz_or_local(local))
 
-    if PY33:
+    if PY33:  # pragma: no cover
 
         def to_system(self, dt):
             # tz=None is a special case since Python 3.3, and will

+ 4 - 3
celery/worker/components.py

@@ -92,7 +92,7 @@ class Hub(bootsteps.StartStopStep):
         # multiprocessing's ApplyResult uses this lock.
         try:
             from billiard import pool
-        except ImportError:
+        except ImportError:  # pragma: no cover
             pass
         else:
             pool.Lock = DummyLock
@@ -137,8 +137,9 @@ class Pool(bootsteps.StartStopStep):
         if w.pool:
             w.pool.terminate()
 
-    def create(self, w, semaphore=None, max_restarts=None):
-        if w.app.conf.worker_pool in ('eventlet', 'gevent'):
+    def create(self, w, semaphore=None, max_restarts=None,
+               green_pools={'eventlet', 'gevent'}):
+        if w.app.conf.worker_pool in green_pools:  # pragma: no cover
             warnings.warn(UserWarning(W_POOL_SETTING))
         threaded = not w.use_eventloop or IS_WINDOWS
         procs = w.min_concurrency