Browse Source

92% coverage

Ask Solem 9 years ago
parent
commit
0c801b9070
40 changed files with 1177 additions and 245 deletions
  1. 13 2
      .coveragerc
  2. 1 2
      celery/app/base.py
  3. 1 1
      celery/app/defaults.py
  4. 1 1
      celery/app/task.py
  5. 2 2
      celery/app/trace.py
  6. 1 1
      celery/app/utils.py
  7. 1 1
      celery/backends/cache.py
  8. 1 1
      celery/backends/database/__init__.py
  9. 2 1
      celery/backends/database/session.py
  10. 1 1
      celery/events/dumper.py
  11. 10 8
      celery/schedules.py
  12. 118 0
      celery/tests/app/test_app.py
  13. 16 0
      celery/tests/app/test_loaders.py
  14. 1 0
      celery/tests/app/test_log.py
  15. 7 1
      celery/tests/app/test_routes.py
  16. 72 2
      celery/tests/app/test_schedules.py
  17. 70 1
      celery/tests/backends/test_database.py
  18. 3 0
      celery/tests/backends/test_rpc.py
  19. 7 6
      celery/tests/case.py
  20. 46 1
      celery/tests/concurrency/test_concurrency.py
  21. 68 43
      celery/tests/concurrency/test_eventlet.py
  22. 62 66
      celery/tests/concurrency/test_gevent.py
  23. 98 16
      celery/tests/concurrency/test_prefork.py
  24. 11 10
      celery/tests/fixups/test_django.py
  25. 5 0
      celery/tests/security/test_certificate.py
  26. 10 0
      celery/tests/security/test_security.py
  27. 61 2
      celery/tests/tasks/test_tasks.py
  28. 113 2
      celery/tests/tasks/test_trace.py
  29. 98 0
      celery/tests/utils/test_debug.py
  30. 31 1
      celery/tests/utils/test_mail.py
  31. 5 0
      celery/tests/utils/test_text.py
  32. 75 1
      celery/tests/utils/test_utils.py
  33. 1 1
      celery/tests/worker/test_autoscale.py
  34. 49 0
      celery/tests/worker/test_consumer.py
  35. 73 8
      celery/tests/worker/test_control.py
  36. 34 1
      celery/tests/worker/test_loops.py
  37. 1 46
      celery/tests/worker/test_worker.py
  38. 2 2
      celery/utils/abstract.py
  39. 4 4
      celery/utils/debug.py
  40. 2 10
      celery/worker/control.py

+ 13 - 2
.coveragerc

@@ -2,6 +2,17 @@
 branch = 1
 branch = 1
 cover_pylib = 0
 cover_pylib = 0
 include=*celery/*
 include=*celery/*
-omit = celery.utils.debug,celery.tests.*,celery.bin.graph;
+omit = celery.tests.*
 [report]
 [report]
-omit = */python?.?/*,*/site-packages/*,*/pypy/*
+omit =
+    */python?.?/*
+    */site-packages/*
+    */pypy/*
+    */celery/bin/graph.py
+    *celery/bin/logtool.py
+    *celery/task/base.py
+    *celery/five.py
+    *celery/contrib/sphinx.py
+    *celery/backends/couchdb.py
+    *celery/backends/couchbase.py
+    *celery/backends/cassandra.py

+ 1 - 2
celery/app/base.py

@@ -19,7 +19,7 @@ from functools import wraps
 from amqp import starpromise
 from amqp import starpromise
 try:
 try:
     from billiard.util import register_after_fork
     from billiard.util import register_after_fork
-except ImportError:
+except ImportError:  # pragma: no cover
     register_after_fork = None
     register_after_fork = None
 from kombu.clocks import LamportClock
 from kombu.clocks import LamportClock
 from kombu.common import oid_from
 from kombu.common import oid_from
@@ -771,7 +771,6 @@ class Celery(object):
     def select_queues(self, queues=None):
     def select_queues(self, queues=None):
         """Select a subset of queues, where queues must be a list of queue
         """Select a subset of queues, where queues must be a list of queue
         names to keep."""
         names to keep."""
-
         return self.amqp.queues.select(queues)
         return self.amqp.queues.select(queues)
 
 
     def either(self, default_key, *values):
     def either(self, default_key, *values):

+ 1 - 1
celery/app/defaults.py

@@ -335,7 +335,7 @@ SETTING_KEYS = set(keys(DEFAULTS))
 _OLD_SETTING_KEYS = set(keys(_TO_NEW_KEY))
 _OLD_SETTING_KEYS = set(keys(_TO_NEW_KEY))
 
 
 
 
-def find_deprecated_settings(source):
+def find_deprecated_settings(source):  # pragma: no cover
     from celery.utils import warn_deprecated
     from celery.utils import warn_deprecated
     for name, opt in flatten(NAMESPACES):
     for name, opt in flatten(NAMESPACES):
         if (opt.deprecate_by or opt.remove_by) and getattr(source, name, None):
         if (opt.deprecate_by or opt.remove_by) and getattr(source, name, None):

+ 1 - 1
celery/app/task.py

@@ -477,7 +477,7 @@ class Task(object):
         """
         """
         try:
         try:
             check_arguments = self.__header__
             check_arguments = self.__header__
-        except AttributeError:
+        except AttributeError:  # pragma: no cover
             pass
             pass
         else:
         else:
             check_arguments(*(args or ()), **(kwargs or {}))
             check_arguments(*(args or ()), **(kwargs or {}))

+ 2 - 2
celery/app/trace.py

@@ -390,12 +390,12 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                                     else:
                                     else:
                                         sigs.append(sig)
                                         sigs.append(sig)
                                 for group_ in groups:
                                 for group_ in groups:
-                                    group.apply_async(
+                                    group_.apply_async(
                                         (retval,),
                                         (retval,),
                                         parent_id=uuid, root_id=root_id,
                                         parent_id=uuid, root_id=root_id,
                                     )
                                     )
                                 if sigs:
                                 if sigs:
-                                    group(sigs).apply_async(
+                                    group(sigs, app=app).apply_async(
                                         (retval,),
                                         (retval,),
                                         parent_id=uuid, root_id=root_id,
                                         parent_id=uuid, root_id=root_id,
                                     )
                                     )

+ 1 - 1
celery/app/utils.py

@@ -141,7 +141,7 @@ class Settings(ConfigurationView):
         return filt({
         return filt({
             k: v for k, v in items(
             k: v for k, v in items(
                 self if with_defaults else self.without_defaults())
                 self if with_defaults else self.without_defaults())
-            if k.isupper() and not k.startswith('_')
+            if not k.startswith('_')
         })
         })
 
 
     def humanize(self, with_defaults=False, censored=True):
     def humanize(self, with_defaults=False, censored=True):

+ 1 - 1
celery/backends/cache.py

@@ -45,7 +45,7 @@ def import_best_memcache():
                 import memcache  # noqa
                 import memcache  # noqa
             except ImportError:
             except ImportError:
                 raise ImproperlyConfigured(REQUIRES_BACKEND)
                 raise ImproperlyConfigured(REQUIRES_BACKEND)
-        if PY3:
+        if PY3:  # pragma: no cover
             memcache_key_t = bytes_to_str
             memcache_key_t = bytes_to_str
         _imp[0] = (is_pylibmc, memcache, memcache_key_t)
         _imp[0] = (is_pylibmc, memcache, memcache_key_t)
     return _imp[0]
     return _imp[0]

+ 1 - 1
celery/backends/database/__init__.py

@@ -25,7 +25,7 @@ from .session import SessionManager
 try:
 try:
     from sqlalchemy.exc import DatabaseError, InvalidRequestError
     from sqlalchemy.exc import DatabaseError, InvalidRequestError
     from sqlalchemy.orm.exc import StaleDataError
     from sqlalchemy.orm.exc import StaleDataError
-except ImportError:
+except ImportError:  # pragma: no cover
     raise ImproperlyConfigured(
     raise ImproperlyConfigured(
         'The database result backend requires SQLAlchemy to be installed.'
         'The database result backend requires SQLAlchemy to be installed.'
         'See http://pypi.python.org/pypi/SQLAlchemy')
         'See http://pypi.python.org/pypi/SQLAlchemy')

+ 2 - 1
celery/backends/database/session.py

@@ -10,7 +10,7 @@ from __future__ import absolute_import
 
 
 try:
 try:
     from billiard.util import register_after_fork
     from billiard.util import register_after_fork
-except ImportError:
+except ImportError:  # pragma: no cover
     register_after_fork = None
     register_after_fork = None
 
 
 from sqlalchemy import create_engine
 from sqlalchemy import create_engine
@@ -24,6 +24,7 @@ __all__ = ['SessionManager']
 
 
 
 
 class SessionManager(object):
 class SessionManager(object):
+
     def __init__(self):
     def __init__(self):
         self._engines = {}
         self._engines = {}
         self._sessions = {}
         self._sessions = {}

+ 1 - 1
celery/events/dumper.py

@@ -48,7 +48,7 @@ class Dumper(object):
         # need to flush so that output can be piped.
         # need to flush so that output can be piped.
         try:
         try:
             self.out.flush()
             self.out.flush()
-        except AttributeError:
+        except AttributeError:  # pragma: no cover
             pass
             pass
 
 
     def on_event(self, ev):
     def on_event(self, ev):

+ 10 - 8
celery/schedules.py

@@ -589,7 +589,10 @@ class crontab(schedule):
         return NotImplemented
         return NotImplemented
 
 
     def __ne__(self, other):
     def __ne__(self, other):
-        return not self.__eq__(other)
+        res = self.__eq__(other)
+        if res is NotImplemented:
+            return True
+        return not res
 
 
 
 
 def maybe_schedule(s, relative=False, app=None):
 def maybe_schedule(s, relative=False, app=None):
@@ -691,12 +694,8 @@ class solar(schedule):
         self.method = self._methods[event]
         self.method = self._methods[event]
         self.use_center = self._use_center_l[event]
         self.use_center = self._use_center_l[event]
 
 
-    def now(self):
-        return (self.nowfun or self.app.now)()
-
     def __reduce__(self):
     def __reduce__(self):
-        return (self.__class__, (
-            self.event, self.lat, self.lon), None)
+        return self.__class__, (self.event, self.lat, self.lon)
 
 
     def __repr__(self):
     def __repr__(self):
         return '<solar: {0} at latitude {1}, longitude: {2}>'.format(
         return '<solar: {0} at latitude {1}, longitude: {2}>'.format(
@@ -715,7 +714,7 @@ class solar(schedule):
                 self.ephem.Sun(),
                 self.ephem.Sun(),
                 start=last_run_at_utc, use_center=self.use_center,
                 start=last_run_at_utc, use_center=self.use_center,
             )
             )
-        except self.ephem.CircumpolarError:
+        except self.ephem.CircumpolarError:  # pragma: no cover
             """Sun will not rise/set today. Check again tomorrow
             """Sun will not rise/set today. Check again tomorrow
             (specifically, after the next anti-transit)."""
             (specifically, after the next anti-transit)."""
             next_utc = (
             next_utc = (
@@ -750,4 +749,7 @@ class solar(schedule):
         return NotImplemented
         return NotImplemented
 
 
     def __ne__(self, other):
     def __ne__(self, other):
-        return not self.__eq__(other)
+        res = self.__eq__(other)
+        if res is NotImplemented:
+            return True
+        return not res

+ 118 - 0
celery/tests/app/test_app.py

@@ -9,6 +9,7 @@ from pickle import loads, dumps
 
 
 from amqp import promise
 from amqp import promise
 
 
+from celery import Celery
 from celery import shared_task, current_app
 from celery import shared_task, current_app
 from celery import app as _app
 from celery import app as _app
 from celery import _state
 from celery import _state
@@ -19,12 +20,14 @@ from celery.five import items, keys
 from celery.loaders.base import BaseLoader, unconfigured
 from celery.loaders.base import BaseLoader, unconfigured
 from celery.platforms import pyimplementation
 from celery.platforms import pyimplementation
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
+from celery.utils.timeutils import timezone
 
 
 from celery.tests.case import (
 from celery.tests.case import (
     CELERY_TEST_CONFIG,
     CELERY_TEST_CONFIG,
     AppCase,
     AppCase,
     Mock,
     Mock,
     Case,
     Case,
+    ContextMock,
     depends_on_current_app,
     depends_on_current_app,
     mask_modules,
     mask_modules,
     patch,
     patch,
@@ -128,6 +131,12 @@ class test_App(AppCase):
             task = app.task(fun)
             task = app.task(fun)
             self.assertEqual(task.name, app.main + '.fun')
             self.assertEqual(task.name, app.main + '.fun')
 
 
+    def test_task_too_many_args(self):
+        with self.assertRaises(TypeError):
+            self.app.task(Mock(name='fun'), True)
+        with self.assertRaises(TypeError):
+            self.app.task(Mock(name='fun'), True, 1, 2)
+
     def test_with_config_source(self):
     def test_with_config_source(self):
         with self.Celery(config_source=ObjectConfig) as app:
         with self.Celery(config_source=ObjectConfig) as app:
             self.assertEqual(app.conf.FOO, 1)
             self.assertEqual(app.conf.FOO, 1)
@@ -235,6 +244,18 @@ class test_App(AppCase):
             self.assertEqual(prom.fun, self.app._autodiscover_tasks)
             self.assertEqual(prom.fun, self.app._autodiscover_tasks)
             self.assertEqual(prom.args[0](), [1, 2, 3])
             self.assertEqual(prom.args[0](), [1, 2, 3])
 
 
+    def test_autodiscover_tasks__no_packages(self):
+        fixup1 = Mock(name='fixup')
+        fixup2 = Mock(name='fixup')
+        self.app._autodiscover_tasks_from_names = Mock(name='auto')
+        self.app._fixups = [fixup1, fixup2]
+        fixup1.autodiscover_tasks.return_value = ['A', 'B', 'C']
+        fixup2.autodiscover_tasks.return_value = ['D', 'E', 'F']
+        self.app.autodiscover_tasks(force=True)
+        self.app._autodiscover_tasks_from_names.assert_called_with(
+            ['A', 'B', 'C', 'D', 'E', 'F'], related_name='tasks',
+        )
+
     @with_environ('CELERY_BROKER_URL', '')
     @with_environ('CELERY_BROKER_URL', '')
     def test_with_broker(self):
     def test_with_broker(self):
         with self.Celery(broker='foo://baribaz') as app:
         with self.Celery(broker='foo://baribaz') as app:
@@ -739,6 +760,86 @@ class test_App(AppCase):
         self.assertIsNone(self.app._pool)
         self.assertIsNone(self.app._pool)
         self.app._after_fork(self.app)
         self.app._after_fork(self.app)
 
 
+    def test_global_after_fork(self):
+        app = Mock(name='app')
+        prev, _state._apps = _state._apps, [app]
+        try:
+            obj = Mock(name='obj')
+            _appbase._global_after_fork(obj)
+            app._after_fork.assert_called_with(obj)
+        finally:
+            _state._apps = prev
+
+    @patch('multiprocessing.util', create=True)
+    def test_global_after_fork__raises(self, util):
+        app = Mock(name='app')
+        prev, _state._apps = _state._apps, [app]
+        try:
+            obj = Mock(name='obj')
+            exc = app._after_fork.side_effect = KeyError()
+            _appbase._global_after_fork(obj)
+            util._logger.info.assert_called_with(
+                'after forker raised exception: %r', exc, exc_info=1)
+            util._logger = None
+            _appbase._global_after_fork(obj)
+        finally:
+            _state._apps = prev
+
+    def test_ensure_after_fork__no_multiprocessing(self):
+        prev, _appbase.register_after_fork = (
+            _appbase.register_after_fork, None)
+        try:
+            _appbase._after_fork_registered = False
+            _appbase._ensure_after_fork()
+            self.assertTrue(_appbase._after_fork_registered)
+        finally:
+            _appbase.register_after_fork = prev
+
+    def test_canvas(self):
+        self.assertTrue(self.app.canvas.Signature)
+
+    def test_signature(self):
+        sig = self.app.signature('foo', (1, 2))
+        self.assertIs(sig.app, self.app)
+
+    def test_timezone__none_set(self):
+        self.app.conf.timezone = None
+        tz = self.app.timezone
+        self.assertEqual(tz, timezone.get_timezone('UTC'))
+
+    def test_compat_on_configure(self):
+        on_configure = Mock(name='on_configure')
+
+        class CompatApp(Celery):
+
+            def on_configure(self, *args, **kwargs):
+                on_configure(*args, **kwargs)
+
+        with CompatApp(set_as_current=False) as app:
+            app.loader = Mock()
+            app.loader.conf = {}
+            app._load_config()
+            on_configure.assert_called_with()
+
+    def test_add_periodic_task(self):
+
+        @self.app.task
+        def add(x, y):
+            pass
+        assert not self.app.configured
+        self.app.add_periodic_task(
+            10, self.app.signature('add', (2, 2)),
+            name='add1', expires=3,
+        )
+        self.assertTrue(self.app._pending_periodic_tasks)
+        assert not self.app.configured
+
+        sig2 = add.s(4, 4)
+        self.assertTrue(self.app.configured)
+        self.app.add_periodic_task(20, sig2, name='add2', expires=4)
+        self.assertIn('add1', self.app.conf.beat_schedule)
+        self.assertIn('add2', self.app.conf.beat_schedule)
+
     def test_pool_no_multiprocessing(self):
     def test_pool_no_multiprocessing(self):
         with mask_modules('multiprocessing.util'):
         with mask_modules('multiprocessing.util'):
             pool = self.app.pool
             pool = self.app.pool
@@ -747,6 +848,18 @@ class test_App(AppCase):
     def test_bugreport(self):
     def test_bugreport(self):
         self.assertTrue(self.app.bugreport())
         self.assertTrue(self.app.bugreport())
 
 
+    def test_send_task__connection_provided(self):
+        connection = Mock(name='connection')
+        router = Mock(name='router')
+        router.route.return_value = {}
+        self.app.amqp = Mock(name='amqp')
+        self.app.amqp.Producer.attach_mock(ContextMock(), 'return_value')
+        self.app.send_task('foo', (1, 2), connection=connection, router=router)
+        self.app.amqp.Producer.assert_called_with(connection)
+        self.app.amqp.send_task_message.assert_called_with(
+            self.app.amqp.Producer(), 'foo',
+            self.app.amqp.create_task_message())
+
     def test_send_task_sent_event(self):
     def test_send_task_sent_event(self):
 
 
         class Dispatcher(object):
         class Dispatcher(object):
@@ -799,6 +912,11 @@ class test_App(AppCase):
         x.send(Mock(), Mock())
         x.send(Mock(), Mock())
         self.assertFalse(task.app.mail_admins.called)
         self.assertFalse(task.app.mail_admins.called)
 
 
+    def test_select_queues(self):
+        self.app.amqp = Mock(name='amqp')
+        self.app.select_queues({'foo', 'bar'})
+        self.app.amqp.queues.select.assert_called_with({'foo', 'bar'})
+
 
 
 class test_defaults(AppCase):
 class test_defaults(AppCase):
 
 

+ 16 - 0
celery/tests/app/test_loaders.py

@@ -184,6 +184,22 @@ class test_DefaultLoader(AppCase):
             if prevconfig:
             if prevconfig:
                 sys.modules[configname] = prevconfig
                 sys.modules[configname] = prevconfig
 
 
+    def test_read_configuration_ImportError(self):
+        sentinel = object()
+        prev, os.environ['CELERY_CONFIG_MODULE'] = (
+            os.environ.get('CELERY_CONFIG_MODULE', sentinel), 'daweqew.dweqw',
+        )
+        try:
+            l = default.Loader(app=self.app)
+            with self.assertRaises(ImportError):
+                l.read_configuration(fail_silently=False)
+            l.read_configuration(fail_silently=True)
+        finally:
+            if prev is not sentinel:
+                os.environ['CELERY_CONFIG_MODULE'] = prev
+            else:
+                os.environ.pop('CELERY_CONFIG_MODULE', None)
+
     def test_import_from_cwd(self):
     def test_import_from_cwd(self):
         l = default.Loader(app=self.app)
         l = default.Loader(app=self.app)
         old_path = list(sys.path)
         old_path = list(sys.path)

+ 1 - 0
celery/tests/app/test_log.py

@@ -199,6 +199,7 @@ class test_default_logger(AppCase):
     def test_configure_logger(self):
     def test_configure_logger(self):
         logger = self.app.log.get_default_logger()
         logger = self.app.log.get_default_logger()
         self.app.log._configure_logger(logger, sys.stderr, None, '', False)
         self.app.log._configure_logger(logger, sys.stderr, None, '', False)
+        self.app.log._configure_logger(None, sys.stderr, None, '', False)
         logger.handlers[:] = []
         logger.handlers[:] = []
 
 
     def test_setup_logging_subsystem_colorize(self):
     def test_setup_logging_subsystem_colorize(self):

+ 7 - 1
celery/tests/app/test_routes.py

@@ -1,6 +1,6 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-from kombu import Exchange
+from kombu import Exchange, Queue
 from kombu.utils.functional import maybe_evaluate
 from kombu.utils.functional import maybe_evaluate
 
 
 from celery.app import routes
 from celery.app import routes
@@ -121,6 +121,12 @@ class test_lookup_route(RouteCase):
         dest = x.expand_destination('foo')
         dest = x.expand_destination('foo')
         self.assertEqual(dest['queue'].name, 'foo')
         self.assertEqual(dest['queue'].name, 'foo')
 
 
+    def test_expand_destination__Queue(self):
+        queue = Queue('foo')
+        x = Router(self.app, {}, self.app.amqp.queues)
+        dest = x.expand_destination({'queue': queue})
+        self.assertIs(dest['queue'], queue)
+
     def test_lookup_paths_traversed(self):
     def test_lookup_paths_traversed(self):
         set_queues(
         set_queues(
             self.app, foo=self.a_queue, bar=self.b_queue,
             self.app, foo=self.a_queue, bar=self.b_queue,

+ 72 - 2
celery/tests/app/test_schedules.py

@@ -7,8 +7,10 @@ from datetime import datetime, timedelta
 from pickle import dumps, loads
 from pickle import dumps, loads
 
 
 from celery.five import items
 from celery.five import items
-from celery.schedules import ParseException, crontab, crontab_parser
-from celery.tests.case import AppCase, SkipTest
+from celery.schedules import (
+    ParseException, crontab, crontab_parser, schedule, solar,
+)
+from celery.tests.case import AppCase, Mock, SkipTest
 
 
 
 
 @contextmanager
 @contextmanager
@@ -21,6 +23,73 @@ def patch_crontab_nowfun(cls, retval):
         cls.nowfun = prev_nowfun
         cls.nowfun = prev_nowfun
 
 
 
 
+class test_solar(AppCase):
+
+    def setup(self):
+        try:
+            import ephem  # noqa
+        except ImportError:
+            raise SkipTest('ephem module not installed')
+        self.s = solar('sunrise', 60, 30, app=self.app)
+
+    def test_reduce(self):
+        fun, args = self.s.__reduce__()
+        self.assertEqual(fun(*args), self.s)
+
+    def test_eq(self):
+        self.assertEqual(self.s, solar('sunrise', 60, 30, app=self.app))
+        self.assertNotEqual(self.s, solar('sunset', 60, 30, app=self.app))
+        self.assertNotEqual(self.s, schedule(10))
+
+    def test_repr(self):
+        self.assertTrue(repr(self.s))
+
+    def test_is_due(self):
+        self.s.remaining_estimate = Mock(name='rem')
+        self.s.remaining_estimate.return_value = timedelta(seconds=0)
+        self.assertTrue(self.s.is_due(datetime.utcnow()).is_due)
+
+    def test_is_due__not_due(self):
+        self.s.remaining_estimate = Mock(name='rem')
+        self.s.remaining_estimate.return_value = timedelta(hours=10)
+        self.assertFalse(self.s.is_due(datetime.utcnow()).is_due)
+
+    def test_remaining_estimate(self):
+        self.s.cal = Mock(name='cal')
+        self.s.cal.next_rising().datetime.return_value = datetime.utcnow()
+        self.s.remaining_estimate(datetime.utcnow())
+
+    def test_coordinates(self):
+        with self.assertRaises(ValueError):
+            solar('sunrise', -120, 60)
+        with self.assertRaises(ValueError):
+            solar('sunrise', 120, 60)
+        with self.assertRaises(ValueError):
+            solar('sunrise', 60, -200)
+        with self.assertRaises(ValueError):
+            solar('sunrise', 60, 200)
+
+    def test_invalid_event(self):
+        with self.assertRaises(ValueError):
+            solar('asdqwewqew', 60, 60)
+
+
+class test_schedule(AppCase):
+
+    def test_ne(self):
+        s1 = schedule(10, app=self.app)
+        s2 = schedule(12, app=self.app)
+        s3 = schedule(10, app=self.app)
+        self.assertEqual(s1, s3)
+        self.assertNotEqual(s1, s2)
+
+    def test_pickle(self):
+        s1 = schedule(10, app=self.app)
+        fun, args = s1.__reduce__()
+        s2 = fun(*args)
+        self.assertEqual(s1, s2)
+
+
 class test_crontab_parser(AppCase):
 class test_crontab_parser(AppCase):
 
 
     def crontab(self, *args, **kwargs):
     def crontab(self, *args, **kwargs):
@@ -182,6 +251,7 @@ class test_crontab_parser(AppCase):
         )
         )
         self.assertFalse(object() == self.crontab(minute='1'))
         self.assertFalse(object() == self.crontab(minute='1'))
         self.assertFalse(self.crontab(minute='1') == object())
         self.assertFalse(self.crontab(minute='1') == object())
+        self.assertNotEqual(crontab(month_of_year='1'), schedule(10))
 
 
 
 
 class test_crontab_remaining_estimate(AppCase):
 class test_crontab_remaining_estimate(AppCase):

+ 70 - 1
celery/tests/backends/test_database.py

@@ -10,8 +10,10 @@ from celery.utils import uuid
 
 
 from celery.tests.case import (
 from celery.tests.case import (
     AppCase,
     AppCase,
+    Mock,
     SkipTest,
     SkipTest,
     depends_on_current_app,
     depends_on_current_app,
+    patch,
     skip_if_pypy,
     skip_if_pypy,
     skip_if_jython,
     skip_if_jython,
 )
 )
@@ -21,7 +23,11 @@ try:
 except ImportError:
 except ImportError:
     DatabaseBackend = Task = TaskSet = retry = None  # noqa
     DatabaseBackend = Task = TaskSet = retry = None  # noqa
 else:
 else:
-    from celery.backends.database import DatabaseBackend, retry
+    from celery.backends.database import (
+        DatabaseBackend, retry, session_cleanup,
+    )
+    from celery.backends.database import session
+    from celery.backends.database.session import SessionManager
     from celery.backends.database.models import Task, TaskSet
     from celery.backends.database.models import Task, TaskSet
 
 
 
 
@@ -31,6 +37,23 @@ class SomeClass(object):
         self.data = data
         self.data = data
 
 
 
 
+class test_session_cleanup(AppCase):
+
+    def test_context(self):
+        session = Mock(name='session')
+        with session_cleanup(session):
+            pass
+        session.close.assert_called_with()
+
+    def test_context_raises(self):
+        session = Mock(name='session')
+        with self.assertRaises(KeyError):
+            with session_cleanup(session):
+                raise KeyError()
+        session.rollback.assert_called_with()
+        session.close.assert_called_with()
+
+
 class test_DatabaseBackend(AppCase):
 class test_DatabaseBackend(AppCase):
 
 
     @skip_if_pypy
     @skip_if_pypy
@@ -188,3 +211,49 @@ class test_DatabaseBackend(AppCase):
 
 
     def test_TaskSet__repr__(self):
     def test_TaskSet__repr__(self):
         self.assertIn('foo', repr(TaskSet('foo', None)))
         self.assertIn('foo', repr(TaskSet('foo', None)))
+
+
+class test_SessionManager(AppCase):
+
+    def test_after_fork(self):
+        s = SessionManager()
+        self.assertFalse(s.forked)
+        s._after_fork()
+        self.assertTrue(s.forked)
+
+    @patch('celery.backends.database.session.create_engine')
+    def test_get_engine_forked(self, create_engine):
+        s = SessionManager()
+        s._after_fork()
+        engine = s.get_engine('dburi', foo=1)
+        create_engine.assert_called_with('dburi', foo=1)
+        self.assertIs(engine, create_engine())
+        engine2 = s.get_engine('dburi', foo=1)
+        self.assertIs(engine2, engine)
+
+    @patch('celery.backends.database.session.sessionmaker')
+    def test_create_session_forked(self, sessionmaker):
+        s = SessionManager()
+        s.get_engine = Mock(name='get_engine')
+        s._after_fork()
+        engine, session = s.create_session('dburi', short_lived_sessions=True)
+        sessionmaker.assert_called_with(bind=s.get_engine())
+        self.assertIs(session, sessionmaker())
+        sessionmaker.return_value = Mock(name='new')
+        engine, session2 = s.create_session('dburi', short_lived_sessions=True)
+        sessionmaker.assert_called_with(bind=s.get_engine())
+        self.assertIsNot(session2, session)
+        sessionmaker.return_value = Mock(name='new2')
+        engine, session3 = s.create_session(
+            'dburi', short_lived_sessions=False)
+        sessionmaker.assert_called_with(bind=s.get_engine())
+        self.assertIs(session3, session2)
+
+    def test_coverage_madness(self):
+        prev, session.register_after_fork = (
+            session.register_after_fork, None,
+        )
+        try:
+            SessionManager()
+        finally:
+            session.register_after_fork = prev

+ 3 - 0
celery/tests/backends/test_rpc.py

@@ -43,6 +43,9 @@ class test_RPCBackend(AppCase):
         with self.assertRaises(RuntimeError):
         with self.assertRaises(RuntimeError):
             self.b.destination_for('task_id', None)
             self.b.destination_for('task_id', None)
 
 
+    def test_rkey(self):
+        self.assertEqual(self.b.rkey('id1'), 'id1')
+
     def test_binding(self):
     def test_binding(self):
         queue = self.b.binding
         queue = self.b.binding
         self.assertEqual(queue.name, self.b.oid)
         self.assertEqual(queue.name, self.b.oid)

+ 7 - 6
celery/tests/case.py

@@ -34,7 +34,7 @@ except ImportError:
 from nose import SkipTest
 from nose import SkipTest
 from kombu import Queue
 from kombu import Queue
 from kombu.log import NullHandler
 from kombu.log import NullHandler
-from kombu.utils import nested, symbol_by_name
+from kombu.utils import symbol_by_name
 
 
 from celery import Celery
 from celery import Celery
 from celery.app import current_app
 from celery.app import current_app
@@ -54,7 +54,7 @@ __all__ = [
     'skip_if_environ', 'todo', 'skip', 'skip_if',
     'skip_if_environ', 'todo', 'skip', 'skip_if',
     'skip_unless', 'mask_modules', 'override_stdouts', 'mock_module',
     'skip_unless', 'mask_modules', 'override_stdouts', 'mock_module',
     'replace_module_value', 'sys_platform', 'reset_modules',
     'replace_module_value', 'sys_platform', 'reset_modules',
-    'patch_modules', 'mock_context', 'mock_open', 'patch_many',
+    'patch_modules', 'mock_context', 'mock_open',
     'assert_signal_called', 'skip_if_pypy',
     'assert_signal_called', 'skip_if_pypy',
     'skip_if_jython', 'task_message_from_sig', 'restore_logging',
     'skip_if_jython', 'task_message_from_sig', 'restore_logging',
 ]
 ]
@@ -315,6 +315,11 @@ class Case(unittest.TestCase):
         self.addCleanup(manager.stop)
         self.addCleanup(manager.stop)
         return patched
         return patched
 
 
+    def mock_modules(self, *modules):
+        manager = mock_module(*modules)
+        manager.__enter__()
+        self.addCleanup(partial(manager.__exit__, None, None, None))
+
     def assertWarns(self, expected_warning):
     def assertWarns(self, expected_warning):
         return _AssertWarnsContext(expected_warning, self, None)
         return _AssertWarnsContext(expected_warning, self, None)
 
 
@@ -815,10 +820,6 @@ def mock_open(typ=WhateverIO, side_effect=None):
             yield val
             yield val
 
 
 
 
-def patch_many(*targets):
-    return nested(*[patch(target) for target in targets])
-
-
 @contextmanager
 @contextmanager
 def assert_signal_called(signal, **expected):
 def assert_signal_called(signal, **expected):
     handler = Mock()
     handler = Mock()

+ 46 - 1
celery/tests/concurrency/test_concurrency.py

@@ -5,7 +5,8 @@ import os
 from itertools import count
 from itertools import count
 
 
 from celery.concurrency.base import apply_target, BasePool
 from celery.concurrency.base import apply_target, BasePool
-from celery.tests.case import AppCase, Mock
+from celery.exceptions import WorkerShutdown, WorkerTerminate
+from celery.tests.case import AppCase, Mock, patch
 
 
 
 
 class test_BasePool(AppCase):
 class test_BasePool(AppCase):
@@ -47,6 +48,47 @@ class test_BasePool(AppCase):
                              {'target': (3, (8, 16)),
                              {'target': (3, (8, 16)),
                               'callback': (4, (42,))})
                               'callback': (4, (42,))})
 
 
+    def test_apply_target__propagate(self):
+        target = Mock(name='target')
+        target.side_effect = KeyError()
+        with self.assertRaises(KeyError):
+            apply_target(target, propagate=(KeyError,))
+
+    def test_apply_target__raises(self):
+        target = Mock(name='target')
+        target.side_effect = KeyError()
+        with self.assertRaises(KeyError):
+            apply_target(target)
+
+    def test_apply_target__raises_WorkerShutdown(self):
+        target = Mock(name='target')
+        target.side_effect = WorkerShutdown()
+        with self.assertRaises(WorkerShutdown):
+            apply_target(target)
+
+    def test_apply_target__raises_WorkerTerminate(self):
+        target = Mock(name='target')
+        target.side_effect = WorkerTerminate()
+        with self.assertRaises(WorkerTerminate):
+            apply_target(target)
+
+    def test_apply_target__raises_BaseException(self):
+        target = Mock(name='target')
+        callback = Mock(name='callback')
+        target.side_effect = BaseException()
+        apply_target(target, callback=callback)
+        self.assertTrue(callback.called)
+
+    @patch('celery.concurrency.base.reraise')
+    def test_apply_target__raises_BaseException_raises_else(self, reraise):
+        target = Mock(name='target')
+        callback = Mock(name='callback')
+        reraise.side_effect = KeyError()
+        target.side_effect = BaseException()
+        with self.assertRaises(KeyError):
+            apply_target(target, callback=callback)
+        self.assertFalse(callback.called)
+
     def test_does_not_debug(self):
     def test_does_not_debug(self):
         x = BasePool(10)
         x = BasePool(10)
         x._does_debug = False
         x._does_debug = False
@@ -67,6 +109,9 @@ class test_BasePool(AppCase):
     def test_interface_info(self):
     def test_interface_info(self):
         self.assertDictEqual(BasePool(10).info, {})
         self.assertDictEqual(BasePool(10).info, {})
 
 
+    def test_interface_flush(self):
+        self.assertIsNone(BasePool(10).flush())
+
     def test_active(self):
     def test_active(self):
         p = BasePool(10)
         p = BasePool(10)
         self.assertFalse(p.active)
         self.assertFalse(p.active)

+ 68 - 43
celery/tests/concurrency/test_eventlet.py

@@ -3,29 +3,20 @@ from __future__ import absolute_import
 import os
 import os
 import sys
 import sys
 
 
-from celery.app.defaults import is_pypy
 from celery.concurrency.eventlet import (
 from celery.concurrency.eventlet import (
     apply_target,
     apply_target,
     Timer,
     Timer,
     TaskPool,
     TaskPool,
 )
 )
 
 
-from celery.tests.case import (
-    AppCase, Mock, SkipTest, mock_module, patch, patch_many, skip_if_pypy,
-)
+from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
 
 
 
 
 class EventletCase(AppCase):
 class EventletCase(AppCase):
 
 
     @skip_if_pypy
     @skip_if_pypy
     def setup(self):
     def setup(self):
-        if is_pypy:
-            raise SkipTest('mock_modules not working on PyPy1.9')
-        try:
-            self.eventlet = __import__('eventlet')
-        except ImportError:
-            raise SkipTest(
-                'eventlet not installed, skipping related tests.')
+        self.mock_modules(*eventlet_modules)
 
 
     @skip_if_pypy
     @skip_if_pypy
     def teardown(self):
     def teardown(self):
@@ -68,46 +59,80 @@ eventlet_modules = (
 
 
 class test_Timer(EventletCase):
 class test_Timer(EventletCase):
 
 
+    def setup(self):
+        EventletCase.setup(self)
+        self.spawn_after = self.patch('eventlet.greenthread.spawn_after')
+        self.GreenletExit = self.patch('greenlet.GreenletExit')
+
     def test_sched(self):
     def test_sched(self):
-        with mock_module(*eventlet_modules):
-            with patch_many('eventlet.greenthread.spawn_after',
-                            'greenlet.GreenletExit') as (spawn_after,
-                                                         GreenletExit):
-                x = Timer()
-                x.GreenletExit = KeyError
-                entry = Mock()
-                g = x._enter(1, 0, entry)
-                self.assertTrue(x.queue)
-
-                x._entry_exit(g, entry)
-                g.wait.side_effect = KeyError()
-                x._entry_exit(g, entry)
-                entry.cancel.assert_called_with()
-                self.assertFalse(x._queue)
-
-                x._queue.add(g)
-                x.clear()
-                x._queue.add(g)
-                g.cancel.side_effect = KeyError()
-                x.clear()
+        x = Timer()
+        x.GreenletExit = KeyError
+        entry = Mock()
+        g = x._enter(1, 0, entry)
+        self.assertTrue(x.queue)
+
+        x._entry_exit(g, entry)
+        g.wait.side_effect = KeyError()
+        x._entry_exit(g, entry)
+        entry.cancel.assert_called_with()
+        self.assertFalse(x._queue)
+
+        x._queue.add(g)
+        x.clear()
+        x._queue.add(g)
+        g.cancel.side_effect = KeyError()
+        x.clear()
+
+    def test_cancel(self):
+        x = Timer()
+        tref = Mock(name='tref')
+        x.cancel(tref)
+        tref.cancel.assert_called_with()
+        x.GreenletExit = KeyError
+        tref.cancel.side_effect = KeyError()
+        x.cancel(tref)
 
 
 
 
 class test_TaskPool(EventletCase):
 class test_TaskPool(EventletCase):
 
 
+    def setup(self):
+        EventletCase.setup(self)
+        self.GreenPool = self.patch('eventlet.greenpool.GreenPool')
+        self.greenthread = self.patch('eventlet.greenthread')
+
     def test_pool(self):
     def test_pool(self):
-        with mock_module(*eventlet_modules):
-            with patch_many('eventlet.greenpool.GreenPool',
-                            'eventlet.greenthread') as (GreenPool,
-                                                        greenthread):
-                x = TaskPool()
-                x.on_start()
-                x.on_stop()
-                x.on_apply(Mock())
-                x._pool = None
-                x.on_stop()
-                self.assertTrue(x.getpid())
+        x = TaskPool()
+        x.on_start()
+        x.on_stop()
+        x.on_apply(Mock())
+        x._pool = None
+        x.on_stop()
+        self.assertTrue(x.getpid())
 
 
     @patch('celery.concurrency.eventlet.base')
     @patch('celery.concurrency.eventlet.base')
     def test_apply_target(self, base):
     def test_apply_target(self, base):
         apply_target(Mock(), getpid=Mock())
         apply_target(Mock(), getpid=Mock())
         self.assertTrue(base.apply_target.called)
         self.assertTrue(base.apply_target.called)
+
+    def test_grow(self):
+        x = TaskPool(10)
+        x._pool = Mock(name='_pool')
+        x.grow(2)
+        self.assertEqual(x.limit, 12)
+        x._pool.resize.assert_called_with(12)
+
+    def test_shrink(self):
+        x = TaskPool(10)
+        x._pool = Mock(name='_pool')
+        x.shrink(2)
+        self.assertEqual(x.limit, 8)
+        x._pool.resize.assert_called_with(8)
+
+    def test_get_info(self):
+        x = TaskPool(10)
+        x._pool = Mock(name='_pool')
+        self.assertDictEqual(x._get_info(), {
+            'max-concurrency': 10,
+            'free-threads': x._pool.free(),
+            'running-threads': x._pool.running(),
+        })

+ 62 - 66
celery/tests/concurrency/test_gevent.py

@@ -6,9 +6,7 @@ from celery.concurrency.gevent import (
     apply_timeout,
     apply_timeout,
 )
 )
 
 
-from celery.tests.case import (
-    AppCase, Mock, SkipTest, mock_module, patch, patch_many, skip_if_pypy,
-)
+from celery.tests.case import AppCase, Mock, patch, skip_if_pypy
 
 
 gevent_modules = (
 gevent_modules = (
     'gevent',
     'gevent',
@@ -23,80 +21,78 @@ class GeventCase(AppCase):
 
 
     @skip_if_pypy
     @skip_if_pypy
     def setup(self):
     def setup(self):
-        try:
-            self.gevent = __import__('gevent')
-        except ImportError:
-            raise SkipTest(
-                'gevent not installed, skipping related tests.')
+        self.mock_modules(*gevent_modules)
 
 
 
 
 class test_gevent_patch(GeventCase):
 class test_gevent_patch(GeventCase):
 
 
     def test_is_patched(self):
     def test_is_patched(self):
-        with mock_module(*gevent_modules):
-            with patch('gevent.monkey.patch_all', create=True) as patch_all:
-                import gevent
-                gevent.version_info = (1, 0, 0)
-                from celery import maybe_patch_concurrency
-                maybe_patch_concurrency(['x', '-P', 'gevent'])
-                self.assertTrue(patch_all.called)
+        with patch('gevent.monkey.patch_all', create=True) as patch_all:
+            import gevent
+            gevent.version_info = (1, 0, 0)
+            from celery import maybe_patch_concurrency
+            maybe_patch_concurrency(['x', '-P', 'gevent'])
+            self.assertTrue(patch_all.called)
+
 
 
+class test_Timer(GeventCase):
 
 
-class test_Timer(AppCase):
+    def setup(self):
+        GeventCase.setup(self)
+        self.greenlet = self.patch('gevent.greenlet')
+        self.GreenletExit = self.patch('gevent.greenlet.GreenletExit')
 
 
     def test_sched(self):
     def test_sched(self):
-        with mock_module(*gevent_modules):
-            with patch_many('gevent.greenlet',
-                            'gevent.greenlet.GreenletExit') as (greenlet,
-                                                                GreenletExit):
-                greenlet.Greenlet = object
-                x = Timer()
-                greenlet.Greenlet = Mock()
-                x._Greenlet.spawn_later = Mock()
-                x._GreenletExit = KeyError
-                entry = Mock()
-                g = x._enter(1, 0, entry)
-                self.assertTrue(x.queue)
-
-                x._entry_exit(g)
-                g.kill.assert_called_with()
-                self.assertFalse(x._queue)
-
-                x._queue.add(g)
-                x.clear()
-                x._queue.add(g)
-                g.kill.side_effect = KeyError()
-                x.clear()
-
-                g = x._Greenlet()
-                g.cancel()
-
-
-class test_TaskPool(AppCase):
+        self.greenlet.Greenlet = object
+        x = Timer()
+        self.greenlet.Greenlet = Mock()
+        x._Greenlet.spawn_later = Mock()
+        x._GreenletExit = KeyError
+        entry = Mock()
+        g = x._enter(1, 0, entry)
+        self.assertTrue(x.queue)
+
+        x._entry_exit(g)
+        g.kill.assert_called_with()
+        self.assertFalse(x._queue)
+
+        x._queue.add(g)
+        x.clear()
+        x._queue.add(g)
+        g.kill.side_effect = KeyError()
+        x.clear()
+
+        g = x._Greenlet()
+        g.cancel()
+
+
+class test_TaskPool(GeventCase):
+
+    def setup(self):
+        GeventCase.setup(self)
+        self.spawn_raw = self.patch('gevent.spawn_raw')
+        self.Pool = self.patch('gevent.pool.Pool')
 
 
     def test_pool(self):
     def test_pool(self):
-        with mock_module(*gevent_modules):
-            with patch_many('gevent.spawn_raw', 'gevent.pool.Pool') as (
-                    spawn_raw, Pool):
-                x = TaskPool()
-                x.on_start()
-                x.on_stop()
-                x.on_apply(Mock())
-                x._pool = None
-                x.on_stop()
-
-                x._pool = Mock()
-                x._pool._semaphore.counter = 1
-                x._pool.size = 1
-                x.grow()
-                self.assertEqual(x._pool.size, 2)
-                self.assertEqual(x._pool._semaphore.counter, 2)
-                x.shrink()
-                self.assertEqual(x._pool.size, 1)
-                self.assertEqual(x._pool._semaphore.counter, 1)
-
-                x._pool = [4, 5, 6]
-                self.assertEqual(x.num_processes, 3)
+        x = TaskPool()
+        x.on_start()
+        x.on_stop()
+        x.on_apply(Mock())
+        x._pool = None
+        x.on_stop()
+
+        x._pool = Mock()
+        x._pool._semaphore.counter = 1
+        x._pool.size = 1
+        x.grow()
+        self.assertEqual(x._pool.size, 2)
+        self.assertEqual(x._pool._semaphore.counter, 2)
+        x.shrink()
+        self.assertEqual(x._pool.size, 1)
+        self.assertEqual(x._pool._semaphore.counter, 1)
+
+        x._pool = [4, 5, 6]
+        self.assertEqual(x.num_processes, 3)
 
 
 
 
 class test_apply_timeout(AppCase):
 class test_apply_timeout(AppCase):

+ 98 - 16
celery/tests/concurrency/test_prefork.py

@@ -1,14 +1,16 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 import errno
 import errno
+import os
 import socket
 import socket
-import time
 
 
 from itertools import cycle
 from itertools import cycle
 
 
+from celery.app.defaults import DEFAULTS
+from celery.datastructures import AttributeDict
 from celery.five import items, range
 from celery.five import items, range
 from celery.utils.functional import noop
 from celery.utils.functional import noop
-from celery.tests.case import AppCase, Mock, SkipTest, patch
+from celery.tests.case import AppCase, Mock, SkipTest, patch, restore_logging
 try:
 try:
     from celery.concurrency import prefork as mp
     from celery.concurrency import prefork as mp
     from celery.concurrency import asynpool
     from celery.concurrency import asynpool
@@ -54,6 +56,67 @@ class MockResult(object):
         return self.value
         return self.value
 
 
 
 
+class test_process_initializer(AppCase):
+
+    @patch('celery.platforms.signals')
+    @patch('celery.platforms.set_mp_process_title')
+    def test_process_initializer(self, set_mp_process_title, _signals):
+        with restore_logging():
+            from celery import signals
+            from celery._state import _tls
+            from celery.concurrency.prefork import (
+                process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE,
+            )
+
+            def on_worker_process_init(**kwargs):
+                on_worker_process_init.called = True
+            on_worker_process_init.called = False
+            signals.worker_process_init.connect(on_worker_process_init)
+
+            def Loader(*args, **kwargs):
+                loader = Mock(*args, **kwargs)
+                loader.conf = {}
+                loader.override_backends = {}
+                return loader
+
+            with self.Celery(loader=Loader) as app:
+                app.conf = AttributeDict(DEFAULTS)
+                process_initializer(app, 'awesome.worker.com')
+                _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
+                _signals.reset.assert_any_call(*WORKER_SIGRESET)
+                self.assertTrue(app.loader.init_worker.call_count)
+                self.assertTrue(on_worker_process_init.called)
+                self.assertIs(_tls.current_app, app)
+                set_mp_process_title.assert_called_with(
+                    'celeryd', hostname='awesome.worker.com',
+                )
+
+                with patch('celery.app.trace.setup_worker_optimizations') as S:
+                    os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
+                    try:
+                        process_initializer(app, 'luke.worker.com')
+                        S.assert_called_with(app, 'luke.worker.com')
+                    finally:
+                        os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
+
+                os.environ['CELERY_LOG_FILE'] = 'worker%I.log'
+                app.log.setup = Mock(name='log_setup')
+                try:
+                    process_initializer(app, 'luke.worker.com')
+                finally:
+                    os.environ.pop('CELERY_LOG_FILE', None)
+
+
+class test_process_destructor(AppCase):
+
+    @patch('celery.concurrency.prefork.signals')
+    def test_process_destructor(self, signals):
+        mp.process_destructor(13, -3)
+        signals.worker_process_shutdown.send.assert_called_with(
+            sender=None, pid=13, exitcode=-3,
+        )
+
+
 class MockPool(object):
 class MockPool(object):
     started = False
     started = False
     closed = False
     closed = False
@@ -284,6 +347,39 @@ class test_TaskPool(PoolCase):
         pool.terminate()
         pool.terminate()
         self.assertTrue(_pool.terminated)
         self.assertTrue(_pool.terminated)
 
 
+    def test_restart(self):
+        pool = TaskPool(10)
+        pool._pool = Mock(name='pool')
+        pool.restart()
+        pool._pool.restart.assert_called_with()
+        pool._pool.apply_async.assert_called_with(mp.noop)
+
+    def test_did_start_ok(self):
+        pool = TaskPool(10)
+        pool._pool = Mock(name='pool')
+        self.assertIs(pool.did_start_ok(), pool._pool.did_start_ok())
+
+    def test_register_with_event_loop(self):
+        pool = TaskPool(10)
+        pool._pool = Mock(name='pool')
+        loop = Mock(name='loop')
+        pool.register_with_event_loop(loop)
+        pool._pool.register_with_event_loop.assert_called_with(loop)
+
+    def test_on_close(self):
+        pool = TaskPool(10)
+        pool._pool = Mock(name='pool')
+        pool._pool._state = mp.RUN
+        pool.on_close()
+        pool._pool.close.assert_called_with()
+
+    def test_on_close__pool_not_running(self):
+        pool = TaskPool(10)
+        pool._pool = Mock(name='pool')
+        pool._pool._state = mp.CLOSE
+        pool.on_close()
+        self.assertFalse(pool._pool.close.called)
+
     def test_apply_async(self):
     def test_apply_async(self):
         pool = TaskPool(10)
         pool = TaskPool(10)
         pool.start()
         pool.start()
@@ -320,17 +416,3 @@ class test_TaskPool(PoolCase):
         pool = TaskPool(7)
         pool = TaskPool(7)
         pool.start()
         pool.start()
         self.assertEqual(pool.num_processes, 7)
         self.assertEqual(pool.num_processes, 7)
-
-    def test_restart(self):
-        raise SkipTest('functional test')
-
-        def get_pids(pool):
-            return {p.pid for p in pool._pool._pool}
-
-        tp = self.TaskPool(5)
-        time.sleep(0.5)
-        tp.start()
-        pids = get_pids(tp)
-        tp.restart()
-        time.sleep(0.5)
-        self.assertEqual(pids, get_pids(tp))

+ 11 - 10
celery/tests/fixups/test_django.py

@@ -12,7 +12,7 @@ from celery.fixups.django import (
 )
 )
 
 
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase, Mock, patch, patch_many, patch_modules, mask_modules,
+    AppCase, Mock, patch, patch_modules, mask_modules,
 )
 )
 
 
 
 
@@ -63,15 +63,16 @@ class test_DjangoFixup(FixupCase):
 
 
     def test_install(self):
     def test_install(self):
         self.app.loader = Mock()
         self.app.loader = Mock()
+        self.cw = self.patch('os.getcwd')
+        self.p = self.patch('sys.path')
+        self.sigs = self.patch('celery.fixups.django.signals')
         with self.fixup_context(self.app) as (f, _, _):
         with self.fixup_context(self.app) as (f, _, _):
-            with patch_many('os.getcwd', 'sys.path',
-                            'celery.fixups.django.signals') as (cw, p, sigs):
-                cw.return_value = '/opt/vandelay'
-                f.install()
-                sigs.worker_init.connect.assert_called_with(f.on_worker_init)
-                self.assertEqual(self.app.loader.now, f.now)
-                self.assertEqual(self.app.loader.mail_admins, f.mail_admins)
-                p.append.assert_called_with('/opt/vandelay')
+            self.cw.return_value = '/opt/vandelay'
+            f.install()
+            self.sigs.worker_init.connect.assert_called_with(f.on_worker_init)
+            self.assertEqual(self.app.loader.now, f.now)
+            self.assertEqual(self.app.loader.mail_admins, f.mail_admins)
+            self.p.append.assert_called_with('/opt/vandelay')
 
 
     def test_now(self):
     def test_now(self):
         with self.fixup_context(self.app) as (f, _, _):
         with self.fixup_context(self.app) as (f, _, _):
@@ -114,7 +115,7 @@ class test_DjangoWorkerFixup(FixupCase):
         self.app.conf = {'CELERY_DB_REUSE_MAX': None}
         self.app.conf = {'CELERY_DB_REUSE_MAX': None}
         self.app.loader = Mock()
         self.app.loader = Mock()
         with self.fixup_context(self.app) as (f, _, _):
         with self.fixup_context(self.app) as (f, _, _):
-            with patch_many('celery.fixups.django.signals') as (sigs,):
+            with patch('celery.fixups.django.signals') as sigs:
                 f.install()
                 f.install()
                 sigs.beat_embedded_init.connect.assert_called_with(
                 sigs.beat_embedded_init.connect.assert_called_with(
                     f.close_database,
                     f.close_database,

+ 5 - 0
celery/tests/security/test_certificate.py

@@ -26,6 +26,11 @@ class test_Certificate(SecurityCase):
         raise SkipTest('cert expired')
         raise SkipTest('cert expired')
         self.assertFalse(Certificate(CERT1).has_expired())
         self.assertFalse(Certificate(CERT1).has_expired())
 
 
+    def test_has_expired_mock(self):
+        x = Certificate(CERT1)
+        x._cert = Mock(name='cert')
+        self.assertIs(x.has_expired(), x._cert.has_expired())
+
 
 
 class test_CertStore(SecurityCase):
 class test_CertStore(SecurityCase):
 
 

+ 10 - 0
celery/tests/security/test_security.py

@@ -20,6 +20,7 @@ from kombu.serialization import disable_insecure_serializers
 
 
 from celery.exceptions import ImproperlyConfigured, SecurityError
 from celery.exceptions import ImproperlyConfigured, SecurityError
 from celery.five import builtins
 from celery.five import builtins
+from celery.security import disable_untrusted_serializers, setup_security
 from celery.security.utils import reraise_errors
 from celery.security.utils import reraise_errors
 from kombu.serialization import registry
 from kombu.serialization import registry
 
 
@@ -53,6 +54,11 @@ class test_security(SecurityCase):
         finally:
         finally:
             disable_insecure_serializers(allowed=['json'])
             disable_insecure_serializers(allowed=['json'])
 
 
+    @patch('celery.security._disable_insecure_serializers')
+    def test_disable_untrusted_serializers(self, disable):
+        disable_untrusted_serializers(['foo'])
+        disable.assert_called_with(allowed=['foo'])
+
     def test_setup_security(self):
     def test_setup_security(self):
         disabled = registry._disabled_content_types
         disabled = registry._disabled_content_types
         self.assertEqual(0, len(disabled))
         self.assertEqual(0, len(disabled))
@@ -62,6 +68,10 @@ class test_security(SecurityCase):
         self.assertIn('application/x-python-serialize', disabled)
         self.assertIn('application/x-python-serialize', disabled)
         disabled.clear()
         disabled.clear()
 
 
+    @patch('celery.current_app')
+    def test_setup_security__default_app(self, current_app):
+        setup_security()
+
     @patch('celery.security.register_auth')
     @patch('celery.security.register_auth')
     @patch('celery.security._disable_insecure_serializers')
     @patch('celery.security._disable_insecure_serializers')
     def test_setup_registry_complete(self, dis, reg, key='KEY', cert='CERT'):
     def test_setup_registry_complete(self, dis, reg, key='KEY', cert='CERT'):

+ 61 - 2
celery/tests/tasks/test_tasks.py

@@ -6,13 +6,17 @@ from kombu import Queue
 
 
 from celery import Task
 from celery import Task
 
 
-from celery.exceptions import Retry
+from celery import group
+from celery.app.task import _reprtask
+from celery.exceptions import Ignore, Retry
 from celery.five import items, range, string_t
 from celery.five import items, range, string_t
 from celery.result import EagerResult
 from celery.result import EagerResult
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.timeutils import parse_iso8601
 from celery.utils.timeutils import parse_iso8601
 
 
-from celery.tests.case import AppCase, depends_on_current_app, patch
+from celery.tests.case import (
+    AppCase, ContextMock, Mock, depends_on_current_app, patch,
+)
 
 
 
 
 def return_True(*args, **kwargs):
 def return_True(*args, **kwargs):
@@ -269,6 +273,20 @@ class test_tasks(TasksCase):
             pass
             pass
         self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
         self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
 
 
+    @patch('celery.app.task.current_app')
+    @depends_on_current_app
+    def test_bind__no_app(self, current_app):
+        class XTask(Task):
+            _app = None
+        XTask._app = None
+        XTask.__bound__ = False
+        XTask.bind = Mock(name='bind')
+        self.assertIs(XTask.app, current_app)
+        XTask.bind.assert_called_with(current_app)
+
+    def test_reprtask__no_fmt(self):
+        self.assertTrue(_reprtask(self.mytask))
+
     def test_AsyncResult(self):
     def test_AsyncResult(self):
         task_id = uuid()
         task_id = uuid()
         result = self.retry_task.AsyncResult(task_id)
         result = self.retry_task.AsyncResult(task_id)
@@ -375,6 +393,47 @@ class test_tasks(TasksCase):
             self.mytask.backend.mark_as_done(presult.id, result=None)
             self.mytask.backend.mark_as_done(presult.id, result=None)
             self.assertTrue(presult.successful())
             self.assertTrue(presult.successful())
 
 
+    def test_send_event(self):
+        mytask = self.mytask._get_current_object()
+        mytask.app.events = Mock(name='events')
+        mytask.app.events.attach_mock(ContextMock(), 'default_dispatcher')
+        mytask.request.id = 'fb'
+        mytask.send_event('task-foo', id=3122)
+        mytask.app.events.default_dispatcher().send.assert_called_with(
+            'task-foo', uuid='fb', id=3122,
+        )
+
+    def test_replace(self):
+        sig1 = Mock(name='sig1')
+        with self.assertRaises(Ignore):
+            self.mytask.replace(sig1)
+
+    def test_replace__group(self):
+        c = group([self.mytask.s()], app=self.app)
+        c.freeze = Mock(name='freeze')
+        c.delay = Mock(name='delay')
+        self.mytask.request.id = 'id'
+        self.mytask.request.group = 'group'
+        self.mytask.request.root_id = 'root_id',
+        with self.assertRaises(Ignore):
+            self.mytask.replace(c)
+
+    def test_send_error_email_enabled(self):
+        mytask = self.increment_counter._get_current_object()
+        mytask.send_error_emails = True
+        mytask.disable_error_emails = False
+        mytask.ErrorMail = Mock(name='ErrorMail')
+        context = Mock(name='context')
+        exc = Mock(name='context')
+        mytask.send_error_email(context, exc, foo=1)
+        mytask.ErrorMail.assert_called_with(mytask, foo=1)
+        mytask.ErrorMail().send.assert_called_with(context, exc)
+
+    def test_add_trail__no_trail(self):
+        mytask = self.increment_counter._get_current_object()
+        mytask.trail = False
+        mytask.add_trail('foo')
+
     def test_repr_v2_compat(self):
     def test_repr_v2_compat(self):
         self.mytask.__v2_compat__ = True
         self.mytask.__v2_compat__ = True
         self.assertIn('v2 compatible', repr(self.mytask))
         self.assertIn('v2 compatible', repr(self.mytask))

+ 113 - 2
celery/tests/tasks/test_trace.py

@@ -1,12 +1,20 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-from celery import uuid
+from kombu.exceptions import EncodeError
+
+from celery import group, uuid
 from celery import signals
 from celery import signals
 from celery import states
 from celery import states
-from celery.exceptions import Ignore, Retry
+from celery.exceptions import Ignore, Retry, Reject
 from celery.app.trace import (
 from celery.app.trace import (
     TraceInfo,
     TraceInfo,
     build_tracer,
     build_tracer,
+    get_log_policy,
+    log_policy_reject,
+    log_policy_ignore,
+    log_policy_internal,
+    log_policy_expected,
+    log_policy_unexpected,
     trace_task,
     trace_task,
     setup_worker_optimizations,
     setup_worker_optimizations,
     reset_worker_optimizations,
     reset_worker_optimizations,
@@ -60,6 +68,33 @@ class test_trace(TraceCase):
         self.trace(add_with_success, (2, 2), {})
         self.trace(add_with_success, (2, 2), {})
         self.assertTrue(add_with_success.on_success.called)
         self.assertTrue(add_with_success.on_success.called)
 
 
+    def test_get_log_policy(self):
+        einfo = Mock(name='einfo')
+        einfo.internal = False
+        self.assertIs(
+            get_log_policy(self.add, einfo, Reject()),
+            log_policy_reject,
+        )
+        self.assertIs(
+            get_log_policy(self.add, einfo, Ignore()),
+            log_policy_ignore,
+        )
+        self.add.throws = (TypeError,)
+        self.assertIs(
+            get_log_policy(self.add, einfo, KeyError()),
+            log_policy_unexpected,
+        )
+        self.assertIs(
+            get_log_policy(self.add, einfo, TypeError()),
+            log_policy_expected,
+        )
+        einfo2 = Mock(name='einfo2')
+        einfo2.internal = True
+        self.assertIs(
+            get_log_policy(self.add, einfo2, KeyError()),
+            log_policy_internal,
+        )
+
     def test_trace_after_return(self):
     def test_trace_after_return(self):
 
 
         @self.app.task(shared=False, after_return=Mock())
         @self.app.task(shared=False, after_return=Mock())
@@ -134,6 +169,74 @@ class test_trace(TraceCase):
         retval, info = self.trace(ignored, (), {})
         retval, info = self.trace(ignored, (), {})
         self.assertEqual(info.state, states.IGNORED)
         self.assertEqual(info.state, states.IGNORED)
 
 
+    def test_when_Reject(self):
+
+        @self.app.task(shared=False)
+        def rejecting():
+            raise Reject()
+
+        retval, info = self.trace(rejecting, (), {})
+        self.assertEqual(info.state, states.REJECTED)
+
+    @patch('celery.canvas.maybe_signature')
+    def test_callbacks__scalar(self, maybe_signature):
+        sig = Mock(name='sig')
+        request = {'callbacks': [sig], 'root_id': 'root'}
+        maybe_signature.return_value = sig
+        retval, _ = self.trace(self.add, (2, 2), {}, request=request)
+        sig.apply_async.assert_called_with(
+            (4,), parent_id='id-1', root_id='root',
+        )
+
+    @patch('celery.canvas.maybe_signature')
+    def test_callbacks__EncodeError(self, maybe_signature):
+        sig = Mock(name='sig')
+        request = {'callbacks': [sig], 'root_id': 'root'}
+        maybe_signature.return_value = sig
+        sig.apply_async.side_effect = EncodeError()
+        retval, einfo = self.trace(self.add, (2, 2), {}, request=request)
+        self.assertEqual(einfo.state, states.FAILURE)
+
+    @patch('celery.canvas.maybe_signature')
+    @patch('celery.app.trace.group.apply_async')
+    def test_callbacks__sigs(self, group_, maybe_signature):
+        sig1 = Mock(name='sig')
+        sig2 = Mock(name='sig2')
+        sig3 = group([Mock(name='g1'), Mock(name='g2')], app=self.app)
+        sig3.apply_async = Mock(name='gapply')
+        request = {'callbacks': [sig1, sig3, sig2], 'root_id': 'root'}
+
+        def passt(s, *args, **kwargs):
+            return s
+        maybe_signature.side_effect = passt
+        retval, _ = self.trace(self.add, (2, 2), {}, request=request)
+        group_.assert_called_with(
+            (4,), parent_id='id-1', root_id='root',
+        )
+        sig3.apply_async.assert_called_with(
+            (4,), parent_id='id-1', root_id='root',
+        )
+
+    @patch('celery.canvas.maybe_signature')
+    @patch('celery.app.trace.group.apply_async')
+    def test_callbacks__only_groups(self, group_, maybe_signature):
+        sig1 = group([Mock(name='g1'), Mock(name='g2')], app=self.app)
+        sig2 = group([Mock(name='g3'), Mock(name='g4')], app=self.app)
+        sig1.apply_async = Mock(name='gapply')
+        sig2.apply_async = Mock(name='gapply')
+        request = {'callbacks': [sig1, sig2], 'root_id': 'root'}
+
+        def passt(s, *args, **kwargs):
+            return s
+        maybe_signature.side_effect = passt
+        retval, _ = self.trace(self.add, (2, 2), {}, request=request)
+        sig1.apply_async.assert_called_with(
+            (4,), parent_id='id-1', root_id='root',
+        )
+        sig2.apply_async.assert_called_with(
+            (4,), parent_id='id-1', root_id='root',
+        )
+
     def test_trace_SystemExit(self):
     def test_trace_SystemExit(self):
         with self.assertRaises(SystemExit):
         with self.assertRaises(SystemExit):
             self.trace(self.raises, (SystemExit(),), {})
             self.trace(self.raises, (SystemExit(),), {})
@@ -184,6 +287,14 @@ class test_TraceInfo(TraceCase):
             store_errors=self.add_cast.store_errors_even_if_ignored,
             store_errors=self.add_cast.store_errors_even_if_ignored,
         )
         )
 
 
+    @patch('celery.app.trace.ExceptionInfo')
+    def test_handle_reject(self, ExceptionInfo):
+        x = self.TI(states.FAILURE)
+        x._log_error = Mock(name='log_error')
+        req = Mock(name='req')
+        x.handle_reject(self.add, req)
+        x._log_error.assert_called_with(self.add, req, ExceptionInfo())
+
 
 
 class test_stackprotection(AppCase):
 class test_stackprotection(AppCase):
 
 

+ 98 - 0
celery/tests/utils/test_debug.py

@@ -0,0 +1,98 @@
+from __future__ import absolute_import, unicode_literals
+
+from celery.utils import debug
+
+from celery.tests.case import Case, Mock, patch
+
+
+class test_on_blocking(Case):
+
+    @patch('inspect.getframeinfo')
+    def test_on_blocking(self, getframeinfo):
+        frame = Mock(name='frame')
+        with self.assertRaises(RuntimeError):
+            debug._on_blocking(1, frame)
+            getframeinfo.assert_called_with(frame)
+
+
+class test_blockdetection(Case):
+
+    @patch('celery.utils.debug.signals')
+    def test_context(self, signals):
+        with debug.blockdetection(10):
+            signals.arm_alarm.assert_called_with(10)
+            signals.__setitem__.assert_called_with('ALRM', debug._on_blocking)
+        signals.__setitem__.assert_called_with('ALRM', signals['ALRM'])
+        signals.reset_alarm.assert_called_with()
+
+
+class test_sample_mem(Case):
+
+    @patch('celery.utils.debug.mem_rss')
+    def test_sample_mem(self, mem_rss):
+        prev, debug._mem_sample = debug._mem_sample, []
+        try:
+            debug.sample_mem()
+            self.assertIs(debug._mem_sample[0], mem_rss())
+        finally:
+            debug._mem_sample = prev
+
+
+class test_sample(Case):
+
+    def test_sample(self):
+        x = list(range(100))
+        self.assertEqual(
+            list(debug.sample(x, 10)),
+            [0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
+        )
+        x = list(range(91))
+        self.assertEqual(
+            list(debug.sample(x, 10)),
+            [0, 9, 18, 27, 36, 45, 54, 63, 72, 81],
+        )
+
+
+class test_hfloat(Case):
+
+    def test_hfloat(self):
+        self.assertEqual(str(debug.hfloat(10, 5)), "10")
+        self.assertEqual(str(debug.hfloat(10.45645234234, 5)), "10.456")
+
+
+class test_humanbytes(Case):
+
+    def test_humanbytes(self):
+        self.assertEqual(debug.humanbytes(2 ** 20), "1MB")
+        self.assertEqual(debug.humanbytes(4 * 2 ** 20), "4MB")
+        self.assertEqual(debug.humanbytes(2 ** 16), "64kB")
+        self.assertEqual(debug.humanbytes(2 ** 16), "64kB")
+        self.assertEqual(debug.humanbytes(2 ** 8), "256b")
+
+
+class test_mem_rss(Case):
+
+    @patch('celery.utils.debug.ps')
+    @patch('celery.utils.debug.humanbytes')
+    def test_mem_rss(self, humanbytes, ps):
+        ret = debug.mem_rss()
+        ps.assert_called_with()
+        ps().get_memory_info.assert_called_with()
+        humanbytes.assert_called_with(ps().get_memory_info().rss)
+        self.assertIs(ret, humanbytes())
+        ps.return_value = None
+        self.assertIsNone(debug.mem_rss())
+
+
+class test_ps(Case):
+
+    @patch('celery.utils.debug.Process')
+    @patch('os.getpid')
+    def test_ps(self, getpid, Process):
+        prev, debug._process = debug._process, None
+        try:
+            debug.ps()
+            Process.assert_called_with(getpid())
+            self.assertIs(debug._process, Process())
+        finally:
+            debug._process = prev

+ 31 - 1
celery/tests/utils/test_mail.py

@@ -1,6 +1,6 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-from celery.utils.mail import Message, Mailer, SSLError
+from celery.utils.mail import Message, Mailer, SSLError, ErrorMail
 
 
 from celery.tests.case import Case, Mock, patch
 from celery.tests.case import Case, Mock, patch
 
 
@@ -51,3 +51,33 @@ class test_Mailer(Case):
         client.quit.side_effect = SSLError()
         client.quit.side_effect = SSLError()
         mailer._send(msg)
         mailer._send(msg)
         client.close.assert_called_with()
         client.close.assert_called_with()
+
+
+class test_ErrorMail(Case):
+
+    def setUp(self):
+        self.task = Mock(name='task')
+        self.mailer = ErrorMail(
+            self.task, subject='foo{foo} ', body='bar{bar} ',
+        )
+
+    def test_should_send(self):
+        self.assertTrue(self.mailer.should_send(Mock(), Mock()))
+
+    def test_format_subject(self):
+        self.assertEqual(
+            self.mailer.format_subject({'foo': 'FOO'}),
+            'fooFOO',
+        )
+
+    def test_format_body(self):
+        self.assertEqual(
+            self.mailer.format_body({'bar': 'BAR'}),
+            'barBAR',
+        )
+
+    def test_send(self):
+        self.mailer.send({'foo': 'FOO', 'bar': 'BAR'}, KeyError())
+        self.task.app.mail_admins.assert_called_with(
+            'fooFOO', 'barBAR', fail_silently=True,
+        )

+ 5 - 0
celery/tests/utils/test_text.py

@@ -7,6 +7,7 @@ from celery.utils.text import (
     indent,
     indent,
     pretty,
     pretty,
     truncate,
     truncate,
+    truncate_bytes,
 )
 )
 
 
 from celery.tests.case import AppCase, Case
 from celery.tests.case import AppCase, Case
@@ -68,6 +69,10 @@ class test_utils(Case):
         self.assertEqual(truncate('ABCDEFGHI', 3), 'ABC...')
         self.assertEqual(truncate('ABCDEFGHI', 3), 'ABC...')
         self.assertEqual(truncate('ABCDEFGHI', 10), 'ABCDEFGHI')
         self.assertEqual(truncate('ABCDEFGHI', 10), 'ABCDEFGHI')
 
 
+    def test_truncate_bytes(self):
+        self.assertEqual(truncate_bytes(b'ABCDEFGHI', 3), b'ABC...')
+        self.assertEqual(truncate_bytes(b'ABCDEFGHI', 10), b'ABCDEFGHI')
+
     def test_abbr(self):
     def test_abbr(self):
         self.assertEqual(abbr(None, 3), '???')
         self.assertEqual(abbr(None, 3), '???')
         self.assertEqual(abbr('ABCDEFGHI', 6), 'ABC...')
         self.assertEqual(abbr('ABCDEFGHI', 6), 'ABC...')

+ 75 - 1
celery/tests/utils/test_utils.py

@@ -8,6 +8,8 @@ from kombu import Queue
 
 
 from celery.utils import (
 from celery.utils import (
     chunks,
     chunks,
+    deprecated_property,
+    isatty,
     is_iterable,
     is_iterable,
     cached_property,
     cached_property,
     warn_deprecated,
     warn_deprecated,
@@ -22,6 +24,15 @@ def double(x):
     return x * 2
     return x * 2
 
 
 
 
+class test_isatty(Case):
+
+    def test_tty(self):
+        fh = Mock(name='fh')
+        self.assertIs(isatty(fh), fh.isatty())
+        fh.isatty.side_effect = AttributeError()
+        self.assertFalse(isatty(fh))
+
+
 class test_worker_direct(Case):
 class test_worker_direct(Case):
 
 
     def test_returns_if_queue(self):
     def test_returns_if_queue(self):
@@ -29,6 +40,61 @@ class test_worker_direct(Case):
         self.assertIs(worker_direct(q), q)
         self.assertIs(worker_direct(q), q)
 
 
 
 
+class test_deprecated_property(Case):
+
+    @patch('celery.utils.warn_deprecated')
+    def test_deprecated(self, warn_deprecated):
+
+        class X(object):
+            _foo = None
+
+            @deprecated_property(deprecation='1.2')
+            def foo(self):
+                return self._foo
+
+            @foo.setter
+            def foo(self, value):
+                self._foo = value
+
+            @foo.deleter
+            def foo(self):
+                self._foo = None
+        self.assertTrue(X.foo)
+        self.assertTrue(X.foo.__set__(None, 1))
+        self.assertTrue(X.foo.__delete__(None))
+        x = X()
+        x.foo = 10
+        warn_deprecated.assert_called_with(
+            stacklevel=3, deprecation='1.2', alternative=None,
+            description='foo', removal=None,
+        )
+        warn_deprecated.reset_mock()
+        self.assertEqual(x.foo, 10)
+        warn_deprecated.assert_called_with(
+            stacklevel=3, deprecation='1.2', alternative=None,
+            description='foo', removal=None,
+        )
+        warn_deprecated.reset_mock()
+        del(x.foo)
+        warn_deprecated.assert_called_with(
+            stacklevel=3, deprecation='1.2', alternative=None,
+            description='foo', removal=None,
+        )
+        self.assertIsNone(x._foo)
+
+    def test_deprecated_no_setter_or_deleter(self):
+        class X(object):
+            @deprecated_property(deprecation='1.2')
+            def foo(self):
+                pass
+        self.assertTrue(X.foo)
+        x = X()
+        with self.assertRaises(AttributeError):
+            x.foo = 10
+        with self.assertRaises(AttributeError):
+            del(x.foo)
+
+
 class test_gen_task_name(Case):
 class test_gen_task_name(Case):
 
 
     def test_no_module(self):
     def test_no_module(self):
@@ -54,8 +120,16 @@ class test_jsonify(Case):
         self.assertTrue(jsonify(10.3))
         self.assertTrue(jsonify(10.3))
         self.assertTrue(jsonify('hello'))
         self.assertTrue(jsonify('hello'))
 
 
+        unknown_type_filter = Mock()
+        obj = object()
+        self.assertIs(
+            jsonify(obj, unknown_type_filter=unknown_type_filter),
+            unknown_type_filter.return_value,
+        )
+        unknown_type_filter.assert_called_with(obj)
+
         with self.assertRaises(ValueError):
         with self.assertRaises(ValueError):
-            jsonify(object())
+            jsonify(obj)
 
 
 
 
 class test_chunks(Case):
 class test_chunks(Case):

+ 1 - 1
celery/tests/worker/test_autoscale.py

@@ -134,7 +134,7 @@ class test_Autoscaler(AppCase):
         x.scale_up(3)
         x.scale_up(3)
         x._last_action = monotonic() - 10000
         x._last_action = monotonic() - 10000
         x.pool.shrink_raises_exception = True
         x.pool.shrink_raises_exception = True
-        x.scale_down(1)
+        x._shrink(1)
 
 
     @patch('celery.worker.autoscale.debug')
     @patch('celery.worker.autoscale.debug')
     def test_shrink_raises_ValueError(self, debug):
     def test_shrink_raises_ValueError(self, debug):

+ 49 - 0
celery/tests/worker/test_consumer.py

@@ -41,6 +41,9 @@ class test_Consumer(AppCase):
         consumer.conninfo = consumer.connection
         consumer.conninfo = consumer.connection
         return consumer
         return consumer
 
 
+    def test_repr(self):
+        self.assertTrue(repr(self.get_consumer()))
+
     def test_taskbuckets_defaultdict(self):
     def test_taskbuckets_defaultdict(self):
         c = self.get_consumer()
         c = self.get_consumer()
         self.assertIsNone(c.task_buckets['fooxasdwx.wewe'])
         self.assertIsNone(c.task_buckets['fooxasdwx.wewe'])
@@ -68,6 +71,44 @@ class test_Consumer(AppCase):
             self.get_consumer()
             self.get_consumer()
             self.assertIsNone(self.app.conf.broker_connection_timeout)
             self.assertIsNone(self.app.conf.broker_connection_timeout)
 
 
+    def test_limit_moved_to_pool(self):
+        with patch('celery.worker.consumer.task_reserved') as reserved:
+            c = self.get_consumer()
+            c.on_task_request = Mock(name='on_task_request')
+            request = Mock(name='request')
+            c._limit_move_to_pool(request)
+            reserved.assert_called_with(request)
+            c.on_task_request.assert_called_with(request)
+
+    def test_update_prefetch_count(self):
+        c = self.get_consumer()
+        c._update_qos_eventually = Mock(name='update_qos')
+        c.initial_prefetch_count = None
+        c.pool.num_processes = None
+        c.prefetch_multiplier = 10
+        self.assertIsNone(c._update_prefetch_count(1))
+        c.initial_prefetch_count = 10
+        c.pool.num_processes = 10
+        c._update_prefetch_count(8)
+        c._update_qos_eventually.assert_called_with(8)
+        self.assertEqual(c.initial_prefetch_count, 10 * 10)
+
+    def test_flush_events(self):
+        c = self.get_consumer()
+        c.event_dispatcher = None
+        c._flush_events()
+        c.event_dispatcher = Mock(name='evd')
+        c._flush_events()
+        c.event_dispatcher.flush.assert_called_with()
+
+    def test_on_send_event_buffered(self):
+        c = self.get_consumer()
+        c.hub = None
+        c.on_send_event_buffered()
+        c.hub = Mock(name='hub')
+        c.on_send_event_buffered()
+        c.hub._ready.add.assert_called_with(c._flush_events)
+
     def test_limit_task(self):
     def test_limit_task(self):
         c = self.get_consumer()
         c = self.get_consumer()
 
 
@@ -460,6 +501,14 @@ class test_Gossip(AppCase):
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
             state.workers['foo']
             state.workers['foo']
 
 
+    def test_on_message__task(self):
+        c = self.Consumer()
+        g = Gossip(c)
+        self.assertTrue(g.enabled)
+        message = Mock(name='message')
+        message.delivery_info = {'routing_key': 'task.failed'}
+        g.on_message(Mock(name='prepare'), message)
+
     def test_on_message(self):
     def test_on_message(self):
         c = self.Consumer()
         c = self.Consumer()
         g = Gossip(c)
         g = Gossip(c)

+ 73 - 8
celery/tests/worker/test_control.py

@@ -18,7 +18,6 @@ from celery.worker import control
 from celery.worker import state as worker_state
 from celery.worker import state as worker_state
 from celery.worker.request import Request
 from celery.worker.request import Request
 from celery.worker.state import revoked
 from celery.worker.state import revoked
-from celery.worker.control import Panel
 from celery.worker.pidbox import Pidbox, gPidbox
 from celery.worker.pidbox import Pidbox, gPidbox
 
 
 from celery.tests.case import AppCase, Mock, TaskMessage, call, patch
 from celery.tests.case import AppCase, Mock, TaskMessage, call, patch
@@ -132,7 +131,7 @@ class test_ControlPanel(AppCase):
     def create_panel(self, **kwargs):
     def create_panel(self, **kwargs):
         return self.app.control.mailbox.Node(hostname=hostname,
         return self.app.control.mailbox.Node(hostname=hostname,
                                              state=self.create_state(**kwargs),
                                              state=self.create_state(**kwargs),
-                                             handlers=Panel.data)
+                                             handlers=control.Panel.data)
 
 
     def test_enable_events(self):
     def test_enable_events(self):
         consumer = Consumer(self.app)
         consumer = Consumer(self.app)
@@ -168,21 +167,36 @@ class test_ControlPanel(AppCase):
         consumer = Consumer(self.app)
         consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
         panel.state.app.clock.value = 313
         panel.state.app.clock.value = 313
+        panel.state.hostname = 'elaine@vandelay.com'
         worker_state.revoked.add('revoked1')
         worker_state.revoked.add('revoked1')
         try:
         try:
-            x = panel.handle('hello', {'from_node': 'george@vandelay.com'})
-            self.assertIn('revoked1', x['revoked'])
+            self.assertIsNone(panel.handle('hello', {
+                'from_node': 'elaine@vandelay.com',
+            }))
+            x = panel.handle('hello', {
+                'from_node': 'george@vandelay.com',
+            })
             self.assertEqual(x['clock'], 314)  # incremented
             self.assertEqual(x['clock'], 314)  # incremented
+            x = panel.handle('hello', {
+                'from_node': 'george@vandelay.com',
+                'revoked': {'1234', '4567', '891'}
+            })
+            self.assertIn('revoked1', x['revoked'])
+            self.assertIn('1234', x['revoked'])
+            self.assertIn('4567', x['revoked'])
+            self.assertIn('891', x['revoked'])
+            self.assertEqual(x['clock'], 315)  # incremented
         finally:
         finally:
             worker_state.revoked.discard('revoked1')
             worker_state.revoked.discard('revoked1')
 
 
     def test_conf(self):
     def test_conf(self):
-        return
         consumer = Consumer(self.app)
         consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
-        self.app.conf.SOME_KEY6 = 'hello world'
+        panel.app = self.app
+        panel.app.finalize()
+        self.app.conf.some_key6 = 'hello world'
         x = panel.handle('dump_conf')
         x = panel.handle('dump_conf')
-        self.assertIn('SOME_KEY6', x)
+        self.assertIn('some_key6', x)
 
 
     def test_election(self):
     def test_election(self):
         consumer = Consumer(self.app)
         consumer = Consumer(self.app)
@@ -193,6 +207,14 @@ class test_ControlPanel(AppCase):
         )
         )
         consumer.gossip.election.assert_called_with('id', 'topic', 'action')
         consumer.gossip.election.assert_called_with('id', 'topic', 'action')
 
 
+    def test_election__no_gossip(self):
+        consumer = Mock(name='consumer')
+        consumer.gossip = None
+        panel = self.create_panel(consumer=consumer)
+        panel.handle(
+            'election', {'id': 'id', 'topic': 'topic', 'action': 'action'},
+        )
+
     def test_heartbeat(self):
     def test_heartbeat(self):
         consumer = Consumer(self.app)
         consumer = Consumer(self.app)
         panel = self.create_panel(consumer=consumer)
         panel = self.create_panel(consumer=consumer)
@@ -236,11 +258,27 @@ class test_ControlPanel(AppCase):
         self.assertListEqual(list(sorted(q['name'] for q in r)),
         self.assertListEqual(list(sorted(q['name'] for q in r)),
                              ['bar', 'foo'])
                              ['bar', 'foo'])
 
 
+    def test_active_queues__empty(self):
+        consumer = Mock(name='consumer')
+        panel = self.create_panel(consumer=consumer)
+        consumer.task_consumer = None
+        self.assertFalse(panel.handle('active_queues'))
+
     def test_dump_tasks(self):
     def test_dump_tasks(self):
         info = '\n'.join(self.panel.handle('dump_tasks'))
         info = '\n'.join(self.panel.handle('dump_tasks'))
         self.assertIn('mytask', info)
         self.assertIn('mytask', info)
         self.assertIn('rate_limit=200', info)
         self.assertIn('rate_limit=200', info)
 
 
+    def test_dump_tasks2(self):
+        prev, control.DEFAULT_TASK_INFO_ITEMS = (
+            control.DEFAULT_TASK_INFO_ITEMS, [])
+        try:
+            info = '\n'.join(self.panel.handle('dump_tasks'))
+            self.assertIn('mytask', info)
+            self.assertNotIn('rate_limit=200', info)
+        finally:
+            control.DEFAULT_TASK_INFO_ITEMS = prev
+
     def test_stats(self):
     def test_stats(self):
         prev_count, worker_state.total_count = worker_state.total_count, 100
         prev_count, worker_state.total_count = worker_state.total_count, 100
         try:
         try:
@@ -493,7 +531,7 @@ class test_ControlPanel(AppCase):
 
 
         panel = _Node(hostname=hostname,
         panel = _Node(hostname=hostname,
                       state=self.create_state(consumer=Consumer(self.app)),
                       state=self.create_state(consumer=Consumer(self.app)),
-                      handlers=Panel.data,
+                      handlers=control.Panel.data,
                       mailbox=self.app.control.mailbox)
                       mailbox=self.app.control.mailbox)
         r = panel.dispatch('ping', reply_to={'exchange': 'x',
         r = panel.dispatch('ping', reply_to={'exchange': 'x',
                                              'routing_key': 'x'})
                                              'routing_key': 'x'})
@@ -584,3 +622,30 @@ class test_ControlPanel(AppCase):
             self.assertTrue(consumer.controller.pool.restart.called)
             self.assertTrue(consumer.controller.pool.restart.called)
             self.assertTrue(_reload.called)
             self.assertTrue(_reload.called)
             self.assertFalse(_import.called)
             self.assertFalse(_import.called)
+
+    def test_query_task(self):
+        consumer = Consumer(self.app)
+        consumer.controller = _WC(app=self.app)
+        consumer.controller.consumer = consumer
+        panel = self.create_panel(consumer=consumer)
+        panel.app = self.app
+        req1 = Request(
+            TaskMessage(self.mytask.name, args=(2, 2)),
+            app=self.app,
+        )
+        worker_state.reserved_requests.add(req1)
+        try:
+            self.assertFalse(panel.handle('query_task', {'ids': {'1daa'}}))
+            ret = panel.handle('query_task', {'ids': {req1.id}})
+            self.assertIn(req1.id, ret)
+            self.assertEqual(ret[req1.id][0], 'reserved')
+            worker_state.active_requests.add(req1)
+            try:
+                ret = panel.handle('query_task', {'ids': {req1.id}})
+                self.assertEqual(ret[req1.id][0], 'active')
+            finally:
+                worker_state.active_requests.clear()
+            ret = panel.handle('query_task', {'ids': {req1.id}})
+            self.assertEqual(ret[req1.id][0], 'reserved')
+        finally:
+            worker_state.reserved_requests.clear()

+ 34 - 1
celery/tests/worker/test_loops.py

@@ -1,11 +1,14 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
+import errno
 import socket
 import socket
 
 
 from kombu.async import Hub, READ, WRITE, ERR
 from kombu.async import Hub, READ, WRITE, ERR
 
 
 from celery.bootsteps import CLOSE, RUN
 from celery.bootsteps import CLOSE, RUN
-from celery.exceptions import InvalidTaskError, WorkerShutdown, WorkerTerminate
+from celery.exceptions import (
+    InvalidTaskError, WorkerLostError, WorkerShutdown, WorkerTerminate,
+)
 from celery.five import Empty
 from celery.five import Empty
 from celery.platforms import EX_FAILURE
 from celery.platforms import EX_FAILURE
 from celery.worker import state
 from celery.worker import state
@@ -129,6 +132,13 @@ class test_asynloop(AppCase):
             _quick_drain, [p.fun for p in x.hub._ready],
             _quick_drain, [p.fun for p in x.hub._ready],
         )
         )
 
 
+    def test_pool_did_not_start_at_startup(self):
+        x = X(self.app)
+        x.obj.restart_count = 0
+        x.obj.pool.did_start_ok.return_value = False
+        with self.assertRaises(WorkerLostError):
+            asynloop(*x.args)
+
     def test_setup_heartbeat(self):
     def test_setup_heartbeat(self):
         x = X(self.app, heartbeat=10)
         x = X(self.app, heartbeat=10)
         x.hub.call_repeatedly = Mock(name='x.hub.call_repeatedly()')
         x.hub.call_repeatedly = Mock(name='x.hub.call_repeatedly()')
@@ -423,3 +433,26 @@ class test_synloop(AppCase):
         x = X(self.app)
         x = X(self.app)
         x.close_then_error(x.connection.drain_events)
         x.close_then_error(x.connection.drain_events)
         self.assertIsNone(synloop(*x.args))
         self.assertIsNone(synloop(*x.args))
+
+
+class test_quick_drain(AppCase):
+
+    def setup(self):
+        self.connection = Mock(name='connection')
+
+    def test_drain(self):
+        _quick_drain(self.connection, timeout=33.3)
+        self.connection.drain_events.assert_called_with(timeout=33.3)
+
+    def test_drain_error(self):
+        exc = KeyError()
+        exc.errno = 313
+        self.connection.drain_events.side_effect = exc
+        with self.assertRaises(KeyError):
+            _quick_drain(self.connection, timeout=33.3)
+
+    def test_drain_error_EAGAIN(self):
+        exc = KeyError()
+        exc.errno = errno.EAGAIN
+        self.connection.drain_events.side_effect = exc
+        _quick_drain(self.connection, timeout=33.3)

+ 1 - 46
celery/tests/worker/test_worker.py

@@ -12,10 +12,8 @@ from kombu import Connection
 from kombu.common import QoS, ignore_errors
 from kombu.common import QoS, ignore_errors
 from kombu.transport.base import Message
 from kombu.transport.base import Message
 
 
-from celery.app.defaults import DEFAULTS
 from celery.bootsteps import RUN, CLOSE, StartStopStep
 from celery.bootsteps import RUN, CLOSE, StartStopStep
 from celery.concurrency.base import BasePool
 from celery.concurrency.base import BasePool
-from celery.datastructures import AttributeDict
 from celery.exceptions import (
 from celery.exceptions import (
     WorkerShutdown, WorkerTerminate, TaskRevokedError, InvalidTaskError,
     WorkerShutdown, WorkerTerminate, TaskRevokedError, InvalidTaskError,
 )
 )
@@ -30,9 +28,7 @@ from celery.utils import worker_direct
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 from celery.utils.timer2 import Timer
 
 
-from celery.tests.case import (
-    AppCase, Mock, SkipTest, TaskMessage, patch, restore_logging,
-)
+from celery.tests.case import AppCase, Mock, SkipTest, TaskMessage, patch
 
 
 
 
 def MockStep(step=None):
 def MockStep(step=None):
@@ -875,47 +871,6 @@ class test_WorkController(AppCase):
         worker.stop()
         worker.stop()
         self.assertTrue(worker.pidlock.release.called)
         self.assertTrue(worker.pidlock.release.called)
 
 
-    @patch('celery.platforms.signals')
-    @patch('celery.platforms.set_mp_process_title')
-    def test_process_initializer(self, set_mp_process_title, _signals):
-        with restore_logging():
-            from celery import signals
-            from celery._state import _tls
-            from celery.concurrency.prefork import (
-                process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE,
-            )
-
-            def on_worker_process_init(**kwargs):
-                on_worker_process_init.called = True
-            on_worker_process_init.called = False
-            signals.worker_process_init.connect(on_worker_process_init)
-
-            def Loader(*args, **kwargs):
-                loader = Mock(*args, **kwargs)
-                loader.conf = {}
-                loader.override_backends = {}
-                return loader
-
-            with self.Celery(loader=Loader) as app:
-                app.conf = AttributeDict(DEFAULTS)
-                process_initializer(app, 'awesome.worker.com')
-                _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
-                _signals.reset.assert_any_call(*WORKER_SIGRESET)
-                self.assertTrue(app.loader.init_worker.call_count)
-                self.assertTrue(on_worker_process_init.called)
-                self.assertIs(_tls.current_app, app)
-                set_mp_process_title.assert_called_with(
-                    'celeryd', hostname='awesome.worker.com',
-                )
-
-                with patch('celery.app.trace.setup_worker_optimizations') as S:
-                    os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
-                    try:
-                        process_initializer(app, 'luke.worker.com')
-                        S.assert_called_with(app, 'luke.worker.com')
-                    finally:
-                        os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
-
     def test_attrs(self):
     def test_attrs(self):
         worker = self.worker
         worker = self.worker
         self.assertIsNotNone(worker.timer)
         self.assertIsNotNone(worker.timer)

+ 2 - 2
celery/utils/abstract.py

@@ -32,7 +32,7 @@ class _AbstractClass(object):
         ) or NotImplemented
         ) or NotImplemented
 
 
 
 
-class CallableTask(_AbstractClass, Callable):
+class CallableTask(_AbstractClass, Callable):  # pragma: no cover
     __required_attributes__ = frozenset({
     __required_attributes__ = frozenset({
         'delay', 'apply_async', 'apply',
         'delay', 'apply_async', 'apply',
     })
     })
@@ -54,7 +54,7 @@ class CallableTask(_AbstractClass, Callable):
         return cls._subclasshook_using(CallableTask, C)
         return cls._subclasshook_using(CallableTask, C)
 
 
 
 
-class CallableSignature(CallableTask):
+class CallableSignature(CallableTask):  # pragma: no cover
     __required_attributes__ = frozenset({
     __required_attributes__ = frozenset({
         'clone', 'freeze', 'set', 'link', 'link_error', '__or__',
         'clone', 'freeze', 'set', 'link', 'link_error', '__or__',
     })
     })

+ 4 - 4
celery/utils/debug.py

@@ -31,7 +31,7 @@ UNITS = (
     (2 ** 30.0, 'GB'),
     (2 ** 30.0, 'GB'),
     (2 ** 20.0, 'MB'),
     (2 ** 20.0, 'MB'),
     (2 ** 10.0, 'kB'),
     (2 ** 10.0, 'kB'),
-    (0.0, '{0!d}b'),
+    (0.0, 'b'),
 )
 )
 
 
 _process = None
 _process = None
@@ -78,7 +78,7 @@ def sample_mem():
     return current_rss
     return current_rss
 
 
 
 
-def _memdump(samples=10):
+def _memdump(samples=10):  # pragma: no cover
     S = _mem_sample
     S = _mem_sample
     prev = list(S) if len(S) <= samples else sample(S, samples)
     prev = list(S) if len(S) <= samples else sample(S, samples)
     _mem_sample[:] = []
     _mem_sample[:] = []
@@ -88,7 +88,7 @@ def _memdump(samples=10):
     return prev, after_collect
     return prev, after_collect
 
 
 
 
-def memdump(samples=10, file=None):
+def memdump(samples=10, file=None):  # pragma: no cover
     """Dump memory statistics.
     """Dump memory statistics.
 
 
     Will print a sample of all RSS memory samples added by
     Will print a sample of all RSS memory samples added by
@@ -151,7 +151,7 @@ def mem_rss():
         return humanbytes(p.get_memory_info().rss)
         return humanbytes(p.get_memory_info().rss)
 
 
 
 
-def ps():
+def ps():  # pragma: no cover
     """Return the global :class:`psutil.Process` instance,
     """Return the global :class:`psutil.Process` instance,
     or :const:`None` if :mod:`psutil` is not installed."""
     or :const:`None` if :mod:`psutil` is not installed."""
     global _process
     global _process

+ 2 - 10
celery/worker/control.py

@@ -52,21 +52,14 @@ def _find_requests_by_id(ids, requests):
 @Panel.register
 @Panel.register
 def query_task(state, ids, **kwargs):
 def query_task(state, ids, **kwargs):
     ids = maybe_list(ids)
     ids = maybe_list(ids)
-
-    def reqinfo(state, req):
-        return state, req.info()
-
-    reqs = {
+    return dict({
         req.id: ('reserved', req.info())
         req.id: ('reserved', req.info())
         for req in _find_requests_by_id(ids, worker_state.reserved_requests)
         for req in _find_requests_by_id(ids, worker_state.reserved_requests)
-    }
-    reqs.update({
+    }, **{
         req.id: ('active', req.info())
         req.id: ('active', req.info())
         for req in _find_requests_by_id(ids, worker_state.active_requests)
         for req in _find_requests_by_id(ids, worker_state.active_requests)
     })
     })
 
 
-    return reqs
-
 
 
 @Panel.register
 @Panel.register
 def revoke(state, task_id, terminate=False, signal=None, **kwargs):
 def revoke(state, task_id, terminate=False, signal=None, **kwargs):
@@ -368,7 +361,6 @@ def active_queues(state):
 
 
 def _wanted_config_key(key):
 def _wanted_config_key(key):
     return (isinstance(key, string_t) and
     return (isinstance(key, string_t) and
-            key.isupper() and
             not key.startswith('__'))
             not key.startswith('__'))