Browse Source

Reorganizes the tests a little

Ask Solem 13 years ago
parent
commit
419249e9e1

+ 321 - 0
celery/tests/test_app/__init__.py

@@ -0,0 +1,321 @@
+from __future__ import with_statement
+
+import os
+import sys
+
+from mock import Mock
+
+from celery import Celery
+from celery import app as _app
+from celery.app import defaults
+from celery.app.base import BaseApp, pyimplementation
+from celery.loaders.base import BaseLoader
+from celery.utils.serialization import pickle
+
+from celery.tests import config
+from celery.tests.utils import (unittest, mask_modules, platform_pyimp,
+                                sys_platform, pypy_version)
+from celery.utils.mail import ErrorMail
+from kombu.utils import gen_unique_id
+
+THIS_IS_A_KEY = "this is a value"
+
+
+class Object(object):
+
+    def __init__(self, **kwargs):
+        for key, value in kwargs.items():
+            setattr(self, key, value)
+
+
+def _get_test_config():
+    return dict((key, getattr(config, key))
+                    for key in dir(config)
+                        if key.isupper() and not key.startswith("_"))
+
+test_config = _get_test_config()
+
+
+class test_App(unittest.TestCase):
+
+    def setUp(self):
+        self.app = Celery(set_as_current=False)
+        self.app.conf.update(test_config)
+
+    def test_task(self):
+        app = Celery("foozibari", set_as_current=False)
+
+        def fun():
+            pass
+
+        fun.__module__ = "__main__"
+        task = app.task(fun)
+        self.assertEqual(task.name, app.main + ".fun")
+
+    def test_repr(self):
+        self.assertTrue(repr(self.app))
+
+    def test_TaskSet(self):
+        ts = self.app.TaskSet()
+        self.assertListEqual(ts.tasks, [])
+        self.assertIs(ts.app, self.app)
+
+    def test_pickle_app(self):
+        changes = dict(THE_FOO_BAR="bars",
+                       THE_MII_MAR="jars")
+        self.app.conf.update(changes)
+        saved = pickle.dumps(self.app)
+        self.assertLess(len(saved), 2048)
+        restored = pickle.loads(saved)
+        self.assertDictContainsSubset(changes, restored.conf)
+
+    def test_worker_main(self):
+        from celery.bin import celeryd
+
+        class WorkerCommand(celeryd.WorkerCommand):
+
+            def execute_from_commandline(self, argv):
+                return argv
+
+        prev, celeryd.WorkerCommand = celeryd.WorkerCommand, WorkerCommand
+        try:
+            ret = self.app.worker_main(argv=["--version"])
+            self.assertListEqual(ret, ["--version"])
+        finally:
+            celeryd.WorkerCommand = prev
+
+    def test_config_from_envvar(self):
+        os.environ["CELERYTEST_CONFIG_OBJECT"] = "celery.tests.test_app"
+        self.app.config_from_envvar("CELERYTEST_CONFIG_OBJECT")
+        self.assertEqual(self.app.conf.THIS_IS_A_KEY, "this is a value")
+
+    def test_config_from_object(self):
+
+        class Object(object):
+            LEAVE_FOR_WORK = True
+            MOMENT_TO_STOP = True
+            CALL_ME_BACK = 123456789
+            WANT_ME_TO = False
+            UNDERSTAND_ME = True
+
+        self.app.config_from_object(Object())
+
+        self.assertTrue(self.app.conf.LEAVE_FOR_WORK)
+        self.assertTrue(self.app.conf.MOMENT_TO_STOP)
+        self.assertEqual(self.app.conf.CALL_ME_BACK, 123456789)
+        self.assertFalse(self.app.conf.WANT_ME_TO)
+        self.assertTrue(self.app.conf.UNDERSTAND_ME)
+
+    def test_config_from_cmdline(self):
+        cmdline = [".always_eager=no",
+                   ".result_backend=/dev/null",
+                   '.task_error_whitelist=(list)["a", "b", "c"]',
+                   "celeryd.prefetch_multiplier=368",
+                   ".foobarstring=(string)300",
+                   ".foobarint=(int)300",
+                   '.result_engine_options=(dict){"foo": "bar"}']
+        self.app.config_from_cmdline(cmdline, namespace="celery")
+        self.assertFalse(self.app.conf.CELERY_ALWAYS_EAGER)
+        self.assertEqual(self.app.conf.CELERY_RESULT_BACKEND, "/dev/null")
+        self.assertEqual(self.app.conf.CELERYD_PREFETCH_MULTIPLIER, 368)
+        self.assertListEqual(self.app.conf.CELERY_TASK_ERROR_WHITELIST,
+                             ["a", "b", "c"])
+        self.assertEqual(self.app.conf.CELERY_FOOBARSTRING, "300")
+        self.assertEqual(self.app.conf.CELERY_FOOBARINT, 300)
+        self.assertDictEqual(self.app.conf.CELERY_RESULT_ENGINE_OPTIONS,
+                             {"foo": "bar"})
+
+    def test_compat_setting_CELERY_BACKEND(self):
+
+        self.app.config_from_object(Object(CELERY_BACKEND="set_by_us"))
+        self.assertEqual(self.app.conf.CELERY_RESULT_BACKEND, "set_by_us")
+
+    def test_setting_BROKER_TRANSPORT_OPTIONS(self):
+
+        _args = {'foo': 'bar', 'spam': 'baz'}
+
+        self.app.config_from_object(Object())
+        self.assertEqual(self.app.conf.BROKER_TRANSPORT_OPTIONS, {})
+
+        self.app.config_from_object(Object(BROKER_TRANSPORT_OPTIONS=_args))
+        self.assertEqual(self.app.conf.BROKER_TRANSPORT_OPTIONS, _args)
+
+    def test_Windows_log_color_disabled(self):
+        self.app.IS_WINDOWS = True
+        self.assertFalse(self.app.log.supports_color())
+
+    def test_compat_setting_CARROT_BACKEND(self):
+        self.app.config_from_object(Object(CARROT_BACKEND="set_by_us"))
+        self.assertEqual(self.app.conf.BROKER_TRANSPORT, "set_by_us")
+
+    def test_mail_admins(self):
+
+        class Loader(BaseLoader):
+
+            def mail_admins(*args, **kwargs):
+                return args, kwargs
+
+        self.app.loader = Loader()
+        self.app.conf.ADMINS = None
+        self.assertFalse(self.app.mail_admins("Subject", "Body"))
+        self.app.conf.ADMINS = [("George Costanza", "george@vandelay.com")]
+        self.assertTrue(self.app.mail_admins("Subject", "Body"))
+
+    def test_amqp_get_broker_info(self):
+        self.assertDictContainsSubset({"hostname": "localhost",
+                                       "userid": "guest",
+                                       "password": "guest",
+                                       "virtual_host": "/"},
+                                      self.app.broker_connection(
+                                          transport="amqplib").info())
+        self.app.conf.BROKER_PORT = 1978
+        self.app.conf.BROKER_VHOST = "foo"
+        self.assertDictContainsSubset({"port": 1978,
+                                       "virtual_host": "foo"},
+                                      self.app.broker_connection(
+                                          transport="amqplib").info())
+        conn = self.app.broker_connection(virtual_host="/value")
+        self.assertDictContainsSubset({"virtual_host": "/value"},
+                                      conn.info())
+
+    def test_BROKER_BACKEND_alias(self):
+        self.assertEqual(self.app.conf.BROKER_BACKEND,
+                         self.app.conf.BROKER_TRANSPORT)
+
+    def test_with_default_connection(self):
+
+        @self.app.with_default_connection
+        def handler(connection=None, foo=None):
+            return connection, foo
+
+        connection, foo = handler(foo=42)
+        self.assertEqual(foo, 42)
+        self.assertTrue(connection)
+
+    def test_after_fork(self):
+        p = self.app._pool = Mock()
+        self.app._after_fork(self.app)
+        p.force_close_all.assert_called_with()
+        self.assertIsNone(self.app._pool)
+        self.app._after_fork(self.app)
+
+    def test_pool_no_multiprocessing(self):
+        with mask_modules("multiprocessing.util"):
+            pool = self.app.pool
+            self.assertIs(pool, self.app._pool)
+
+    def test_bugreport(self):
+        self.assertTrue(self.app.bugreport())
+
+    def test_send_task_sent_event(self):
+        from celery.app import amqp
+
+        class Dispatcher(object):
+            sent = []
+
+            def send(self, type, **fields):
+                self.sent.append((type, fields))
+
+        conn = self.app.broker_connection()
+        chan = conn.channel()
+        try:
+            for e in ("foo_exchange", "moo_exchange", "bar_exchange"):
+                chan.exchange_declare(e, "direct", durable=True)
+                chan.queue_declare(e, durable=True)
+                chan.queue_bind(e, e, e)
+        finally:
+            chan.close()
+        assert conn.transport_cls == "memory"
+
+        pub = self.app.amqp.TaskPublisher(conn, exchange="foo_exchange")
+        self.assertIn("foo_exchange", amqp._exchanges_declared)
+
+        dispatcher = Dispatcher()
+        self.assertTrue(pub.delay_task("footask", (), {},
+                                       exchange="moo_exchange",
+                                       routing_key="moo_exchange",
+                                       event_dispatcher=dispatcher))
+        self.assertTrue(dispatcher.sent)
+        self.assertEqual(dispatcher.sent[0][0], "task-sent")
+        self.assertTrue(pub.delay_task("footask", (), {},
+                                       event_dispatcher=dispatcher,
+                                       exchange="bar_exchange",
+                                       routing_key="bar_exchange"))
+        self.assertIn("bar_exchange", amqp._exchanges_declared)
+
+    def test_error_mail_sender(self):
+        x = ErrorMail.subject % {"name": "task_name",
+                                 "id": gen_unique_id(),
+                                 "exc": "FOOBARBAZ",
+                                 "hostname": "lana"}
+        self.assertTrue(x)
+
+
+class test_BaseApp(unittest.TestCase):
+
+    def test_on_init(self):
+        BaseApp()
+
+
+class test_defaults(unittest.TestCase):
+
+    def test_str_to_bool(self):
+        for s in ("false", "no", "0"):
+            self.assertFalse(defaults.str_to_bool(s))
+        for s in ("true", "yes", "1"):
+            self.assertTrue(defaults.str_to_bool(s))
+        self.assertRaises(TypeError, defaults.str_to_bool, "unsure")
+
+
+class test_debugging_utils(unittest.TestCase):
+
+    def test_enable_disable_trace(self):
+        try:
+            _app.enable_trace()
+            self.assertEqual(_app.app_or_default, _app._app_or_default_trace)
+            _app.disable_trace()
+            self.assertEqual(_app.app_or_default, _app._app_or_default)
+        finally:
+            _app.disable_trace()
+
+
+class test_compilation(unittest.TestCase):
+    _clean = ("celery.app.base", )
+
+    def setUp(self):
+        self._prev = dict((k, sys.modules.pop(k, None)) for k in self._clean)
+
+    def tearDown(self):
+        sys.modules.update(self._prev)
+
+    def test_kombu_version_check(self):
+        import kombu
+        kombu.VERSION = (0, 9, 9)
+        with self.assertRaises(ImportError):
+            __import__("celery.app.base")
+
+
+class test_pyimplementation(unittest.TestCase):
+
+    def test_platform_python_implementation(self):
+        with platform_pyimp(lambda: "Xython"):
+            self.assertEqual(pyimplementation(), "Xython")
+
+    def test_platform_jython(self):
+        with platform_pyimp():
+            with sys_platform("java 1.6.51"):
+                self.assertIn("Jython", pyimplementation())
+
+    def test_platform_pypy(self):
+        with platform_pyimp():
+            with sys_platform("darwin"):
+                with pypy_version((1, 4, 3)):
+                    self.assertIn("PyPy", pyimplementation())
+                with pypy_version((1, 4, 3, "a4")):
+                    self.assertIn("PyPy", pyimplementation())
+
+    def test_platform_fallback(self):
+        with platform_pyimp():
+            with sys_platform("darwin"):
+                with pypy_version():
+                    self.assertEqual("CPython", pyimplementation())

+ 0 - 322
celery/tests/test_app/test_app.py

@@ -1,322 +0,0 @@
-from __future__ import with_statement
-
-import os
-import sys
-
-from mock import Mock
-
-from celery import Celery
-from celery import app as _app
-from celery.app import defaults
-from celery.app.base import BaseApp, pyimplementation
-from celery.loaders.base import BaseLoader
-from celery.utils.serialization import pickle
-
-from celery.tests import config
-from celery.tests.utils import (unittest, mask_modules, platform_pyimp,
-                                sys_platform, pypy_version)
-from celery.utils.mail import ErrorMail
-from kombu.utils import gen_unique_id
-
-THIS_IS_A_KEY = "this is a value"
-
-
-class Object(object):
-
-    def __init__(self, **kwargs):
-        for key, value in kwargs.items():
-            setattr(self, key, value)
-
-
-def _get_test_config():
-    return dict((key, getattr(config, key))
-                    for key in dir(config)
-                        if key.isupper() and not key.startswith("_"))
-
-test_config = _get_test_config()
-
-
-class test_App(unittest.TestCase):
-
-    def setUp(self):
-        self.app = Celery(set_as_current=False)
-        self.app.conf.update(test_config)
-
-    def test_task(self):
-        app = Celery("foozibari", set_as_current=False)
-
-        def fun():
-            pass
-
-        fun.__module__ = "__main__"
-        task = app.task(fun)
-        self.assertEqual(task.name, app.main + ".fun")
-
-    def test_repr(self):
-        self.assertTrue(repr(self.app))
-
-    def test_TaskSet(self):
-        ts = self.app.TaskSet()
-        self.assertListEqual(ts.tasks, [])
-        self.assertIs(ts.app, self.app)
-
-    def test_pickle_app(self):
-        changes = dict(THE_FOO_BAR="bars",
-                       THE_MII_MAR="jars")
-        self.app.conf.update(changes)
-        saved = pickle.dumps(self.app)
-        self.assertLess(len(saved), 2048)
-        restored = pickle.loads(saved)
-        self.assertDictContainsSubset(changes, restored.conf)
-
-    def test_worker_main(self):
-        from celery.bin import celeryd
-
-        class WorkerCommand(celeryd.WorkerCommand):
-
-            def execute_from_commandline(self, argv):
-                return argv
-
-        prev, celeryd.WorkerCommand = celeryd.WorkerCommand, WorkerCommand
-        try:
-            ret = self.app.worker_main(argv=["--version"])
-            self.assertListEqual(ret, ["--version"])
-        finally:
-            celeryd.WorkerCommand = prev
-
-    def test_config_from_envvar(self):
-        os.environ["CELERYTEST_CONFIG_OBJECT"] = \
-                "celery.tests.test_app.test_app"
-        self.app.config_from_envvar("CELERYTEST_CONFIG_OBJECT")
-        self.assertEqual(self.app.conf.THIS_IS_A_KEY, "this is a value")
-
-    def test_config_from_object(self):
-
-        class Object(object):
-            LEAVE_FOR_WORK = True
-            MOMENT_TO_STOP = True
-            CALL_ME_BACK = 123456789
-            WANT_ME_TO = False
-            UNDERSTAND_ME = True
-
-        self.app.config_from_object(Object())
-
-        self.assertTrue(self.app.conf.LEAVE_FOR_WORK)
-        self.assertTrue(self.app.conf.MOMENT_TO_STOP)
-        self.assertEqual(self.app.conf.CALL_ME_BACK, 123456789)
-        self.assertFalse(self.app.conf.WANT_ME_TO)
-        self.assertTrue(self.app.conf.UNDERSTAND_ME)
-
-    def test_config_from_cmdline(self):
-        cmdline = [".always_eager=no",
-                   ".result_backend=/dev/null",
-                   '.task_error_whitelist=(list)["a", "b", "c"]',
-                   "celeryd.prefetch_multiplier=368",
-                   ".foobarstring=(string)300",
-                   ".foobarint=(int)300",
-                   '.result_engine_options=(dict){"foo": "bar"}']
-        self.app.config_from_cmdline(cmdline, namespace="celery")
-        self.assertFalse(self.app.conf.CELERY_ALWAYS_EAGER)
-        self.assertEqual(self.app.conf.CELERY_RESULT_BACKEND, "/dev/null")
-        self.assertEqual(self.app.conf.CELERYD_PREFETCH_MULTIPLIER, 368)
-        self.assertListEqual(self.app.conf.CELERY_TASK_ERROR_WHITELIST,
-                             ["a", "b", "c"])
-        self.assertEqual(self.app.conf.CELERY_FOOBARSTRING, "300")
-        self.assertEqual(self.app.conf.CELERY_FOOBARINT, 300)
-        self.assertDictEqual(self.app.conf.CELERY_RESULT_ENGINE_OPTIONS,
-                             {"foo": "bar"})
-
-    def test_compat_setting_CELERY_BACKEND(self):
-
-        self.app.config_from_object(Object(CELERY_BACKEND="set_by_us"))
-        self.assertEqual(self.app.conf.CELERY_RESULT_BACKEND, "set_by_us")
-
-    def test_setting_BROKER_TRANSPORT_OPTIONS(self):
-
-        _args = {'foo': 'bar', 'spam': 'baz'}
-
-        self.app.config_from_object(Object())
-        self.assertEqual(self.app.conf.BROKER_TRANSPORT_OPTIONS, {})
-
-        self.app.config_from_object(Object(BROKER_TRANSPORT_OPTIONS=_args))
-        self.assertEqual(self.app.conf.BROKER_TRANSPORT_OPTIONS, _args)
-
-    def test_Windows_log_color_disabled(self):
-        self.app.IS_WINDOWS = True
-        self.assertFalse(self.app.log.supports_color())
-
-    def test_compat_setting_CARROT_BACKEND(self):
-        self.app.config_from_object(Object(CARROT_BACKEND="set_by_us"))
-        self.assertEqual(self.app.conf.BROKER_TRANSPORT, "set_by_us")
-
-    def test_mail_admins(self):
-
-        class Loader(BaseLoader):
-
-            def mail_admins(*args, **kwargs):
-                return args, kwargs
-
-        self.app.loader = Loader()
-        self.app.conf.ADMINS = None
-        self.assertFalse(self.app.mail_admins("Subject", "Body"))
-        self.app.conf.ADMINS = [("George Costanza", "george@vandelay.com")]
-        self.assertTrue(self.app.mail_admins("Subject", "Body"))
-
-    def test_amqp_get_broker_info(self):
-        self.assertDictContainsSubset({"hostname": "localhost",
-                                       "userid": "guest",
-                                       "password": "guest",
-                                       "virtual_host": "/"},
-                                      self.app.broker_connection(
-                                          transport="amqplib").info())
-        self.app.conf.BROKER_PORT = 1978
-        self.app.conf.BROKER_VHOST = "foo"
-        self.assertDictContainsSubset({"port": 1978,
-                                       "virtual_host": "foo"},
-                                      self.app.broker_connection(
-                                          transport="amqplib").info())
-        conn = self.app.broker_connection(virtual_host="/value")
-        self.assertDictContainsSubset({"virtual_host": "/value"},
-                                      conn.info())
-
-    def test_BROKER_BACKEND_alias(self):
-        self.assertEqual(self.app.conf.BROKER_BACKEND,
-                         self.app.conf.BROKER_TRANSPORT)
-
-    def test_with_default_connection(self):
-
-        @self.app.with_default_connection
-        def handler(connection=None, foo=None):
-            return connection, foo
-
-        connection, foo = handler(foo=42)
-        self.assertEqual(foo, 42)
-        self.assertTrue(connection)
-
-    def test_after_fork(self):
-        p = self.app._pool = Mock()
-        self.app._after_fork(self.app)
-        p.force_close_all.assert_called_with()
-        self.assertIsNone(self.app._pool)
-        self.app._after_fork(self.app)
-
-    def test_pool_no_multiprocessing(self):
-        with mask_modules("multiprocessing.util"):
-            pool = self.app.pool
-            self.assertIs(pool, self.app._pool)
-
-    def test_bugreport(self):
-        self.assertTrue(self.app.bugreport())
-
-    def test_send_task_sent_event(self):
-        from celery.app import amqp
-
-        class Dispatcher(object):
-            sent = []
-
-            def send(self, type, **fields):
-                self.sent.append((type, fields))
-
-        conn = self.app.broker_connection()
-        chan = conn.channel()
-        try:
-            for e in ("foo_exchange", "moo_exchange", "bar_exchange"):
-                chan.exchange_declare(e, "direct", durable=True)
-                chan.queue_declare(e, durable=True)
-                chan.queue_bind(e, e, e)
-        finally:
-            chan.close()
-        assert conn.transport_cls == "memory"
-
-        pub = self.app.amqp.TaskPublisher(conn, exchange="foo_exchange")
-        self.assertIn("foo_exchange", amqp._exchanges_declared)
-
-        dispatcher = Dispatcher()
-        self.assertTrue(pub.delay_task("footask", (), {},
-                                       exchange="moo_exchange",
-                                       routing_key="moo_exchange",
-                                       event_dispatcher=dispatcher))
-        self.assertTrue(dispatcher.sent)
-        self.assertEqual(dispatcher.sent[0][0], "task-sent")
-        self.assertTrue(pub.delay_task("footask", (), {},
-                                       event_dispatcher=dispatcher,
-                                       exchange="bar_exchange",
-                                       routing_key="bar_exchange"))
-        self.assertIn("bar_exchange", amqp._exchanges_declared)
-
-    def test_error_mail_sender(self):
-        x = ErrorMail.subject % {"name": "task_name",
-                                 "id": gen_unique_id(),
-                                 "exc": "FOOBARBAZ",
-                                 "hostname": "lana"}
-        self.assertTrue(x)
-
-
-class test_BaseApp(unittest.TestCase):
-
-    def test_on_init(self):
-        BaseApp()
-
-
-class test_defaults(unittest.TestCase):
-
-    def test_str_to_bool(self):
-        for s in ("false", "no", "0"):
-            self.assertFalse(defaults.str_to_bool(s))
-        for s in ("true", "yes", "1"):
-            self.assertTrue(defaults.str_to_bool(s))
-        self.assertRaises(TypeError, defaults.str_to_bool, "unsure")
-
-
-class test_debugging_utils(unittest.TestCase):
-
-    def test_enable_disable_trace(self):
-        try:
-            _app.enable_trace()
-            self.assertEqual(_app.app_or_default, _app._app_or_default_trace)
-            _app.disable_trace()
-            self.assertEqual(_app.app_or_default, _app._app_or_default)
-        finally:
-            _app.disable_trace()
-
-
-class test_compilation(unittest.TestCase):
-    _clean = ("celery.app.base", )
-
-    def setUp(self):
-        self._prev = dict((k, sys.modules.pop(k, None)) for k in self._clean)
-
-    def tearDown(self):
-        sys.modules.update(self._prev)
-
-    def test_kombu_version_check(self):
-        import kombu
-        kombu.VERSION = (0, 9, 9)
-        with self.assertRaises(ImportError):
-            __import__("celery.app.base")
-
-
-class test_pyimplementation(unittest.TestCase):
-
-    def test_platform_python_implementation(self):
-        with platform_pyimp(lambda: "Xython"):
-            self.assertEqual(pyimplementation(), "Xython")
-
-    def test_platform_jython(self):
-        with platform_pyimp():
-            with sys_platform("java 1.6.51"):
-                self.assertIn("Jython", pyimplementation())
-
-    def test_platform_pypy(self):
-        with platform_pyimp():
-            with sys_platform("darwin"):
-                with pypy_version((1, 4, 3)):
-                    self.assertIn("PyPy", pyimplementation())
-                with pypy_version((1, 4, 3, "a4")):
-                    self.assertIn("PyPy", pyimplementation())
-
-    def test_platform_fallback(self):
-        with platform_pyimp():
-            with sys_platform("darwin"):
-                with pypy_version():
-                    self.assertEqual("CPython", pyimplementation())

+ 0 - 0
celery/tests/test_compat/test_log.py → celery/tests/test_app/test_log.py


+ 78 - 0
celery/tests/test_bin/__init__.py

@@ -0,0 +1,78 @@
+import os
+
+from celery.bin.base import Command
+
+from celery.tests.utils import AppCase
+
+
+class Object(object):
+    pass
+
+
+class MyApp(object):
+    pass
+
+APP = MyApp()  # <-- Used by test_with_custom_app
+
+
+class MockCommand(Command):
+
+    def parse_options(self, prog_name, arguments):
+        options = Object()
+        options.foo = "bar"
+        options.prog_name = prog_name
+        return options, (10, 20, 30)
+
+    def run(self, *args, **kwargs):
+        return args, kwargs
+
+
+class test_Command(AppCase):
+
+    def test_get_options(self):
+        cmd = Command()
+        cmd.option_list = (1, 2, 3)
+        self.assertTupleEqual(cmd.get_options(), (1, 2, 3))
+
+    def test_run_interface(self):
+        self.assertRaises(NotImplementedError, Command().run)
+
+    def test_execute_from_commandline(self):
+        cmd = MockCommand()
+        args1, kwargs1 = cmd.execute_from_commandline()     # sys.argv
+        self.assertTupleEqual(args1, (10, 20, 30))
+        self.assertDictContainsSubset({"foo": "bar"}, kwargs1)
+        self.assertTrue(kwargs1.get("prog_name"))
+        args2, kwargs2 = cmd.execute_from_commandline(["foo"])   # pass list
+        self.assertTupleEqual(args2, (10, 20, 30))
+        self.assertDictContainsSubset({"foo": "bar", "prog_name": "foo"},
+                                      kwargs2)
+
+    def test_with_custom_config_module(self):
+        prev = os.environ.pop("CELERY_CONFIG_MODULE", None)
+        try:
+            cmd = MockCommand()
+            cmd.setup_app_from_commandline(["--config=foo.bar.baz"])
+            self.assertEqual(os.environ.get("CELERY_CONFIG_MODULE"),
+                             "foo.bar.baz")
+        finally:
+            if prev:
+                os.environ["CELERY_CONFIG_MODULE"] = prev
+
+    def test_with_custom_app(self):
+        cmd = MockCommand()
+        app = ".".join([__name__, "APP"])
+        cmd.setup_app_from_commandline(["--app=%s" % (app, ),
+                                        "--loglevel=INFO"])
+        self.assertIs(cmd.app, APP)
+
+    def test_with_cmdline_config(self):
+        cmd = MockCommand()
+        cmd.enable_config_from_cmdline = True
+        cmd.namespace = "celeryd"
+        rest = cmd.setup_app_from_commandline(argv=[
+            "--loglevel=INFO", "--", "broker.host=broker.example.com",
+            ".prefetch_multiplier=100"])
+        self.assertEqual(cmd.app.conf.BROKER_HOST, "broker.example.com")
+        self.assertEqual(cmd.app.conf.CELERYD_PREFETCH_MULTIPLIER, 100)
+        self.assertListEqual(rest, ["--loglevel=INFO"])

+ 0 - 78
celery/tests/test_bin/test_base.py

@@ -1,78 +0,0 @@
-import os
-
-from celery.bin.base import Command
-
-from celery.tests.utils import AppCase
-
-
-class Object(object):
-    pass
-
-
-class MyApp(object):
-    pass
-
-APP = MyApp()  # <-- Used by test_with_custom_app
-
-
-class MockCommand(Command):
-
-    def parse_options(self, prog_name, arguments):
-        options = Object()
-        options.foo = "bar"
-        options.prog_name = prog_name
-        return options, (10, 20, 30)
-
-    def run(self, *args, **kwargs):
-        return args, kwargs
-
-
-class test_Command(AppCase):
-
-    def test_get_options(self):
-        cmd = Command()
-        cmd.option_list = (1, 2, 3)
-        self.assertTupleEqual(cmd.get_options(), (1, 2, 3))
-
-    def test_run_interface(self):
-        self.assertRaises(NotImplementedError, Command().run)
-
-    def test_execute_from_commandline(self):
-        cmd = MockCommand()
-        args1, kwargs1 = cmd.execute_from_commandline()     # sys.argv
-        self.assertTupleEqual(args1, (10, 20, 30))
-        self.assertDictContainsSubset({"foo": "bar"}, kwargs1)
-        self.assertTrue(kwargs1.get("prog_name"))
-        args2, kwargs2 = cmd.execute_from_commandline(["foo"])   # pass list
-        self.assertTupleEqual(args2, (10, 20, 30))
-        self.assertDictContainsSubset({"foo": "bar", "prog_name": "foo"},
-                                      kwargs2)
-
-    def test_with_custom_config_module(self):
-        prev = os.environ.pop("CELERY_CONFIG_MODULE", None)
-        try:
-            cmd = MockCommand()
-            cmd.setup_app_from_commandline(["--config=foo.bar.baz"])
-            self.assertEqual(os.environ.get("CELERY_CONFIG_MODULE"),
-                             "foo.bar.baz")
-        finally:
-            if prev:
-                os.environ["CELERY_CONFIG_MODULE"] = prev
-
-    def test_with_custom_app(self):
-        cmd = MockCommand()
-        app = ".".join([__name__, "APP"])
-        cmd.setup_app_from_commandline(["--app=%s" % (app, ),
-                                        "--loglevel=INFO"])
-        self.assertIs(cmd.app, APP)
-
-    def test_with_cmdline_config(self):
-        cmd = MockCommand()
-        cmd.enable_config_from_cmdline = True
-        cmd.namespace = "celeryd"
-        rest = cmd.setup_app_from_commandline(argv=[
-            "--loglevel=INFO", "--", "broker.host=broker.example.com",
-            ".prefetch_multiplier=100"])
-        self.assertEqual(cmd.app.conf.BROKER_HOST, "broker.example.com")
-        self.assertEqual(cmd.app.conf.CELERYD_PREFETCH_MULTIPLIER, 100)
-        self.assertListEqual(rest, ["--loglevel=INFO"])

+ 63 - 0
celery/tests/test_concurrency/__init__.py

@@ -0,0 +1,63 @@
+import os
+
+from itertools import count
+
+from celery.concurrency.base import apply_target, BasePool
+from celery.tests.utils import unittest
+
+
+class test_BasePool(unittest.TestCase):
+
+    def test_apply_target(self):
+
+        scratch = {}
+        counter = count(0).next
+
+        def gen_callback(name, retval=None):
+
+            def callback(*args):
+                scratch[name] = (counter(), args)
+                return retval
+
+            return callback
+
+        apply_target(gen_callback("target", 42),
+                     args=(8, 16),
+                     callback=gen_callback("callback"),
+                     accept_callback=gen_callback("accept_callback"))
+
+        self.assertDictContainsSubset({
+                              "target": (1, (8, 16)),
+                              "callback": (2, (42, ))}, scratch)
+        pa1 = scratch["accept_callback"]
+        self.assertEqual(0, pa1[0])
+        self.assertEqual(pa1[1][0], os.getpid())
+        self.assertTrue(pa1[1][1])
+
+        # No accept callback
+        scratch.clear()
+        apply_target(gen_callback("target", 42),
+                     args=(8, 16),
+                     callback=gen_callback("callback"),
+                     accept_callback=None)
+        self.assertDictEqual(scratch,
+                              {"target": (3, (8, 16)),
+                               "callback": (4, (42, ))})
+
+    def test_interface_on_start(self):
+        BasePool(10).on_start()
+
+    def test_interface_on_stop(self):
+        BasePool(10).on_stop()
+
+    def test_interface_on_apply(self):
+        BasePool(10).on_apply()
+
+    def test_interface_info(self):
+        self.assertDictEqual(BasePool(10).info, {})
+
+    def test_active(self):
+        p = BasePool(10)
+        self.assertFalse(p.active)
+        p._state = p.RUN
+        self.assertTrue(p.active)

+ 0 - 63
celery/tests/test_concurrency/test_concurrency_base.py

@@ -1,63 +0,0 @@
-import os
-
-from itertools import count
-
-from celery.concurrency.base import apply_target, BasePool
-from celery.tests.utils import unittest
-
-
-class test_BasePool(unittest.TestCase):
-
-    def test_apply_target(self):
-
-        scratch = {}
-        counter = count(0).next
-
-        def gen_callback(name, retval=None):
-
-            def callback(*args):
-                scratch[name] = (counter(), args)
-                return retval
-
-            return callback
-
-        apply_target(gen_callback("target", 42),
-                     args=(8, 16),
-                     callback=gen_callback("callback"),
-                     accept_callback=gen_callback("accept_callback"))
-
-        self.assertDictContainsSubset({
-                              "target": (1, (8, 16)),
-                              "callback": (2, (42, ))}, scratch)
-        pa1 = scratch["accept_callback"]
-        self.assertEqual(0, pa1[0])
-        self.assertEqual(pa1[1][0], os.getpid())
-        self.assertTrue(pa1[1][1])
-
-        # No accept callback
-        scratch.clear()
-        apply_target(gen_callback("target", 42),
-                     args=(8, 16),
-                     callback=gen_callback("callback"),
-                     accept_callback=None)
-        self.assertDictEqual(scratch,
-                              {"target": (3, (8, 16)),
-                               "callback": (4, (42, ))})
-
-    def test_interface_on_start(self):
-        BasePool(10).on_start()
-
-    def test_interface_on_stop(self):
-        BasePool(10).on_stop()
-
-    def test_interface_on_apply(self):
-        BasePool(10).on_apply()
-
-    def test_interface_info(self):
-        self.assertDictEqual(BasePool(10).info, {})
-
-    def test_active(self):
-        p = BasePool(10)
-        self.assertFalse(p.active)
-        p._state = p.RUN
-        self.assertTrue(p.active)

+ 186 - 0
celery/tests/test_events/__init__.py

@@ -0,0 +1,186 @@
+import socket
+
+from celery import events
+from celery.app import app_or_default
+from celery.tests.utils import unittest
+
+
+class MockProducer(object):
+    raise_on_publish = False
+
+    def __init__(self, *args, **kwargs):
+        self.sent = []
+
+    def publish(self, msg, *args, **kwargs):
+        if self.raise_on_publish:
+            raise KeyError()
+        self.sent.append(msg)
+
+    def close(self):
+        pass
+
+    def has_event(self, kind):
+        for event in self.sent:
+            if event["type"] == kind:
+                return event
+        return False
+
+
+class TestEvent(unittest.TestCase):
+
+    def test_constructor(self):
+        event = events.Event("world war II")
+        self.assertEqual(event["type"], "world war II")
+        self.assertTrue(event["timestamp"])
+
+
+class TestEventDispatcher(unittest.TestCase):
+
+    def setUp(self):
+        self.app = app_or_default()
+
+    def test_send(self):
+        producer = MockProducer()
+        eventer = self.app.events.Dispatcher(object(), enabled=False)
+        eventer.publisher = producer
+        eventer.enabled = True
+        eventer.send("World War II", ended=True)
+        self.assertTrue(producer.has_event("World War II"))
+        eventer.enabled = False
+        eventer.send("World War III")
+        self.assertFalse(producer.has_event("World War III"))
+
+        evs = ("Event 1", "Event 2", "Event 3")
+        eventer.enabled = True
+        eventer.publisher.raise_on_publish = True
+        eventer.buffer_while_offline = False
+        self.assertRaises(KeyError, eventer.send, "Event X")
+        eventer.buffer_while_offline = True
+        for ev in evs:
+            eventer.send(ev)
+        eventer.publisher.raise_on_publish = False
+        eventer.flush()
+        for ev in evs:
+            self.assertTrue(producer.has_event(ev))
+
+    def test_enabled_disable(self):
+        connection = self.app.broker_connection()
+        channel = connection.channel()
+        try:
+            dispatcher = self.app.events.Dispatcher(connection,
+                                                    enabled=True)
+            dispatcher2 = self.app.events.Dispatcher(connection,
+                                                     enabled=True,
+                                                      channel=channel)
+            self.assertTrue(dispatcher.enabled)
+            self.assertTrue(dispatcher.publisher.channel)
+            self.assertEqual(dispatcher.publisher.serializer,
+                            self.app.conf.CELERY_EVENT_SERIALIZER)
+
+            created_channel = dispatcher.publisher.channel
+            dispatcher.disable()
+            dispatcher.disable()  # Disable with no active publisher
+            dispatcher2.disable()
+            self.assertFalse(dispatcher.enabled)
+            self.assertIsNone(dispatcher.publisher)
+            self.assertTrue(created_channel.closed)
+            self.assertFalse(dispatcher2.channel.closed,
+                             "does not close manually provided channel")
+
+            dispatcher.enable()
+            self.assertTrue(dispatcher.enabled)
+            self.assertTrue(dispatcher.publisher)
+        finally:
+            channel.close()
+            connection.close()
+
+
+class TestEventReceiver(unittest.TestCase):
+
+    def setUp(self):
+        self.app = app_or_default()
+
+    def test_process(self):
+
+        message = {"type": "world-war"}
+
+        got_event = [False]
+
+        def my_handler(event):
+            got_event[0] = True
+
+        r = events.EventReceiver(object(),
+                                 handlers={"world-war": my_handler},
+                                 node_id="celery.tests",
+                                 )
+        r._receive(message, object())
+        self.assertTrue(got_event[0])
+
+    def test_catch_all_event(self):
+
+        message = {"type": "world-war"}
+
+        got_event = [False]
+
+        def my_handler(event):
+            got_event[0] = True
+
+        r = events.EventReceiver(object(), node_id="celery.tests")
+        events.EventReceiver.handlers["*"] = my_handler
+        try:
+            r._receive(message, object())
+            self.assertTrue(got_event[0])
+        finally:
+            events.EventReceiver.handlers = {}
+
+    def test_itercapture(self):
+        connection = self.app.broker_connection()
+        try:
+            r = self.app.events.Receiver(connection, node_id="celery.tests")
+            it = r.itercapture(timeout=0.0001, wakeup=False)
+            consumer = it.next()
+            self.assertTrue(consumer.queues)
+            self.assertEqual(consumer.callbacks[0], r._receive)
+
+            self.assertRaises(socket.timeout, it.next)
+
+            self.assertRaises(socket.timeout,
+                              r.capture, timeout=0.00001)
+        finally:
+            connection.close()
+
+    def test_itercapture_limit(self):
+        connection = self.app.broker_connection()
+        channel = connection.channel()
+        try:
+            events_received = [0]
+
+            def handler(event):
+                events_received[0] += 1
+
+            producer = self.app.events.Dispatcher(connection,
+                                                  enabled=True,
+                                                  channel=channel)
+            r = self.app.events.Receiver(connection,
+                                         handlers={"*": handler},
+                                         node_id="celery.tests")
+            evs = ["ev1", "ev2", "ev3", "ev4", "ev5"]
+            for ev in evs:
+                producer.send(ev)
+            it = r.itercapture(limit=4, wakeup=True)
+            it.next()  # skip consumer (see itercapture)
+            list(it)
+            self.assertEqual(events_received[0], 4)
+        finally:
+            channel.close()
+            connection.close()
+
+
+class test_misc(unittest.TestCase):
+
+    def setUp(self):
+        self.app = app_or_default()
+
+    def test_State(self):
+        state = self.app.events.State()
+        self.assertDictEqual(dict(state.workers), {})

+ 0 - 186
celery/tests/test_events/test_events.py

@@ -1,186 +0,0 @@
-import socket
-
-from celery import events
-from celery.app import app_or_default
-from celery.tests.utils import unittest
-
-
-class MockProducer(object):
-    raise_on_publish = False
-
-    def __init__(self, *args, **kwargs):
-        self.sent = []
-
-    def publish(self, msg, *args, **kwargs):
-        if self.raise_on_publish:
-            raise KeyError()
-        self.sent.append(msg)
-
-    def close(self):
-        pass
-
-    def has_event(self, kind):
-        for event in self.sent:
-            if event["type"] == kind:
-                return event
-        return False
-
-
-class TestEvent(unittest.TestCase):
-
-    def test_constructor(self):
-        event = events.Event("world war II")
-        self.assertEqual(event["type"], "world war II")
-        self.assertTrue(event["timestamp"])
-
-
-class TestEventDispatcher(unittest.TestCase):
-
-    def setUp(self):
-        self.app = app_or_default()
-
-    def test_send(self):
-        producer = MockProducer()
-        eventer = self.app.events.Dispatcher(object(), enabled=False)
-        eventer.publisher = producer
-        eventer.enabled = True
-        eventer.send("World War II", ended=True)
-        self.assertTrue(producer.has_event("World War II"))
-        eventer.enabled = False
-        eventer.send("World War III")
-        self.assertFalse(producer.has_event("World War III"))
-
-        evs = ("Event 1", "Event 2", "Event 3")
-        eventer.enabled = True
-        eventer.publisher.raise_on_publish = True
-        eventer.buffer_while_offline = False
-        self.assertRaises(KeyError, eventer.send, "Event X")
-        eventer.buffer_while_offline = True
-        for ev in evs:
-            eventer.send(ev)
-        eventer.publisher.raise_on_publish = False
-        eventer.flush()
-        for ev in evs:
-            self.assertTrue(producer.has_event(ev))
-
-    def test_enabled_disable(self):
-        connection = self.app.broker_connection()
-        channel = connection.channel()
-        try:
-            dispatcher = self.app.events.Dispatcher(connection,
-                                                    enabled=True)
-            dispatcher2 = self.app.events.Dispatcher(connection,
-                                                     enabled=True,
-                                                      channel=channel)
-            self.assertTrue(dispatcher.enabled)
-            self.assertTrue(dispatcher.publisher.channel)
-            self.assertEqual(dispatcher.publisher.serializer,
-                            self.app.conf.CELERY_EVENT_SERIALIZER)
-
-            created_channel = dispatcher.publisher.channel
-            dispatcher.disable()
-            dispatcher.disable()  # Disable with no active publisher
-            dispatcher2.disable()
-            self.assertFalse(dispatcher.enabled)
-            self.assertIsNone(dispatcher.publisher)
-            self.assertTrue(created_channel.closed)
-            self.assertFalse(dispatcher2.channel.closed,
-                             "does not close manually provided channel")
-
-            dispatcher.enable()
-            self.assertTrue(dispatcher.enabled)
-            self.assertTrue(dispatcher.publisher)
-        finally:
-            channel.close()
-            connection.close()
-
-
-class TestEventReceiver(unittest.TestCase):
-
-    def setUp(self):
-        self.app = app_or_default()
-
-    def test_process(self):
-
-        message = {"type": "world-war"}
-
-        got_event = [False]
-
-        def my_handler(event):
-            got_event[0] = True
-
-        r = events.EventReceiver(object(),
-                                 handlers={"world-war": my_handler},
-                                 node_id="celery.tests",
-                                 )
-        r._receive(message, object())
-        self.assertTrue(got_event[0])
-
-    def test_catch_all_event(self):
-
-        message = {"type": "world-war"}
-
-        got_event = [False]
-
-        def my_handler(event):
-            got_event[0] = True
-
-        r = events.EventReceiver(object(), node_id="celery.tests")
-        events.EventReceiver.handlers["*"] = my_handler
-        try:
-            r._receive(message, object())
-            self.assertTrue(got_event[0])
-        finally:
-            events.EventReceiver.handlers = {}
-
-    def test_itercapture(self):
-        connection = self.app.broker_connection()
-        try:
-            r = self.app.events.Receiver(connection, node_id="celery.tests")
-            it = r.itercapture(timeout=0.0001, wakeup=False)
-            consumer = it.next()
-            self.assertTrue(consumer.queues)
-            self.assertEqual(consumer.callbacks[0], r._receive)
-
-            self.assertRaises(socket.timeout, it.next)
-
-            self.assertRaises(socket.timeout,
-                              r.capture, timeout=0.00001)
-        finally:
-            connection.close()
-
-    def test_itercapture_limit(self):
-        connection = self.app.broker_connection()
-        channel = connection.channel()
-        try:
-            events_received = [0]
-
-            def handler(event):
-                events_received[0] += 1
-
-            producer = self.app.events.Dispatcher(connection,
-                                                  enabled=True,
-                                                  channel=channel)
-            r = self.app.events.Receiver(connection,
-                                         handlers={"*": handler},
-                                         node_id="celery.tests")
-            evs = ["ev1", "ev2", "ev3", "ev4", "ev5"]
-            for ev in evs:
-                producer.send(ev)
-            it = r.itercapture(limit=4, wakeup=True)
-            it.next()  # skip consumer (see itercapture)
-            list(it)
-            self.assertEqual(events_received[0], 4)
-        finally:
-            channel.close()
-            connection.close()
-
-
-class test_misc(unittest.TestCase):
-
-    def setUp(self):
-        self.app = app_or_default()
-
-    def test_State(self):
-        state = self.app.events.State()
-        self.assertDictEqual(dict(state.workers), {})

+ 857 - 0
celery/tests/test_task/__init__.py

@@ -0,0 +1,857 @@
+from datetime import datetime, timedelta
+from functools import wraps
+
+from mock import Mock
+from pyparsing import ParseException
+
+from celery import task
+from celery.app import app_or_default
+from celery.task import task as task_dec
+from celery.exceptions import RetryTaskError
+from celery.execute import send_task
+from celery.result import EagerResult
+from celery.schedules import crontab, crontab_parser
+from celery.utils import uuid
+from celery.utils.timeutils import parse_iso8601
+
+from celery.tests.utils import with_eager_tasks, unittest, StringIO
+
+
+def return_True(*args, **kwargs):
+    # Task run functions can't be closures/lambdas, as they're pickled.
+    return True
+
+
+return_True_task = task_dec()(return_True)
+
+
+def raise_exception(self, **kwargs):
+    raise Exception("%s error" % self.__class__)
+
+
+class MockApplyTask(task.Task):
+
+    def run(self, x, y):
+        return x * y
+
+    @classmethod
+    def apply_async(self, *args, **kwargs):
+        pass
+
+
+class IncrementCounterTask(task.Task):
+    name = "c.unittest.increment_counter_task"
+    count = 0
+
+    def run(self, increment_by=1, **kwargs):
+        increment_by = increment_by or 1
+        self.__class__.count += increment_by
+        return self.__class__.count
+
+
+class RaisingTask(task.Task):
+    name = "c.unittest.raising_task"
+
+    def run(self, **kwargs):
+        raise KeyError("foo")
+
+
+class RetryTask(task.Task):
+    max_retries = 3
+    iterations = 0
+
+    def run(self, arg1, arg2, kwarg=1, max_retries=None, care=True):
+        self.__class__.iterations += 1
+        rmax = self.max_retries if max_retries is None else max_retries
+
+        retries = self.request.retries
+        if care and retries >= rmax:
+            return arg1
+        else:
+            return self.retry(countdown=0, max_retries=max_retries)
+
+
+class RetryTaskNoArgs(task.Task):
+    max_retries = 3
+    iterations = 0
+
+    def run(self, **kwargs):
+        self.__class__.iterations += 1
+
+        retries = kwargs["task_retries"]
+        if retries >= 3:
+            return 42
+        else:
+            return self.retry(kwargs=kwargs, countdown=0)
+
+
+class RetryTaskMockApply(task.Task):
+    max_retries = 3
+    iterations = 0
+    applied = 0
+
+    def run(self, arg1, arg2, kwarg=1, **kwargs):
+        self.__class__.iterations += 1
+
+        retries = kwargs["task_retries"]
+        if retries >= 3:
+            return arg1
+        else:
+            kwargs.update({"kwarg": kwarg})
+            return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
+
+    @classmethod
+    def apply_async(self, *args, **kwargs):
+        self.applied = 1
+
+
+class MyCustomException(Exception):
+    """Random custom exception."""
+
+
+class RetryTaskCustomExc(task.Task):
+    max_retries = 3
+    iterations = 0
+
+    def run(self, arg1, arg2, kwarg=1, **kwargs):
+        self.__class__.iterations += 1
+
+        retries = kwargs["task_retries"]
+        if retries >= 3:
+            return arg1 + kwarg
+        else:
+            try:
+                raise MyCustomException("Elaine Marie Benes")
+            except MyCustomException, exc:
+                kwargs.update({"kwarg": kwarg})
+                return self.retry(args=[arg1, arg2], kwargs=kwargs,
+                                  countdown=0, exc=exc)
+
+
+class TestTaskRetries(unittest.TestCase):
+
+    def test_retry(self):
+        RetryTask.max_retries = 3
+        RetryTask.iterations = 0
+        result = RetryTask.apply([0xFF, 0xFFFF])
+        self.assertEqual(result.get(), 0xFF)
+        self.assertEqual(RetryTask.iterations, 4)
+
+        RetryTask.max_retries = 3
+        RetryTask.iterations = 0
+        result = RetryTask.apply([0xFF, 0xFFFF], {"max_retries": 10})
+        self.assertEqual(result.get(), 0xFF)
+        self.assertEqual(RetryTask.iterations, 11)
+
+    def test_retry_no_args(self):
+        RetryTaskNoArgs.max_retries = 3
+        RetryTaskNoArgs.iterations = 0
+        result = RetryTaskNoArgs.apply()
+        self.assertEqual(result.get(), 42)
+        self.assertEqual(RetryTaskNoArgs.iterations, 4)
+
+    def test_retry_kwargs_can_be_empty(self):
+        self.assertRaises(RetryTaskError, RetryTaskMockApply.retry,
+                            args=[4, 4], kwargs=None)
+
+    def test_retry_not_eager(self):
+        RetryTaskMockApply.request.called_directly = False
+        exc = Exception("baz")
+        try:
+            RetryTaskMockApply.retry(args=[4, 4], kwargs={"task_retries": 0},
+                                     exc=exc, throw=False)
+            self.assertTrue(RetryTaskMockApply.applied)
+        finally:
+            RetryTaskMockApply.applied = 0
+
+        try:
+            self.assertRaises(RetryTaskError, RetryTaskMockApply.retry,
+                    args=[4, 4], kwargs={"task_retries": 0},
+                    exc=exc, throw=True)
+            self.assertTrue(RetryTaskMockApply.applied)
+        finally:
+            RetryTaskMockApply.applied = 0
+
+    def test_retry_with_kwargs(self):
+        RetryTaskCustomExc.max_retries = 3
+        RetryTaskCustomExc.iterations = 0
+        result = RetryTaskCustomExc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
+        self.assertEqual(result.get(), 0xFF + 0xF)
+        self.assertEqual(RetryTaskCustomExc.iterations, 4)
+
+    def test_retry_with_custom_exception(self):
+        RetryTaskCustomExc.max_retries = 2
+        RetryTaskCustomExc.iterations = 0
+        result = RetryTaskCustomExc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
+        self.assertRaises(MyCustomException,
+                          result.get)
+        self.assertEqual(RetryTaskCustomExc.iterations, 3)
+
+    def test_max_retries_exceeded(self):
+        RetryTask.max_retries = 2
+        RetryTask.iterations = 0
+        result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
+        self.assertRaises(RetryTask.MaxRetriesExceededError,
+                          result.get)
+        self.assertEqual(RetryTask.iterations, 3)
+
+        RetryTask.max_retries = 1
+        RetryTask.iterations = 0
+        result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
+        self.assertRaises(RetryTask.MaxRetriesExceededError,
+                          result.get)
+        self.assertEqual(RetryTask.iterations, 2)
+
+
+class TestCeleryTasks(unittest.TestCase):
+
+    def test_unpickle_task(self):
+        import pickle
+
+        @task_dec
+        def xxx():
+            pass
+
+        self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx)
+
+    def createTaskCls(self, cls_name, task_name=None):
+        attrs = {"__module__": self.__module__}
+        if task_name:
+            attrs["name"] = task_name
+
+        cls = type(cls_name, (task.Task, ), attrs)
+        cls.run = return_True
+        return cls
+
+    def test_AsyncResult(self):
+        task_id = uuid()
+        result = RetryTask.AsyncResult(task_id)
+        self.assertEqual(result.backend, RetryTask.backend)
+        self.assertEqual(result.task_id, task_id)
+
+    @with_eager_tasks
+    def test_ping(self):
+        self.assertEqual(task.ping(), 'pong')
+
+    def assertNextTaskDataEqual(self, consumer, presult, task_name,
+            test_eta=False, test_expires=False, **kwargs):
+        next_task = consumer.fetch()
+        task_data = next_task.decode()
+        self.assertEqual(task_data["id"], presult.task_id)
+        self.assertEqual(task_data["task"], task_name)
+        task_kwargs = task_data.get("kwargs", {})
+        if test_eta:
+            self.assertIsInstance(task_data.get("eta"), basestring)
+            to_datetime = parse_iso8601(task_data.get("eta"))
+            self.assertIsInstance(to_datetime, datetime)
+        if test_expires:
+            self.assertIsInstance(task_data.get("expires"), basestring)
+            to_datetime = parse_iso8601(task_data.get("expires"))
+            self.assertIsInstance(to_datetime, datetime)
+        for arg_name, arg_value in kwargs.items():
+            self.assertEqual(task_kwargs.get(arg_name), arg_value)
+
+    def test_incomplete_task_cls(self):
+
+        class IncompleteTask(task.Task):
+            name = "c.unittest.t.itask"
+
+        self.assertRaises(NotImplementedError, IncompleteTask().run)
+
+    def test_task_kwargs_must_be_dictionary(self):
+        self.assertRaises(ValueError, IncrementCounterTask.apply_async,
+                          [], "str")
+
+    def test_task_args_must_be_list(self):
+        self.assertRaises(ValueError, IncrementCounterTask.apply_async,
+                          "str", {})
+
+    def test_regular_task(self):
+        T1 = self.createTaskCls("T1", "c.unittest.t.t1")
+        self.assertIsInstance(T1(), T1)
+        self.assertTrue(T1().run())
+        self.assertTrue(callable(T1()),
+                "Task class is callable()")
+        self.assertTrue(T1()(),
+                "Task class runs run() when called")
+
+        # task name generated out of class module + name.
+        T2 = self.createTaskCls("T2")
+        self.assertTrue(T2().name.endswith("test_task.T2"))
+
+        t1 = T1()
+        consumer = t1.get_consumer()
+        self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
+        consumer.discard_all()
+        self.assertIsNone(consumer.fetch())
+
+        # Without arguments.
+        presult = t1.delay()
+        self.assertNextTaskDataEqual(consumer, presult, t1.name)
+
+        # With arguments.
+        presult2 = t1.apply_async(kwargs=dict(name="George Costanza"))
+        self.assertNextTaskDataEqual(consumer, presult2, t1.name,
+                name="George Costanza")
+
+        # send_task
+        sresult = send_task(t1.name, kwargs=dict(name="Elaine M. Benes"))
+        self.assertNextTaskDataEqual(consumer, sresult, t1.name,
+                name="Elaine M. Benes")
+
+        # With eta.
+        presult2 = t1.apply_async(kwargs=dict(name="George Costanza"),
+                                  eta=datetime.now() + timedelta(days=1),
+                                  expires=datetime.now() + timedelta(days=2))
+        self.assertNextTaskDataEqual(consumer, presult2, t1.name,
+                name="George Costanza", test_eta=True, test_expires=True)
+
+        # With countdown.
+        presult2 = t1.apply_async(kwargs=dict(name="George Costanza"),
+                                  countdown=10, expires=12)
+        self.assertNextTaskDataEqual(consumer, presult2, t1.name,
+                name="George Costanza", test_eta=True, test_expires=True)
+
+        # Discarding all tasks.
+        consumer.discard_all()
+        t1.apply_async()
+        self.assertEqual(consumer.discard_all(), 1)
+        self.assertIsNone(consumer.fetch())
+
+        self.assertFalse(presult.successful())
+        t1.backend.mark_as_done(presult.task_id, result=None)
+        self.assertTrue(presult.successful())
+
+        publisher = t1.get_publisher()
+        self.assertTrue(publisher.exchange)
+
+    def test_context_get(self):
+        request = self.createTaskCls("T1", "c.unittest.t.c.g").request
+        request.foo = 32
+        self.assertEqual(request.get("foo"), 32)
+        self.assertEqual(request.get("bar", 36), 36)
+
+    def test_task_class_repr(self):
+        task = self.createTaskCls("T1", "c.unittest.t.repr")
+        self.assertIn("class Task of", repr(task.app.Task))
+
+    def test_after_return(self):
+        task = self.createTaskCls("T1", "c.unittest.t.after_return")()
+        task.backend = Mock()
+        task.request.chord = 123
+        task.after_return("SUCCESS", 1.0, "foobar", (), {}, None)
+        task.backend.on_chord_part_return.assert_called_with(task)
+
+    def test_send_task_sent_event(self):
+        T1 = self.createTaskCls("T1", "c.unittest.t.t1")
+        conn = T1.app.broker_connection()
+        chan = conn.channel()
+        T1.app.conf.CELERY_SEND_TASK_SENT_EVENT = True
+        dispatcher = [None]
+
+        class Pub(object):
+            channel = chan
+
+            def delay_task(self, *args, **kwargs):
+                dispatcher[0] = kwargs.get("event_dispatcher")
+
+        try:
+            T1.apply_async(publisher=Pub())
+        finally:
+            T1.app.conf.CELERY_SEND_TASK_SENT_EVENT = False
+            chan.close()
+            conn.close()
+
+        self.assertTrue(dispatcher[0])
+
+    def test_get_publisher(self):
+        connection = app_or_default().broker_connection()
+        p = IncrementCounterTask.get_publisher(connection, auto_declare=False,
+                                               exchange="foo")
+        self.assertEqual(p.exchange.name, "foo")
+        p = IncrementCounterTask.get_publisher(connection, auto_declare=False,
+                                               exchange_type="fanout")
+        self.assertEqual(p.exchange.type, "fanout")
+
+    def test_update_state(self):
+
+        @task_dec
+        def yyy():
+            pass
+
+        tid = uuid()
+        yyy.update_state(tid, "FROBULATING", {"fooz": "baaz"})
+        self.assertEqual(yyy.AsyncResult(tid).status, "FROBULATING")
+        self.assertDictEqual(yyy.AsyncResult(tid).result, {"fooz": "baaz"})
+
+        yyy.request.id = tid
+        yyy.update_state(state="FROBUZATING", meta={"fooz": "baaz"})
+        self.assertEqual(yyy.AsyncResult(tid).status, "FROBUZATING")
+        self.assertDictEqual(yyy.AsyncResult(tid).result, {"fooz": "baaz"})
+
+    def test_repr(self):
+
+        @task_dec
+        def task_test_repr():
+            pass
+
+        self.assertIn("task_test_repr", repr(task_test_repr))
+
+    def test_has___name__(self):
+
+        @task_dec
+        def yyy2():
+            pass
+
+        self.assertTrue(yyy2.__name__)
+
+    def test_get_logger(self):
+        T1 = self.createTaskCls("T1", "c.unittest.t.t1")
+        t1 = T1()
+        logfh = StringIO()
+        logger = t1.get_logger(logfile=logfh, loglevel=0)
+        self.assertTrue(logger)
+
+        T1.request.loglevel = 3
+        logger = t1.get_logger(logfile=logfh, loglevel=None)
+        self.assertTrue(logger)
+
+
+class TestTaskSet(unittest.TestCase):
+
+    @with_eager_tasks
+    def test_function_taskset(self):
+        subtasks = [return_True_task.subtask([i]) for i in range(1, 6)]
+        ts = task.TaskSet(subtasks)
+        res = ts.apply_async()
+        self.assertListEqual(res.join(), [True, True, True, True, True])
+
+    def test_counter_taskset(self):
+        IncrementCounterTask.count = 0
+        ts = task.TaskSet(tasks=[
+            IncrementCounterTask.subtask((), {}),
+            IncrementCounterTask.subtask((), {"increment_by": 2}),
+            IncrementCounterTask.subtask((), {"increment_by": 3}),
+            IncrementCounterTask.subtask((), {"increment_by": 4}),
+            IncrementCounterTask.subtask((), {"increment_by": 5}),
+            IncrementCounterTask.subtask((), {"increment_by": 6}),
+            IncrementCounterTask.subtask((), {"increment_by": 7}),
+            IncrementCounterTask.subtask((), {"increment_by": 8}),
+            IncrementCounterTask.subtask((), {"increment_by": 9}),
+        ])
+        self.assertEqual(ts.total, 9)
+
+        consumer = IncrementCounterTask().get_consumer()
+        consumer.purge()
+        consumer.close()
+        taskset_res = ts.apply_async()
+        subtasks = taskset_res.subtasks
+        taskset_id = taskset_res.taskset_id
+        consumer = IncrementCounterTask().get_consumer()
+        for subtask in subtasks:
+            m = consumer.fetch().payload
+            self.assertDictContainsSubset({"taskset": taskset_id,
+                                           "task": IncrementCounterTask.name,
+                                           "id": subtask.task_id}, m)
+            IncrementCounterTask().run(
+                    increment_by=m.get("kwargs", {}).get("increment_by"))
+        self.assertEqual(IncrementCounterTask.count, sum(xrange(1, 10)))
+
+    def test_named_taskset(self):
+        prefix = "test_named_taskset-"
+        ts = task.TaskSet([return_True_task.subtask([1])])
+        res = ts.apply(taskset_id=prefix + uuid())
+        self.assertTrue(res.taskset_id.startswith(prefix))
+
+
+class TestTaskApply(unittest.TestCase):
+
+    def test_apply_throw(self):
+        self.assertRaises(KeyError, RaisingTask.apply, throw=True)
+
+    def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
+        RaisingTask.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
+        try:
+            self.assertRaises(KeyError, RaisingTask.apply)
+        finally:
+            RaisingTask.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = False
+
+    def test_apply(self):
+        IncrementCounterTask.count = 0
+
+        e = IncrementCounterTask.apply()
+        self.assertIsInstance(e, EagerResult)
+        self.assertEqual(e.get(), 1)
+
+        e = IncrementCounterTask.apply(args=[1])
+        self.assertEqual(e.get(), 2)
+
+        e = IncrementCounterTask.apply(kwargs={"increment_by": 4})
+        self.assertEqual(e.get(), 6)
+
+        self.assertTrue(e.successful())
+        self.assertTrue(e.ready())
+        self.assertTrue(repr(e).startswith("<EagerResult:"))
+
+        f = RaisingTask.apply()
+        self.assertTrue(f.ready())
+        self.assertFalse(f.successful())
+        self.assertTrue(f.traceback)
+        self.assertRaises(KeyError, f.get)
+
+
+class MyPeriodic(task.PeriodicTask):
+    run_every = timedelta(hours=1)
+
+
+class TestPeriodicTask(unittest.TestCase):
+
+    def test_must_have_run_every(self):
+        self.assertRaises(NotImplementedError, type, "Foo",
+            (task.PeriodicTask, ), {"__module__": __name__})
+
+    def test_remaining_estimate(self):
+        self.assertIsInstance(
+            MyPeriodic().remaining_estimate(datetime.now()),
+            timedelta)
+
+    def test_is_due_not_due(self):
+        due, remaining = MyPeriodic().is_due(datetime.now())
+        self.assertFalse(due)
+        # This assertion may fail if executed in the
+        # first minute of an hour, thus 59 instead of 60
+        self.assertGreater(remaining, 59)
+
+    def test_is_due(self):
+        p = MyPeriodic()
+        due, remaining = p.is_due(datetime.now() - p.run_every.run_every)
+        self.assertTrue(due)
+        self.assertEqual(remaining,
+                         p.timedelta_seconds(p.run_every.run_every))
+
+    def test_schedule_repr(self):
+        p = MyPeriodic()
+        self.assertTrue(repr(p.run_every))
+
+
+class EveryMinutePeriodic(task.PeriodicTask):
+    run_every = crontab()
+
+
+class QuarterlyPeriodic(task.PeriodicTask):
+    run_every = crontab(minute="*/15")
+
+
+class HourlyPeriodic(task.PeriodicTask):
+    run_every = crontab(minute=30)
+
+
+class DailyPeriodic(task.PeriodicTask):
+    run_every = crontab(hour=7, minute=30)
+
+
+class WeeklyPeriodic(task.PeriodicTask):
+    run_every = crontab(hour=7, minute=30, day_of_week="thursday")
+
+
+def patch_crontab_nowfun(cls, retval):
+
+    def create_patcher(fun):
+
+        @wraps(fun)
+        def __inner(*args, **kwargs):
+            prev_nowfun = cls.run_every.nowfun
+            cls.run_every.nowfun = lambda: retval
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                cls.run_every.nowfun = prev_nowfun
+
+        return __inner
+
+    return create_patcher
+
+
+class test_crontab_parser(unittest.TestCase):
+
+    def test_parse_star(self):
+        self.assertEquals(crontab_parser(24).parse('*'), set(range(24)))
+        self.assertEquals(crontab_parser(60).parse('*'), set(range(60)))
+        self.assertEquals(crontab_parser(7).parse('*'), set(range(7)))
+
+    def test_parse_range(self):
+        self.assertEquals(crontab_parser(60).parse('1-10'),
+                          set(range(1, 10 + 1)))
+        self.assertEquals(crontab_parser(24).parse('0-20'),
+                          set(range(0, 20 + 1)))
+        self.assertEquals(crontab_parser().parse('2-10'),
+                          set(range(2, 10 + 1)))
+
+    def test_parse_groups(self):
+        self.assertEquals(crontab_parser().parse('1,2,3,4'),
+                          set([1, 2, 3, 4]))
+        self.assertEquals(crontab_parser().parse('0,15,30,45'),
+                          set([0, 15, 30, 45]))
+
+    def test_parse_steps(self):
+        self.assertEquals(crontab_parser(8).parse('*/2'),
+                          set([0, 2, 4, 6]))
+        self.assertEquals(crontab_parser().parse('*/2'),
+                          set(i * 2 for i in xrange(30)))
+        self.assertEquals(crontab_parser().parse('*/3'),
+                          set(i * 3 for i in xrange(20)))
+
+    def test_parse_composite(self):
+        self.assertEquals(crontab_parser(8).parse('*/2'), set([0, 2, 4, 6]))
+        self.assertEquals(crontab_parser().parse('2-9/5'), set([5]))
+        self.assertEquals(crontab_parser().parse('2-10/5'), set([5, 10]))
+        self.assertEquals(crontab_parser().parse('2-11/5,3'), set([3, 5, 10]))
+        self.assertEquals(crontab_parser().parse('2-4/3,*/5,0-21/4'),
+                set([0, 3, 4, 5, 8, 10, 12, 15, 16,
+                    20, 25, 30, 35, 40, 45, 50, 55]))
+
+    def test_parse_errors_on_empty_string(self):
+        self.assertRaises(ParseException, crontab_parser(60).parse, '')
+
+    def test_parse_errors_on_empty_group(self):
+        self.assertRaises(ParseException, crontab_parser(60).parse, '1,,2')
+
+    def test_parse_errors_on_empty_steps(self):
+        self.assertRaises(ParseException, crontab_parser(60).parse, '*/')
+
+    def test_parse_errors_on_negative_number(self):
+        self.assertRaises(ParseException, crontab_parser(60).parse, '-20')
+
+    def test_expand_cronspec_eats_iterables(self):
+        self.assertEqual(crontab._expand_cronspec(iter([1, 2, 3]), 100),
+                         set([1, 2, 3]))
+
+    def test_expand_cronspec_invalid_type(self):
+        self.assertRaises(TypeError, crontab._expand_cronspec, object(), 100)
+
+    def test_repr(self):
+        self.assertIn("*", repr(crontab("*")))
+
+    def test_eq(self):
+        self.assertEqual(crontab(day_of_week="1, 2"),
+                         crontab(day_of_week="1-2"))
+        self.assertEqual(crontab(minute="1", hour="2", day_of_week="5"),
+                         crontab(minute="1", hour="2", day_of_week="5"))
+        self.assertNotEqual(crontab(minute="1"), crontab(minute="2"))
+        self.assertFalse(object() == crontab(minute="1"))
+        self.assertFalse(crontab(minute="1") == object())
+
+
+class test_crontab_remaining_estimate(unittest.TestCase):
+
+    def next_ocurrance(self, crontab, now):
+        crontab.nowfun = lambda: now
+        return now + crontab.remaining_estimate(now)
+
+    def test_next_minute(self):
+        next = self.next_ocurrance(crontab(),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEquals(next, datetime(2010, 9, 11, 14, 31))
+
+    def test_not_next_minute(self):
+        next = self.next_ocurrance(crontab(),
+                                   datetime(2010, 9, 11, 14, 59, 15))
+        self.assertEquals(next, datetime(2010, 9, 11, 15, 0))
+
+    def test_this_hour(self):
+        next = self.next_ocurrance(crontab(minute=[5, 42]),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEquals(next, datetime(2010, 9, 11, 14, 42))
+
+    def test_not_this_hour(self):
+        next = self.next_ocurrance(crontab(minute=[5, 10, 15]),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEquals(next, datetime(2010, 9, 11, 15, 5))
+
+    def test_today(self):
+        next = self.next_ocurrance(crontab(minute=[5, 42], hour=[12, 17]),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEquals(next, datetime(2010, 9, 11, 17, 5))
+
+    def test_not_today(self):
+        next = self.next_ocurrance(crontab(minute=[5, 42], hour=[12]),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEquals(next, datetime(2010, 9, 12, 12, 5))
+
+    def test_weekday(self):
+        next = self.next_ocurrance(crontab(minute=30,
+                                           hour=14,
+                                           day_of_week="sat"),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEquals(next, datetime(2010, 9, 18, 14, 30))
+
+    def test_not_weekday(self):
+        next = self.next_ocurrance(crontab(minute=[5, 42],
+                                           day_of_week="mon-fri"),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEquals(next, datetime(2010, 9, 13, 0, 5))
+
+
+class test_crontab_is_due(unittest.TestCase):
+
+    def setUp(self):
+        self.now = datetime.now()
+        self.next_minute = 60 - self.now.second - 1e-6 * self.now.microsecond
+
+    def test_default_crontab_spec(self):
+        c = crontab()
+        self.assertEquals(c.minute, set(range(60)))
+        self.assertEquals(c.hour, set(range(24)))
+        self.assertEquals(c.day_of_week, set(range(7)))
+
+    def test_simple_crontab_spec(self):
+        c = crontab(minute=30)
+        self.assertEquals(c.minute, set([30]))
+        self.assertEquals(c.hour, set(range(24)))
+        self.assertEquals(c.day_of_week, set(range(7)))
+
+    def test_crontab_spec_minute_formats(self):
+        c = crontab(minute=30)
+        self.assertEquals(c.minute, set([30]))
+        c = crontab(minute='30')
+        self.assertEquals(c.minute, set([30]))
+        c = crontab(minute=(30, 40, 50))
+        self.assertEquals(c.minute, set([30, 40, 50]))
+        c = crontab(minute=set([30, 40, 50]))
+        self.assertEquals(c.minute, set([30, 40, 50]))
+
+    def test_crontab_spec_invalid_minute(self):
+        self.assertRaises(ValueError, crontab, minute=60)
+        self.assertRaises(ValueError, crontab, minute='0-100')
+
+    def test_crontab_spec_hour_formats(self):
+        c = crontab(hour=6)
+        self.assertEquals(c.hour, set([6]))
+        c = crontab(hour='5')
+        self.assertEquals(c.hour, set([5]))
+        c = crontab(hour=(4, 8, 12))
+        self.assertEquals(c.hour, set([4, 8, 12]))
+
+    def test_crontab_spec_invalid_hour(self):
+        self.assertRaises(ValueError, crontab, hour=24)
+        self.assertRaises(ValueError, crontab, hour='0-30')
+
+    def test_crontab_spec_dow_formats(self):
+        c = crontab(day_of_week=5)
+        self.assertEquals(c.day_of_week, set([5]))
+        c = crontab(day_of_week='5')
+        self.assertEquals(c.day_of_week, set([5]))
+        c = crontab(day_of_week='fri')
+        self.assertEquals(c.day_of_week, set([5]))
+        c = crontab(day_of_week='tuesday,sunday,fri')
+        self.assertEquals(c.day_of_week, set([0, 2, 5]))
+        c = crontab(day_of_week='mon-fri')
+        self.assertEquals(c.day_of_week, set([1, 2, 3, 4, 5]))
+        c = crontab(day_of_week='*/2')
+        self.assertEquals(c.day_of_week, set([0, 2, 4, 6]))
+
+    def test_crontab_spec_invalid_dow(self):
+        self.assertRaises(ValueError, crontab, day_of_week='fooday-barday')
+        self.assertRaises(ValueError, crontab, day_of_week='1,4,foo')
+        self.assertRaises(ValueError, crontab, day_of_week='7')
+        self.assertRaises(ValueError, crontab, day_of_week='12')
+
+    def test_every_minute_execution_is_due(self):
+        last_ran = self.now - timedelta(seconds=61)
+        due, remaining = EveryMinutePeriodic().is_due(last_ran)
+        self.assertTrue(due)
+        self.assertAlmostEquals(remaining, self.next_minute, 1)
+
+    def test_every_minute_execution_is_not_due(self):
+        last_ran = self.now - timedelta(seconds=self.now.second)
+        due, remaining = EveryMinutePeriodic().is_due(last_ran)
+        self.assertFalse(due)
+        self.assertAlmostEquals(remaining, self.next_minute, 1)
+
+    # 29th of May 2010 is a saturday
+    @patch_crontab_nowfun(HourlyPeriodic, datetime(2010, 5, 29, 10, 30))
+    def test_execution_is_due_on_saturday(self):
+        last_ran = self.now - timedelta(seconds=61)
+        due, remaining = EveryMinutePeriodic().is_due(last_ran)
+        self.assertTrue(due)
+        self.assertAlmostEquals(remaining, self.next_minute, 1)
+
+    # 30th of May 2010 is a sunday
+    @patch_crontab_nowfun(HourlyPeriodic, datetime(2010, 5, 30, 10, 30))
+    def test_execution_is_due_on_sunday(self):
+        last_ran = self.now - timedelta(seconds=61)
+        due, remaining = EveryMinutePeriodic().is_due(last_ran)
+        self.assertTrue(due)
+        self.assertAlmostEquals(remaining, self.next_minute, 1)
+
+    # 31st of May 2010 is a monday
+    @patch_crontab_nowfun(HourlyPeriodic, datetime(2010, 5, 31, 10, 30))
+    def test_execution_is_due_on_monday(self):
+        last_ran = self.now - timedelta(seconds=61)
+        due, remaining = EveryMinutePeriodic().is_due(last_ran)
+        self.assertTrue(due)
+        self.assertAlmostEquals(remaining, self.next_minute, 1)
+
+    @patch_crontab_nowfun(HourlyPeriodic, datetime(2010, 5, 10, 10, 30))
+    def test_every_hour_execution_is_due(self):
+        due, remaining = HourlyPeriodic().is_due(datetime(2010, 5, 10, 6, 30))
+        self.assertTrue(due)
+        self.assertEquals(remaining, 60 * 60)
+
+    @patch_crontab_nowfun(HourlyPeriodic, datetime(2010, 5, 10, 10, 29))
+    def test_every_hour_execution_is_not_due(self):
+        due, remaining = HourlyPeriodic().is_due(datetime(2010, 5, 10, 9, 30))
+        self.assertFalse(due)
+        self.assertEquals(remaining, 60)
+
+    @patch_crontab_nowfun(QuarterlyPeriodic, datetime(2010, 5, 10, 10, 15))
+    def test_first_quarter_execution_is_due(self):
+        due, remaining = QuarterlyPeriodic().is_due(
+                            datetime(2010, 5, 10, 6, 30))
+        self.assertTrue(due)
+        self.assertEquals(remaining, 15 * 60)
+
+    @patch_crontab_nowfun(QuarterlyPeriodic, datetime(2010, 5, 10, 10, 30))
+    def test_second_quarter_execution_is_due(self):
+        due, remaining = QuarterlyPeriodic().is_due(
+                            datetime(2010, 5, 10, 6, 30))
+        self.assertTrue(due)
+        self.assertEquals(remaining, 15 * 60)
+
+    @patch_crontab_nowfun(QuarterlyPeriodic, datetime(2010, 5, 10, 10, 14))
+    def test_first_quarter_execution_is_not_due(self):
+        due, remaining = QuarterlyPeriodic().is_due(
+                            datetime(2010, 5, 10, 10, 0))
+        self.assertFalse(due)
+        self.assertEquals(remaining, 60)
+
+    @patch_crontab_nowfun(QuarterlyPeriodic, datetime(2010, 5, 10, 10, 29))
+    def test_second_quarter_execution_is_not_due(self):
+        due, remaining = QuarterlyPeriodic().is_due(
+                            datetime(2010, 5, 10, 10, 15))
+        self.assertFalse(due)
+        self.assertEquals(remaining, 60)
+
+    @patch_crontab_nowfun(DailyPeriodic, datetime(2010, 5, 10, 7, 30))
+    def test_daily_execution_is_due(self):
+        due, remaining = DailyPeriodic().is_due(datetime(2010, 5, 9, 7, 30))
+        self.assertTrue(due)
+        self.assertEquals(remaining, 24 * 60 * 60)
+
+    @patch_crontab_nowfun(DailyPeriodic, datetime(2010, 5, 10, 10, 30))
+    def test_daily_execution_is_not_due(self):
+        due, remaining = DailyPeriodic().is_due(datetime(2010, 5, 10, 7, 30))
+        self.assertFalse(due)
+        self.assertEquals(remaining, 21 * 60 * 60)
+
+    @patch_crontab_nowfun(WeeklyPeriodic, datetime(2010, 5, 6, 7, 30))
+    def test_weekly_execution_is_due(self):
+        due, remaining = WeeklyPeriodic().is_due(datetime(2010, 4, 30, 7, 30))
+        self.assertTrue(due)
+        self.assertEquals(remaining, 7 * 24 * 60 * 60)
+
+    @patch_crontab_nowfun(WeeklyPeriodic, datetime(2010, 5, 7, 10, 30))
+    def test_weekly_execution_is_not_due(self):
+        due, remaining = WeeklyPeriodic().is_due(datetime(2010, 5, 6, 7, 30))
+        self.assertFalse(due)
+        self.assertEquals(remaining, 6 * 24 * 60 * 60 - 3 * 60 * 60)

+ 0 - 857
celery/tests/test_task/test_task.py

@@ -1,857 +0,0 @@
-from datetime import datetime, timedelta
-from functools import wraps
-
-from mock import Mock
-from pyparsing import ParseException
-
-from celery import task
-from celery.app import app_or_default
-from celery.task import task as task_dec
-from celery.exceptions import RetryTaskError
-from celery.execute import send_task
-from celery.result import EagerResult
-from celery.schedules import crontab, crontab_parser
-from celery.utils import uuid
-from celery.utils.timeutils import parse_iso8601
-
-from celery.tests.utils import with_eager_tasks, unittest, StringIO
-
-
-def return_True(*args, **kwargs):
-    # Task run functions can't be closures/lambdas, as they're pickled.
-    return True
-
-
-return_True_task = task_dec()(return_True)
-
-
-def raise_exception(self, **kwargs):
-    raise Exception("%s error" % self.__class__)
-
-
-class MockApplyTask(task.Task):
-
-    def run(self, x, y):
-        return x * y
-
-    @classmethod
-    def apply_async(self, *args, **kwargs):
-        pass
-
-
-class IncrementCounterTask(task.Task):
-    name = "c.unittest.increment_counter_task"
-    count = 0
-
-    def run(self, increment_by=1, **kwargs):
-        increment_by = increment_by or 1
-        self.__class__.count += increment_by
-        return self.__class__.count
-
-
-class RaisingTask(task.Task):
-    name = "c.unittest.raising_task"
-
-    def run(self, **kwargs):
-        raise KeyError("foo")
-
-
-class RetryTask(task.Task):
-    max_retries = 3
-    iterations = 0
-
-    def run(self, arg1, arg2, kwarg=1, max_retries=None, care=True):
-        self.__class__.iterations += 1
-        rmax = self.max_retries if max_retries is None else max_retries
-
-        retries = self.request.retries
-        if care and retries >= rmax:
-            return arg1
-        else:
-            return self.retry(countdown=0, max_retries=max_retries)
-
-
-class RetryTaskNoArgs(task.Task):
-    max_retries = 3
-    iterations = 0
-
-    def run(self, **kwargs):
-        self.__class__.iterations += 1
-
-        retries = kwargs["task_retries"]
-        if retries >= 3:
-            return 42
-        else:
-            return self.retry(kwargs=kwargs, countdown=0)
-
-
-class RetryTaskMockApply(task.Task):
-    max_retries = 3
-    iterations = 0
-    applied = 0
-
-    def run(self, arg1, arg2, kwarg=1, **kwargs):
-        self.__class__.iterations += 1
-
-        retries = kwargs["task_retries"]
-        if retries >= 3:
-            return arg1
-        else:
-            kwargs.update({"kwarg": kwarg})
-            return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
-
-    @classmethod
-    def apply_async(self, *args, **kwargs):
-        self.applied = 1
-
-
-class MyCustomException(Exception):
-    """Random custom exception."""
-
-
-class RetryTaskCustomExc(task.Task):
-    max_retries = 3
-    iterations = 0
-
-    def run(self, arg1, arg2, kwarg=1, **kwargs):
-        self.__class__.iterations += 1
-
-        retries = kwargs["task_retries"]
-        if retries >= 3:
-            return arg1 + kwarg
-        else:
-            try:
-                raise MyCustomException("Elaine Marie Benes")
-            except MyCustomException, exc:
-                kwargs.update({"kwarg": kwarg})
-                return self.retry(args=[arg1, arg2], kwargs=kwargs,
-                                  countdown=0, exc=exc)
-
-
-class TestTaskRetries(unittest.TestCase):
-
-    def test_retry(self):
-        RetryTask.max_retries = 3
-        RetryTask.iterations = 0
-        result = RetryTask.apply([0xFF, 0xFFFF])
-        self.assertEqual(result.get(), 0xFF)
-        self.assertEqual(RetryTask.iterations, 4)
-
-        RetryTask.max_retries = 3
-        RetryTask.iterations = 0
-        result = RetryTask.apply([0xFF, 0xFFFF], {"max_retries": 10})
-        self.assertEqual(result.get(), 0xFF)
-        self.assertEqual(RetryTask.iterations, 11)
-
-    def test_retry_no_args(self):
-        RetryTaskNoArgs.max_retries = 3
-        RetryTaskNoArgs.iterations = 0
-        result = RetryTaskNoArgs.apply()
-        self.assertEqual(result.get(), 42)
-        self.assertEqual(RetryTaskNoArgs.iterations, 4)
-
-    def test_retry_kwargs_can_be_empty(self):
-        self.assertRaises(RetryTaskError, RetryTaskMockApply.retry,
-                            args=[4, 4], kwargs=None)
-
-    def test_retry_not_eager(self):
-        RetryTaskMockApply.request.called_directly = False
-        exc = Exception("baz")
-        try:
-            RetryTaskMockApply.retry(args=[4, 4], kwargs={"task_retries": 0},
-                                     exc=exc, throw=False)
-            self.assertTrue(RetryTaskMockApply.applied)
-        finally:
-            RetryTaskMockApply.applied = 0
-
-        try:
-            self.assertRaises(RetryTaskError, RetryTaskMockApply.retry,
-                    args=[4, 4], kwargs={"task_retries": 0},
-                    exc=exc, throw=True)
-            self.assertTrue(RetryTaskMockApply.applied)
-        finally:
-            RetryTaskMockApply.applied = 0
-
-    def test_retry_with_kwargs(self):
-        RetryTaskCustomExc.max_retries = 3
-        RetryTaskCustomExc.iterations = 0
-        result = RetryTaskCustomExc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
-        self.assertEqual(result.get(), 0xFF + 0xF)
-        self.assertEqual(RetryTaskCustomExc.iterations, 4)
-
-    def test_retry_with_custom_exception(self):
-        RetryTaskCustomExc.max_retries = 2
-        RetryTaskCustomExc.iterations = 0
-        result = RetryTaskCustomExc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
-        self.assertRaises(MyCustomException,
-                          result.get)
-        self.assertEqual(RetryTaskCustomExc.iterations, 3)
-
-    def test_max_retries_exceeded(self):
-        RetryTask.max_retries = 2
-        RetryTask.iterations = 0
-        result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
-        self.assertRaises(RetryTask.MaxRetriesExceededError,
-                          result.get)
-        self.assertEqual(RetryTask.iterations, 3)
-
-        RetryTask.max_retries = 1
-        RetryTask.iterations = 0
-        result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
-        self.assertRaises(RetryTask.MaxRetriesExceededError,
-                          result.get)
-        self.assertEqual(RetryTask.iterations, 2)
-
-
-class TestCeleryTasks(unittest.TestCase):
-
-    def test_unpickle_task(self):
-        import pickle
-
-        @task_dec
-        def xxx():
-            pass
-
-        self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx)
-
-    def createTaskCls(self, cls_name, task_name=None):
-        attrs = {"__module__": self.__module__}
-        if task_name:
-            attrs["name"] = task_name
-
-        cls = type(cls_name, (task.Task, ), attrs)
-        cls.run = return_True
-        return cls
-
-    def test_AsyncResult(self):
-        task_id = uuid()
-        result = RetryTask.AsyncResult(task_id)
-        self.assertEqual(result.backend, RetryTask.backend)
-        self.assertEqual(result.task_id, task_id)
-
-    @with_eager_tasks
-    def test_ping(self):
-        self.assertEqual(task.ping(), 'pong')
-
-    def assertNextTaskDataEqual(self, consumer, presult, task_name,
-            test_eta=False, test_expires=False, **kwargs):
-        next_task = consumer.fetch()
-        task_data = next_task.decode()
-        self.assertEqual(task_data["id"], presult.task_id)
-        self.assertEqual(task_data["task"], task_name)
-        task_kwargs = task_data.get("kwargs", {})
-        if test_eta:
-            self.assertIsInstance(task_data.get("eta"), basestring)
-            to_datetime = parse_iso8601(task_data.get("eta"))
-            self.assertIsInstance(to_datetime, datetime)
-        if test_expires:
-            self.assertIsInstance(task_data.get("expires"), basestring)
-            to_datetime = parse_iso8601(task_data.get("expires"))
-            self.assertIsInstance(to_datetime, datetime)
-        for arg_name, arg_value in kwargs.items():
-            self.assertEqual(task_kwargs.get(arg_name), arg_value)
-
-    def test_incomplete_task_cls(self):
-
-        class IncompleteTask(task.Task):
-            name = "c.unittest.t.itask"
-
-        self.assertRaises(NotImplementedError, IncompleteTask().run)
-
-    def test_task_kwargs_must_be_dictionary(self):
-        self.assertRaises(ValueError, IncrementCounterTask.apply_async,
-                          [], "str")
-
-    def test_task_args_must_be_list(self):
-        self.assertRaises(ValueError, IncrementCounterTask.apply_async,
-                          "str", {})
-
-    def test_regular_task(self):
-        T1 = self.createTaskCls("T1", "c.unittest.t.t1")
-        self.assertIsInstance(T1(), T1)
-        self.assertTrue(T1().run())
-        self.assertTrue(callable(T1()),
-                "Task class is callable()")
-        self.assertTrue(T1()(),
-                "Task class runs run() when called")
-
-        # task name generated out of class module + name.
-        T2 = self.createTaskCls("T2")
-        self.assertTrue(T2().name.endswith("test_task.T2"))
-
-        t1 = T1()
-        consumer = t1.get_consumer()
-        self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
-        consumer.discard_all()
-        self.assertIsNone(consumer.fetch())
-
-        # Without arguments.
-        presult = t1.delay()
-        self.assertNextTaskDataEqual(consumer, presult, t1.name)
-
-        # With arguments.
-        presult2 = t1.apply_async(kwargs=dict(name="George Costanza"))
-        self.assertNextTaskDataEqual(consumer, presult2, t1.name,
-                name="George Costanza")
-
-        # send_task
-        sresult = send_task(t1.name, kwargs=dict(name="Elaine M. Benes"))
-        self.assertNextTaskDataEqual(consumer, sresult, t1.name,
-                name="Elaine M. Benes")
-
-        # With eta.
-        presult2 = t1.apply_async(kwargs=dict(name="George Costanza"),
-                                  eta=datetime.now() + timedelta(days=1),
-                                  expires=datetime.now() + timedelta(days=2))
-        self.assertNextTaskDataEqual(consumer, presult2, t1.name,
-                name="George Costanza", test_eta=True, test_expires=True)
-
-        # With countdown.
-        presult2 = t1.apply_async(kwargs=dict(name="George Costanza"),
-                                  countdown=10, expires=12)
-        self.assertNextTaskDataEqual(consumer, presult2, t1.name,
-                name="George Costanza", test_eta=True, test_expires=True)
-
-        # Discarding all tasks.
-        consumer.discard_all()
-        t1.apply_async()
-        self.assertEqual(consumer.discard_all(), 1)
-        self.assertIsNone(consumer.fetch())
-
-        self.assertFalse(presult.successful())
-        t1.backend.mark_as_done(presult.task_id, result=None)
-        self.assertTrue(presult.successful())
-
-        publisher = t1.get_publisher()
-        self.assertTrue(publisher.exchange)
-
-    def test_context_get(self):
-        request = self.createTaskCls("T1", "c.unittest.t.c.g").request
-        request.foo = 32
-        self.assertEqual(request.get("foo"), 32)
-        self.assertEqual(request.get("bar", 36), 36)
-
-    def test_task_class_repr(self):
-        task = self.createTaskCls("T1", "c.unittest.t.repr")
-        self.assertIn("class Task of", repr(task.app.Task))
-
-    def test_after_return(self):
-        task = self.createTaskCls("T1", "c.unittest.t.after_return")()
-        task.backend = Mock()
-        task.request.chord = 123
-        task.after_return("SUCCESS", 1.0, "foobar", (), {}, None)
-        task.backend.on_chord_part_return.assert_called_with(task)
-
-    def test_send_task_sent_event(self):
-        T1 = self.createTaskCls("T1", "c.unittest.t.t1")
-        conn = T1.app.broker_connection()
-        chan = conn.channel()
-        T1.app.conf.CELERY_SEND_TASK_SENT_EVENT = True
-        dispatcher = [None]
-
-        class Pub(object):
-            channel = chan
-
-            def delay_task(self, *args, **kwargs):
-                dispatcher[0] = kwargs.get("event_dispatcher")
-
-        try:
-            T1.apply_async(publisher=Pub())
-        finally:
-            T1.app.conf.CELERY_SEND_TASK_SENT_EVENT = False
-            chan.close()
-            conn.close()
-
-        self.assertTrue(dispatcher[0])
-
-    def test_get_publisher(self):
-        connection = app_or_default().broker_connection()
-        p = IncrementCounterTask.get_publisher(connection, auto_declare=False,
-                                               exchange="foo")
-        self.assertEqual(p.exchange.name, "foo")
-        p = IncrementCounterTask.get_publisher(connection, auto_declare=False,
-                                               exchange_type="fanout")
-        self.assertEqual(p.exchange.type, "fanout")
-
-    def test_update_state(self):
-
-        @task_dec
-        def yyy():
-            pass
-
-        tid = uuid()
-        yyy.update_state(tid, "FROBULATING", {"fooz": "baaz"})
-        self.assertEqual(yyy.AsyncResult(tid).status, "FROBULATING")
-        self.assertDictEqual(yyy.AsyncResult(tid).result, {"fooz": "baaz"})
-
-        yyy.request.id = tid
-        yyy.update_state(state="FROBUZATING", meta={"fooz": "baaz"})
-        self.assertEqual(yyy.AsyncResult(tid).status, "FROBUZATING")
-        self.assertDictEqual(yyy.AsyncResult(tid).result, {"fooz": "baaz"})
-
-    def test_repr(self):
-
-        @task_dec
-        def task_test_repr():
-            pass
-
-        self.assertIn("task_test_repr", repr(task_test_repr))
-
-    def test_has___name__(self):
-
-        @task_dec
-        def yyy2():
-            pass
-
-        self.assertTrue(yyy2.__name__)
-
-    def test_get_logger(self):
-        T1 = self.createTaskCls("T1", "c.unittest.t.t1")
-        t1 = T1()
-        logfh = StringIO()
-        logger = t1.get_logger(logfile=logfh, loglevel=0)
-        self.assertTrue(logger)
-
-        T1.request.loglevel = 3
-        logger = t1.get_logger(logfile=logfh, loglevel=None)
-        self.assertTrue(logger)
-
-
-class TestTaskSet(unittest.TestCase):
-
-    @with_eager_tasks
-    def test_function_taskset(self):
-        subtasks = [return_True_task.subtask([i]) for i in range(1, 6)]
-        ts = task.TaskSet(subtasks)
-        res = ts.apply_async()
-        self.assertListEqual(res.join(), [True, True, True, True, True])
-
-    def test_counter_taskset(self):
-        IncrementCounterTask.count = 0
-        ts = task.TaskSet(tasks=[
-            IncrementCounterTask.subtask((), {}),
-            IncrementCounterTask.subtask((), {"increment_by": 2}),
-            IncrementCounterTask.subtask((), {"increment_by": 3}),
-            IncrementCounterTask.subtask((), {"increment_by": 4}),
-            IncrementCounterTask.subtask((), {"increment_by": 5}),
-            IncrementCounterTask.subtask((), {"increment_by": 6}),
-            IncrementCounterTask.subtask((), {"increment_by": 7}),
-            IncrementCounterTask.subtask((), {"increment_by": 8}),
-            IncrementCounterTask.subtask((), {"increment_by": 9}),
-        ])
-        self.assertEqual(ts.total, 9)
-
-        consumer = IncrementCounterTask().get_consumer()
-        consumer.purge()
-        consumer.close()
-        taskset_res = ts.apply_async()
-        subtasks = taskset_res.subtasks
-        taskset_id = taskset_res.taskset_id
-        consumer = IncrementCounterTask().get_consumer()
-        for subtask in subtasks:
-            m = consumer.fetch().payload
-            self.assertDictContainsSubset({"taskset": taskset_id,
-                                           "task": IncrementCounterTask.name,
-                                           "id": subtask.task_id}, m)
-            IncrementCounterTask().run(
-                    increment_by=m.get("kwargs", {}).get("increment_by"))
-        self.assertEqual(IncrementCounterTask.count, sum(xrange(1, 10)))
-
-    def test_named_taskset(self):
-        prefix = "test_named_taskset-"
-        ts = task.TaskSet([return_True_task.subtask([1])])
-        res = ts.apply(taskset_id=prefix + uuid())
-        self.assertTrue(res.taskset_id.startswith(prefix))
-
-
-class TestTaskApply(unittest.TestCase):
-
-    def test_apply_throw(self):
-        self.assertRaises(KeyError, RaisingTask.apply, throw=True)
-
-    def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
-        RaisingTask.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
-        try:
-            self.assertRaises(KeyError, RaisingTask.apply)
-        finally:
-            RaisingTask.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = False
-
-    def test_apply(self):
-        IncrementCounterTask.count = 0
-
-        e = IncrementCounterTask.apply()
-        self.assertIsInstance(e, EagerResult)
-        self.assertEqual(e.get(), 1)
-
-        e = IncrementCounterTask.apply(args=[1])
-        self.assertEqual(e.get(), 2)
-
-        e = IncrementCounterTask.apply(kwargs={"increment_by": 4})
-        self.assertEqual(e.get(), 6)
-
-        self.assertTrue(e.successful())
-        self.assertTrue(e.ready())
-        self.assertTrue(repr(e).startswith("<EagerResult:"))
-
-        f = RaisingTask.apply()
-        self.assertTrue(f.ready())
-        self.assertFalse(f.successful())
-        self.assertTrue(f.traceback)
-        self.assertRaises(KeyError, f.get)
-
-
-class MyPeriodic(task.PeriodicTask):
-    run_every = timedelta(hours=1)
-
-
-class TestPeriodicTask(unittest.TestCase):
-
-    def test_must_have_run_every(self):
-        self.assertRaises(NotImplementedError, type, "Foo",
-            (task.PeriodicTask, ), {"__module__": __name__})
-
-    def test_remaining_estimate(self):
-        self.assertIsInstance(
-            MyPeriodic().remaining_estimate(datetime.now()),
-            timedelta)
-
-    def test_is_due_not_due(self):
-        due, remaining = MyPeriodic().is_due(datetime.now())
-        self.assertFalse(due)
-        # This assertion may fail if executed in the
-        # first minute of an hour, thus 59 instead of 60
-        self.assertGreater(remaining, 59)
-
-    def test_is_due(self):
-        p = MyPeriodic()
-        due, remaining = p.is_due(datetime.now() - p.run_every.run_every)
-        self.assertTrue(due)
-        self.assertEqual(remaining,
-                         p.timedelta_seconds(p.run_every.run_every))
-
-    def test_schedule_repr(self):
-        p = MyPeriodic()
-        self.assertTrue(repr(p.run_every))
-
-
-class EveryMinutePeriodic(task.PeriodicTask):
-    run_every = crontab()
-
-
-class QuarterlyPeriodic(task.PeriodicTask):
-    run_every = crontab(minute="*/15")
-
-
-class HourlyPeriodic(task.PeriodicTask):
-    run_every = crontab(minute=30)
-
-
-class DailyPeriodic(task.PeriodicTask):
-    run_every = crontab(hour=7, minute=30)
-
-
-class WeeklyPeriodic(task.PeriodicTask):
-    run_every = crontab(hour=7, minute=30, day_of_week="thursday")
-
-
-def patch_crontab_nowfun(cls, retval):
-
-    def create_patcher(fun):
-
-        @wraps(fun)
-        def __inner(*args, **kwargs):
-            prev_nowfun = cls.run_every.nowfun
-            cls.run_every.nowfun = lambda: retval
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                cls.run_every.nowfun = prev_nowfun
-
-        return __inner
-
-    return create_patcher
-
-
-class test_crontab_parser(unittest.TestCase):
-
-    def test_parse_star(self):
-        self.assertEquals(crontab_parser(24).parse('*'), set(range(24)))
-        self.assertEquals(crontab_parser(60).parse('*'), set(range(60)))
-        self.assertEquals(crontab_parser(7).parse('*'), set(range(7)))
-
-    def test_parse_range(self):
-        self.assertEquals(crontab_parser(60).parse('1-10'),
-                          set(range(1, 10 + 1)))
-        self.assertEquals(crontab_parser(24).parse('0-20'),
-                          set(range(0, 20 + 1)))
-        self.assertEquals(crontab_parser().parse('2-10'),
-                          set(range(2, 10 + 1)))
-
-    def test_parse_groups(self):
-        self.assertEquals(crontab_parser().parse('1,2,3,4'),
-                          set([1, 2, 3, 4]))
-        self.assertEquals(crontab_parser().parse('0,15,30,45'),
-                          set([0, 15, 30, 45]))
-
-    def test_parse_steps(self):
-        self.assertEquals(crontab_parser(8).parse('*/2'),
-                          set([0, 2, 4, 6]))
-        self.assertEquals(crontab_parser().parse('*/2'),
-                          set(i * 2 for i in xrange(30)))
-        self.assertEquals(crontab_parser().parse('*/3'),
-                          set(i * 3 for i in xrange(20)))
-
-    def test_parse_composite(self):
-        self.assertEquals(crontab_parser(8).parse('*/2'), set([0, 2, 4, 6]))
-        self.assertEquals(crontab_parser().parse('2-9/5'), set([5]))
-        self.assertEquals(crontab_parser().parse('2-10/5'), set([5, 10]))
-        self.assertEquals(crontab_parser().parse('2-11/5,3'), set([3, 5, 10]))
-        self.assertEquals(crontab_parser().parse('2-4/3,*/5,0-21/4'),
-                set([0, 3, 4, 5, 8, 10, 12, 15, 16,
-                    20, 25, 30, 35, 40, 45, 50, 55]))
-
-    def test_parse_errors_on_empty_string(self):
-        self.assertRaises(ParseException, crontab_parser(60).parse, '')
-
-    def test_parse_errors_on_empty_group(self):
-        self.assertRaises(ParseException, crontab_parser(60).parse, '1,,2')
-
-    def test_parse_errors_on_empty_steps(self):
-        self.assertRaises(ParseException, crontab_parser(60).parse, '*/')
-
-    def test_parse_errors_on_negative_number(self):
-        self.assertRaises(ParseException, crontab_parser(60).parse, '-20')
-
-    def test_expand_cronspec_eats_iterables(self):
-        self.assertEqual(crontab._expand_cronspec(iter([1, 2, 3]), 100),
-                         set([1, 2, 3]))
-
-    def test_expand_cronspec_invalid_type(self):
-        self.assertRaises(TypeError, crontab._expand_cronspec, object(), 100)
-
-    def test_repr(self):
-        self.assertIn("*", repr(crontab("*")))
-
-    def test_eq(self):
-        self.assertEqual(crontab(day_of_week="1, 2"),
-                         crontab(day_of_week="1-2"))
-        self.assertEqual(crontab(minute="1", hour="2", day_of_week="5"),
-                         crontab(minute="1", hour="2", day_of_week="5"))
-        self.assertNotEqual(crontab(minute="1"), crontab(minute="2"))
-        self.assertFalse(object() == crontab(minute="1"))
-        self.assertFalse(crontab(minute="1") == object())
-
-
-class test_crontab_remaining_estimate(unittest.TestCase):
-
-    def next_ocurrance(self, crontab, now):
-        crontab.nowfun = lambda: now
-        return now + crontab.remaining_estimate(now)
-
-    def test_next_minute(self):
-        next = self.next_ocurrance(crontab(),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEquals(next, datetime(2010, 9, 11, 14, 31))
-
-    def test_not_next_minute(self):
-        next = self.next_ocurrance(crontab(),
-                                   datetime(2010, 9, 11, 14, 59, 15))
-        self.assertEquals(next, datetime(2010, 9, 11, 15, 0))
-
-    def test_this_hour(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEquals(next, datetime(2010, 9, 11, 14, 42))
-
-    def test_not_this_hour(self):
-        next = self.next_ocurrance(crontab(minute=[5, 10, 15]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEquals(next, datetime(2010, 9, 11, 15, 5))
-
-    def test_today(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42], hour=[12, 17]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEquals(next, datetime(2010, 9, 11, 17, 5))
-
-    def test_not_today(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42], hour=[12]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEquals(next, datetime(2010, 9, 12, 12, 5))
-
-    def test_weekday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_week="sat"),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEquals(next, datetime(2010, 9, 18, 14, 30))
-
-    def test_not_weekday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week="mon-fri"),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEquals(next, datetime(2010, 9, 13, 0, 5))
-
-
-class test_crontab_is_due(unittest.TestCase):
-
-    def setUp(self):
-        self.now = datetime.now()
-        self.next_minute = 60 - self.now.second - 1e-6 * self.now.microsecond
-
-    def test_default_crontab_spec(self):
-        c = crontab()
-        self.assertEquals(c.minute, set(range(60)))
-        self.assertEquals(c.hour, set(range(24)))
-        self.assertEquals(c.day_of_week, set(range(7)))
-
-    def test_simple_crontab_spec(self):
-        c = crontab(minute=30)
-        self.assertEquals(c.minute, set([30]))
-        self.assertEquals(c.hour, set(range(24)))
-        self.assertEquals(c.day_of_week, set(range(7)))
-
-    def test_crontab_spec_minute_formats(self):
-        c = crontab(minute=30)
-        self.assertEquals(c.minute, set([30]))
-        c = crontab(minute='30')
-        self.assertEquals(c.minute, set([30]))
-        c = crontab(minute=(30, 40, 50))
-        self.assertEquals(c.minute, set([30, 40, 50]))
-        c = crontab(minute=set([30, 40, 50]))
-        self.assertEquals(c.minute, set([30, 40, 50]))
-
-    def test_crontab_spec_invalid_minute(self):
-        self.assertRaises(ValueError, crontab, minute=60)
-        self.assertRaises(ValueError, crontab, minute='0-100')
-
-    def test_crontab_spec_hour_formats(self):
-        c = crontab(hour=6)
-        self.assertEquals(c.hour, set([6]))
-        c = crontab(hour='5')
-        self.assertEquals(c.hour, set([5]))
-        c = crontab(hour=(4, 8, 12))
-        self.assertEquals(c.hour, set([4, 8, 12]))
-
-    def test_crontab_spec_invalid_hour(self):
-        self.assertRaises(ValueError, crontab, hour=24)
-        self.assertRaises(ValueError, crontab, hour='0-30')
-
-    def test_crontab_spec_dow_formats(self):
-        c = crontab(day_of_week=5)
-        self.assertEquals(c.day_of_week, set([5]))
-        c = crontab(day_of_week='5')
-        self.assertEquals(c.day_of_week, set([5]))
-        c = crontab(day_of_week='fri')
-        self.assertEquals(c.day_of_week, set([5]))
-        c = crontab(day_of_week='tuesday,sunday,fri')
-        self.assertEquals(c.day_of_week, set([0, 2, 5]))
-        c = crontab(day_of_week='mon-fri')
-        self.assertEquals(c.day_of_week, set([1, 2, 3, 4, 5]))
-        c = crontab(day_of_week='*/2')
-        self.assertEquals(c.day_of_week, set([0, 2, 4, 6]))
-
-    def test_crontab_spec_invalid_dow(self):
-        self.assertRaises(ValueError, crontab, day_of_week='fooday-barday')
-        self.assertRaises(ValueError, crontab, day_of_week='1,4,foo')
-        self.assertRaises(ValueError, crontab, day_of_week='7')
-        self.assertRaises(ValueError, crontab, day_of_week='12')
-
-    def test_every_minute_execution_is_due(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = EveryMinutePeriodic().is_due(last_ran)
-        self.assertTrue(due)
-        self.assertAlmostEquals(remaining, self.next_minute, 1)
-
-    def test_every_minute_execution_is_not_due(self):
-        last_ran = self.now - timedelta(seconds=self.now.second)
-        due, remaining = EveryMinutePeriodic().is_due(last_ran)
-        self.assertFalse(due)
-        self.assertAlmostEquals(remaining, self.next_minute, 1)
-
-    # 29th of May 2010 is a saturday
-    @patch_crontab_nowfun(HourlyPeriodic, datetime(2010, 5, 29, 10, 30))
-    def test_execution_is_due_on_saturday(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = EveryMinutePeriodic().is_due(last_ran)
-        self.assertTrue(due)
-        self.assertAlmostEquals(remaining, self.next_minute, 1)
-
-    # 30th of May 2010 is a sunday
-    @patch_crontab_nowfun(HourlyPeriodic, datetime(2010, 5, 30, 10, 30))
-    def test_execution_is_due_on_sunday(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = EveryMinutePeriodic().is_due(last_ran)
-        self.assertTrue(due)
-        self.assertAlmostEquals(remaining, self.next_minute, 1)
-
-    # 31st of May 2010 is a monday
-    @patch_crontab_nowfun(HourlyPeriodic, datetime(2010, 5, 31, 10, 30))
-    def test_execution_is_due_on_monday(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = EveryMinutePeriodic().is_due(last_ran)
-        self.assertTrue(due)
-        self.assertAlmostEquals(remaining, self.next_minute, 1)
-
-    @patch_crontab_nowfun(HourlyPeriodic, datetime(2010, 5, 10, 10, 30))
-    def test_every_hour_execution_is_due(self):
-        due, remaining = HourlyPeriodic().is_due(datetime(2010, 5, 10, 6, 30))
-        self.assertTrue(due)
-        self.assertEquals(remaining, 60 * 60)
-
-    @patch_crontab_nowfun(HourlyPeriodic, datetime(2010, 5, 10, 10, 29))
-    def test_every_hour_execution_is_not_due(self):
-        due, remaining = HourlyPeriodic().is_due(datetime(2010, 5, 10, 9, 30))
-        self.assertFalse(due)
-        self.assertEquals(remaining, 60)
-
-    @patch_crontab_nowfun(QuarterlyPeriodic, datetime(2010, 5, 10, 10, 15))
-    def test_first_quarter_execution_is_due(self):
-        due, remaining = QuarterlyPeriodic().is_due(
-                            datetime(2010, 5, 10, 6, 30))
-        self.assertTrue(due)
-        self.assertEquals(remaining, 15 * 60)
-
-    @patch_crontab_nowfun(QuarterlyPeriodic, datetime(2010, 5, 10, 10, 30))
-    def test_second_quarter_execution_is_due(self):
-        due, remaining = QuarterlyPeriodic().is_due(
-                            datetime(2010, 5, 10, 6, 30))
-        self.assertTrue(due)
-        self.assertEquals(remaining, 15 * 60)
-
-    @patch_crontab_nowfun(QuarterlyPeriodic, datetime(2010, 5, 10, 10, 14))
-    def test_first_quarter_execution_is_not_due(self):
-        due, remaining = QuarterlyPeriodic().is_due(
-                            datetime(2010, 5, 10, 10, 0))
-        self.assertFalse(due)
-        self.assertEquals(remaining, 60)
-
-    @patch_crontab_nowfun(QuarterlyPeriodic, datetime(2010, 5, 10, 10, 29))
-    def test_second_quarter_execution_is_not_due(self):
-        due, remaining = QuarterlyPeriodic().is_due(
-                            datetime(2010, 5, 10, 10, 15))
-        self.assertFalse(due)
-        self.assertEquals(remaining, 60)
-
-    @patch_crontab_nowfun(DailyPeriodic, datetime(2010, 5, 10, 7, 30))
-    def test_daily_execution_is_due(self):
-        due, remaining = DailyPeriodic().is_due(datetime(2010, 5, 9, 7, 30))
-        self.assertTrue(due)
-        self.assertEquals(remaining, 24 * 60 * 60)
-
-    @patch_crontab_nowfun(DailyPeriodic, datetime(2010, 5, 10, 10, 30))
-    def test_daily_execution_is_not_due(self):
-        due, remaining = DailyPeriodic().is_due(datetime(2010, 5, 10, 7, 30))
-        self.assertFalse(due)
-        self.assertEquals(remaining, 21 * 60 * 60)
-
-    @patch_crontab_nowfun(WeeklyPeriodic, datetime(2010, 5, 6, 7, 30))
-    def test_weekly_execution_is_due(self):
-        due, remaining = WeeklyPeriodic().is_due(datetime(2010, 4, 30, 7, 30))
-        self.assertTrue(due)
-        self.assertEquals(remaining, 7 * 24 * 60 * 60)
-
-    @patch_crontab_nowfun(WeeklyPeriodic, datetime(2010, 5, 7, 10, 30))
-    def test_weekly_execution_is_not_due(self):
-        due, remaining = WeeklyPeriodic().is_due(datetime(2010, 5, 6, 7, 30))
-        self.assertFalse(due)
-        self.assertEquals(remaining, 6 * 24 * 60 * 60 - 3 * 60 * 60)

+ 168 - 0
celery/tests/test_utils/__init__.py

@@ -0,0 +1,168 @@
+import pickle
+from celery.tests.utils import unittest
+
+from celery import utils
+from celery.utils import promise, mpromise, maybe_promise
+
+
+def double(x):
+    return x * 2
+
+
+class test_chunks(unittest.TestCase):
+
+    def test_chunks(self):
+
+        # n == 2
+        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2)
+        self.assertListEqual(list(x),
+            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]])
+
+        # n == 3
+        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3)
+        self.assertListEqual(list(x),
+            [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]])
+
+        # n == 2 (exact)
+        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 2)
+        self.assertListEqual(list(x),
+            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
+
+
+class test_utils(unittest.TestCase):
+
+    def test_get_full_cls_name(self):
+        Class = type("Fox", (object, ), {"__module__": "quick.brown"})
+        self.assertEqual(utils.get_full_cls_name(Class), "quick.brown.Fox")
+
+    def test_is_iterable(self):
+        for a in "f", ["f"], ("f", ), {"f": "f"}:
+            self.assertTrue(utils.is_iterable(a))
+        for b in object(), 1:
+            self.assertFalse(utils.is_iterable(b))
+
+    def test_padlist(self):
+        self.assertListEqual(utils.padlist(["George", "Costanza", "NYC"], 3),
+                ["George", "Costanza", "NYC"])
+        self.assertListEqual(utils.padlist(["George", "Costanza"], 3),
+                ["George", "Costanza", None])
+        self.assertListEqual(utils.padlist(["George", "Costanza", "NYC"], 4,
+                                           default="Earth"),
+                ["George", "Costanza", "NYC", "Earth"])
+
+    def test_firstmethod_AttributeError(self):
+        self.assertIsNone(utils.firstmethod("foo")([object()]))
+
+    def test_firstmethod_promises(self):
+
+        class A(object):
+
+            def __init__(self, value=None):
+                self.value = value
+
+            def m(self):
+                return self.value
+
+        self.assertEqual("four", utils.firstmethod("m")([
+            A(), A(), A(), A("four"), A("five")]))
+        self.assertEqual("four", utils.firstmethod("m")([
+            A(), A(), A(), promise(lambda: A("four")), A("five")]))
+
+    def test_first(self):
+        iterations = [0]
+
+        def predicate(value):
+            iterations[0] += 1
+            if value == 5:
+                return True
+            return False
+
+        self.assertEqual(5, utils.first(predicate, xrange(10)))
+        self.assertEqual(iterations[0], 6)
+
+        iterations[0] = 0
+        self.assertIsNone(utils.first(predicate, xrange(10, 20)))
+        self.assertEqual(iterations[0], 10)
+
+    def test_get_cls_by_name__instance_returns_instance(self):
+        instance = object()
+        self.assertIs(utils.get_cls_by_name(instance), instance)
+
+    def test_truncate_text(self):
+        self.assertEqual(utils.truncate_text("ABCDEFGHI", 3), "ABC...")
+        self.assertEqual(utils.truncate_text("ABCDEFGHI", 10), "ABCDEFGHI")
+
+    def test_abbr(self):
+        self.assertEqual(utils.abbr(None, 3), "???")
+        self.assertEqual(utils.abbr("ABCDEFGHI", 6), "ABC...")
+        self.assertEqual(utils.abbr("ABCDEFGHI", 20), "ABCDEFGHI")
+        self.assertEqual(utils.abbr("ABCDEFGHI", 6, None), "ABCDEF")
+
+    def test_abbrtask(self):
+        self.assertEqual(utils.abbrtask(None, 3), "???")
+        self.assertEqual(utils.abbrtask("feeds.tasks.refresh", 10),
+                                        "[.]refresh")
+        self.assertEqual(utils.abbrtask("feeds.tasks.refresh", 30),
+                                        "feeds.tasks.refresh")
+
+    def test_cached_property(self):
+
+        def fun(obj):
+            return fun.value
+
+        x = utils.cached_property(fun)
+        self.assertIs(x.__get__(None), x)
+        self.assertIs(x.__set__(None, None), x)
+        self.assertIs(x.__delete__(None), x)
+
+
+class test_promise(unittest.TestCase):
+
+    def test__str__(self):
+        self.assertEqual(str(promise(lambda: "the quick brown fox")),
+                "the quick brown fox")
+
+    def test__repr__(self):
+        self.assertEqual(repr(promise(lambda: "fi fa fo")),
+                "'fi fa fo'")
+
+    def test_evaluate(self):
+        self.assertEqual(promise(lambda: 2 + 2)(), 4)
+        self.assertEqual(promise(lambda x: x * 4, 2), 8)
+        self.assertEqual(promise(lambda x: x * 8, 2)(), 16)
+
+    def test_cmp(self):
+        self.assertEqual(promise(lambda: 10), promise(lambda: 10))
+        self.assertNotEqual(promise(lambda: 10), promise(lambda: 20))
+
+    def test__reduce__(self):
+        x = promise(double, 4)
+        y = pickle.loads(pickle.dumps(x))
+        self.assertEqual(x(), y())
+
+    def test__deepcopy__(self):
+        from copy import deepcopy
+        x = promise(double, 4)
+        y = deepcopy(x)
+        self.assertEqual(x._fun, y._fun)
+        self.assertEqual(x._args, y._args)
+        self.assertEqual(x(), y())
+
+
+class test_mpromise(unittest.TestCase):
+
+    def test_is_memoized(self):
+
+        it = iter(xrange(20, 30))
+        p = mpromise(it.next)
+        self.assertEqual(p(), 20)
+        self.assertTrue(p.evaluated)
+        self.assertEqual(p(), 20)
+        self.assertEqual(repr(p), "20")
+
+
+class test_maybe_promise(unittest.TestCase):
+
+    def test_evaluates(self):
+        self.assertEqual(maybe_promise(promise(lambda: 10)), 10)
+        self.assertEqual(maybe_promise(20), 20)

+ 0 - 168
celery/tests/test_utils/test_utils.py

@@ -1,168 +0,0 @@
-import pickle
-from celery.tests.utils import unittest
-
-from celery import utils
-from celery.utils import promise, mpromise, maybe_promise
-
-
-def double(x):
-    return x * 2
-
-
-class test_chunks(unittest.TestCase):
-
-    def test_chunks(self):
-
-        # n == 2
-        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2)
-        self.assertListEqual(list(x),
-            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]])
-
-        # n == 3
-        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3)
-        self.assertListEqual(list(x),
-            [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]])
-
-        # n == 2 (exact)
-        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 2)
-        self.assertListEqual(list(x),
-            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
-
-
-class test_utils(unittest.TestCase):
-
-    def test_get_full_cls_name(self):
-        Class = type("Fox", (object, ), {"__module__": "quick.brown"})
-        self.assertEqual(utils.get_full_cls_name(Class), "quick.brown.Fox")
-
-    def test_is_iterable(self):
-        for a in "f", ["f"], ("f", ), {"f": "f"}:
-            self.assertTrue(utils.is_iterable(a))
-        for b in object(), 1:
-            self.assertFalse(utils.is_iterable(b))
-
-    def test_padlist(self):
-        self.assertListEqual(utils.padlist(["George", "Costanza", "NYC"], 3),
-                ["George", "Costanza", "NYC"])
-        self.assertListEqual(utils.padlist(["George", "Costanza"], 3),
-                ["George", "Costanza", None])
-        self.assertListEqual(utils.padlist(["George", "Costanza", "NYC"], 4,
-                                           default="Earth"),
-                ["George", "Costanza", "NYC", "Earth"])
-
-    def test_firstmethod_AttributeError(self):
-        self.assertIsNone(utils.firstmethod("foo")([object()]))
-
-    def test_firstmethod_promises(self):
-
-        class A(object):
-
-            def __init__(self, value=None):
-                self.value = value
-
-            def m(self):
-                return self.value
-
-        self.assertEqual("four", utils.firstmethod("m")([
-            A(), A(), A(), A("four"), A("five")]))
-        self.assertEqual("four", utils.firstmethod("m")([
-            A(), A(), A(), promise(lambda: A("four")), A("five")]))
-
-    def test_first(self):
-        iterations = [0]
-
-        def predicate(value):
-            iterations[0] += 1
-            if value == 5:
-                return True
-            return False
-
-        self.assertEqual(5, utils.first(predicate, xrange(10)))
-        self.assertEqual(iterations[0], 6)
-
-        iterations[0] = 0
-        self.assertIsNone(utils.first(predicate, xrange(10, 20)))
-        self.assertEqual(iterations[0], 10)
-
-    def test_get_cls_by_name__instance_returns_instance(self):
-        instance = object()
-        self.assertIs(utils.get_cls_by_name(instance), instance)
-
-    def test_truncate_text(self):
-        self.assertEqual(utils.truncate_text("ABCDEFGHI", 3), "ABC...")
-        self.assertEqual(utils.truncate_text("ABCDEFGHI", 10), "ABCDEFGHI")
-
-    def test_abbr(self):
-        self.assertEqual(utils.abbr(None, 3), "???")
-        self.assertEqual(utils.abbr("ABCDEFGHI", 6), "ABC...")
-        self.assertEqual(utils.abbr("ABCDEFGHI", 20), "ABCDEFGHI")
-        self.assertEqual(utils.abbr("ABCDEFGHI", 6, None), "ABCDEF")
-
-    def test_abbrtask(self):
-        self.assertEqual(utils.abbrtask(None, 3), "???")
-        self.assertEqual(utils.abbrtask("feeds.tasks.refresh", 10),
-                                        "[.]refresh")
-        self.assertEqual(utils.abbrtask("feeds.tasks.refresh", 30),
-                                        "feeds.tasks.refresh")
-
-    def test_cached_property(self):
-
-        def fun(obj):
-            return fun.value
-
-        x = utils.cached_property(fun)
-        self.assertIs(x.__get__(None), x)
-        self.assertIs(x.__set__(None, None), x)
-        self.assertIs(x.__delete__(None), x)
-
-
-class test_promise(unittest.TestCase):
-
-    def test__str__(self):
-        self.assertEqual(str(promise(lambda: "the quick brown fox")),
-                "the quick brown fox")
-
-    def test__repr__(self):
-        self.assertEqual(repr(promise(lambda: "fi fa fo")),
-                "'fi fa fo'")
-
-    def test_evaluate(self):
-        self.assertEqual(promise(lambda: 2 + 2)(), 4)
-        self.assertEqual(promise(lambda x: x * 4, 2), 8)
-        self.assertEqual(promise(lambda x: x * 8, 2)(), 16)
-
-    def test_cmp(self):
-        self.assertEqual(promise(lambda: 10), promise(lambda: 10))
-        self.assertNotEqual(promise(lambda: 10), promise(lambda: 20))
-
-    def test__reduce__(self):
-        x = promise(double, 4)
-        y = pickle.loads(pickle.dumps(x))
-        self.assertEqual(x(), y())
-
-    def test__deepcopy__(self):
-        from copy import deepcopy
-        x = promise(double, 4)
-        y = deepcopy(x)
-        self.assertEqual(x._fun, y._fun)
-        self.assertEqual(x._args, y._args)
-        self.assertEqual(x(), y())
-
-
-class test_mpromise(unittest.TestCase):
-
-    def test_is_memoized(self):
-
-        it = iter(xrange(20, 30))
-        p = mpromise(it.next)
-        self.assertEqual(p(), 20)
-        self.assertTrue(p.evaluated)
-        self.assertEqual(p(), 20)
-        self.assertEqual(repr(p), "20")
-
-
-class test_maybe_promise(unittest.TestCase):
-
-    def test_evaluates(self):
-        self.assertEqual(maybe_promise(promise(lambda: 10)), 10)
-        self.assertEqual(maybe_promise(20), 20)

+ 889 - 0
celery/tests/test_worker/__init__.py

@@ -0,0 +1,889 @@
+from __future__ import with_statement
+
+import socket
+import sys
+
+from collections import deque
+from datetime import datetime, timedelta
+from Queue import Empty
+
+from kombu.transport.base import Message
+from kombu.connection import BrokerConnection
+from mock import Mock, patch
+
+from celery import current_app
+from celery.concurrency.base import BasePool
+from celery.exceptions import SystemTerminate
+from celery.task import task as task_dec
+from celery.task import periodic_task as periodic_task_dec
+from celery.utils import uuid
+from celery.worker import WorkController
+from celery.worker.buckets import FastQueue
+from celery.worker.job import TaskRequest
+from celery.worker.consumer import Consumer as MainConsumer
+from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX, CLOSE
+from celery.utils.serialization import pickle
+from celery.utils.timer2 import Timer
+
+from celery.tests.compat import catch_warnings
+from celery.tests.utils import unittest
+from celery.tests.utils import AppCase
+
+
+class PlaceHolder(object):
+        pass
+
+
+class MyKombuConsumer(MainConsumer):
+    broadcast_consumer = Mock()
+    task_consumer = Mock()
+
+    def __init__(self, *args, **kwargs):
+        kwargs.setdefault("pool", BasePool(2))
+        super(MyKombuConsumer, self).__init__(*args, **kwargs)
+
+    def restart_heartbeat(self):
+        self.heart = None
+
+
+class MockNode(object):
+    commands = []
+
+    def handle_message(self, body, message):
+        self.commands.append(body.pop("command", None))
+
+
+class MockEventDispatcher(object):
+    sent = []
+    closed = False
+    flushed = False
+    _outbound_buffer = []
+
+    def send(self, event, *args, **kwargs):
+        self.sent.append(event)
+
+    def close(self):
+        self.closed = True
+
+    def flush(self):
+        self.flushed = True
+
+
+class MockHeart(object):
+    closed = False
+
+    def stop(self):
+        self.closed = True
+
+
+@task_dec()
+def foo_task(x, y, z, **kwargs):
+    return x * y * z
+
+
+@periodic_task_dec(run_every=60)
+def foo_periodic_task():
+    return "foo"
+
+
+def create_message(channel, **data):
+    data.setdefault("id", uuid())
+    channel.no_ack_consumers = set()
+    return Message(channel, body=pickle.dumps(dict(**data)),
+                   content_type="application/x-python-serialize",
+                   content_encoding="binary",
+                   delivery_info={"consumer_tag": "mock"})
+
+
+class test_QoS(unittest.TestCase):
+
+    class _QoS(QoS):
+        def __init__(self, value):
+            self.value = value
+            QoS.__init__(self, None, value, None)
+
+        def set(self, value):
+            return value
+
+    def test_qos_increment_decrement(self):
+        qos = self._QoS(10)
+        self.assertEqual(qos.increment(), 11)
+        self.assertEqual(qos.increment(3), 14)
+        self.assertEqual(qos.increment(-30), 14)
+        self.assertEqual(qos.decrement(7), 7)
+        self.assertEqual(qos.decrement(), 6)
+        self.assertRaises(AssertionError, qos.decrement, 10)
+
+    def test_qos_disabled_increment_decrement(self):
+        qos = self._QoS(0)
+        self.assertEqual(qos.increment(), 0)
+        self.assertEqual(qos.increment(3), 0)
+        self.assertEqual(qos.increment(-30), 0)
+        self.assertEqual(qos.decrement(7), 0)
+        self.assertEqual(qos.decrement(), 0)
+        self.assertEqual(qos.decrement(10), 0)
+
+    def test_qos_thread_safe(self):
+        qos = self._QoS(10)
+
+        def add():
+            for i in xrange(1000):
+                qos.increment()
+
+        def sub():
+            for i in xrange(1000):
+                qos.decrement_eventually()
+
+        def threaded(funs):
+            from threading import Thread
+            threads = [Thread(target=fun) for fun in funs]
+            for thread in threads:
+                thread.start()
+            for thread in threads:
+                thread.join()
+
+        threaded([add, add])
+        self.assertEqual(qos.value, 2010)
+
+        qos.value = 1000
+        threaded([add, sub])  # n = 2
+        self.assertEqual(qos.value, 1000)
+
+    def test_exceeds_short(self):
+        qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1,
+                current_app.log.get_default_logger())
+        qos.update()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
+        qos.increment()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
+        qos.increment()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
+        qos.decrement()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
+        qos.decrement()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
+
+    def test_consumer_increment_decrement(self):
+        consumer = Mock()
+        qos = QoS(consumer, 10, current_app.log.get_default_logger())
+        qos.update()
+        self.assertEqual(qos.value, 10)
+        self.assertIn({"prefetch_count": 10}, consumer.qos.call_args)
+        qos.decrement()
+        self.assertEqual(qos.value, 9)
+        self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
+        qos.decrement_eventually()
+        self.assertEqual(qos.value, 8)
+        self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
+
+        # Does not decrement 0 value
+        qos.value = 0
+        qos.decrement()
+        self.assertEqual(qos.value, 0)
+        qos.increment()
+        self.assertEqual(qos.value, 0)
+
+    def test_consumer_decrement_eventually(self):
+        consumer = Mock()
+        qos = QoS(consumer, 10, current_app.log.get_default_logger())
+        qos.decrement_eventually()
+        self.assertEqual(qos.value, 9)
+        qos.value = 0
+        qos.decrement_eventually()
+        self.assertEqual(qos.value, 0)
+
+    def test_set(self):
+        consumer = Mock()
+        qos = QoS(consumer, 10, current_app.log.get_default_logger())
+        qos.set(12)
+        self.assertEqual(qos.prev, 12)
+        qos.set(qos.prev)
+
+
+class test_Consumer(unittest.TestCase):
+
+    def setUp(self):
+        self.ready_queue = FastQueue()
+        self.eta_schedule = Timer()
+        self.logger = current_app.log.get_default_logger()
+        self.logger.setLevel(0)
+
+    def tearDown(self):
+        self.eta_schedule.stop()
+
+    def test_info(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                           send_events=False)
+        l.qos = QoS(l.task_consumer, 10, l.logger)
+        info = l.info
+        self.assertEqual(info["prefetch_count"], 10)
+        self.assertFalse(info["broker"])
+
+        l.connection = current_app.broker_connection()
+        info = l.info
+        self.assertTrue(info["broker"])
+
+    def test_start_when_closed(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                            send_events=False)
+        l._state = CLOSE
+        l.start()
+
+    def test_connection(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                           send_events=False)
+
+        l.reset_connection()
+        self.assertIsInstance(l.connection, BrokerConnection)
+
+        l._state = RUN
+        l.event_dispatcher = None
+        l.stop_consumers(close_connection=False)
+        self.assertTrue(l.connection)
+
+        l._state = RUN
+        l.stop_consumers()
+        self.assertIsNone(l.connection)
+        self.assertIsNone(l.task_consumer)
+
+        l.reset_connection()
+        self.assertIsInstance(l.connection, BrokerConnection)
+        l.stop_consumers()
+
+        l.stop()
+        l.close_connection()
+        self.assertIsNone(l.connection)
+        self.assertIsNone(l.task_consumer)
+
+    def test_close_connection(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                           send_events=False)
+        l._state = RUN
+        l.close_connection()
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                           send_events=False)
+        eventer = l.event_dispatcher = Mock()
+        eventer.enabled = True
+        heart = l.heart = MockHeart()
+        l._state = RUN
+        l.stop_consumers()
+        self.assertTrue(eventer.close.call_count)
+        self.assertTrue(heart.closed)
+
+    def test_receive_message_unknown(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                           send_events=False)
+        backend = Mock()
+        m = create_message(backend, unknown={"baz": "!!!"})
+        l.event_dispatcher = Mock()
+        l.pidbox_node = MockNode()
+
+        with catch_warnings(record=True) as log:
+            l.receive_message(m.decode(), m)
+            self.assertTrue(log)
+            self.assertIn("unknown message", log[0].message.args[0])
+
+    @patch("celery.utils.timer2.to_timestamp")
+    def test_receive_message_eta_OverflowError(self, to_timestamp):
+        to_timestamp.side_effect = OverflowError()
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False)
+        m = create_message(Mock(), task=foo_task.name,
+                                   args=("2, 2"),
+                                   kwargs={},
+                                   eta=datetime.now().isoformat())
+        l.event_dispatcher = Mock()
+        l.pidbox_node = MockNode()
+
+        l.receive_message(m.decode(), m)
+        self.assertTrue(m.acknowledged)
+        self.assertTrue(to_timestamp.call_count)
+
+    def test_receive_message_InvalidTaskError(self):
+        logger = Mock()
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
+                           send_events=False)
+        m = create_message(Mock(), task=foo_task.name,
+                           args=(1, 2), kwargs="foobarbaz", id=1)
+        l.event_dispatcher = Mock()
+        l.pidbox_node = MockNode()
+
+        l.receive_message(m.decode(), m)
+        self.assertIn("Received invalid task message",
+                      logger.error.call_args[0][0])
+
+    def test_on_decode_error(self):
+        logger = Mock()
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
+                           send_events=False)
+
+        class MockMessage(Mock):
+            content_type = "application/x-msgpack"
+            content_encoding = "binary"
+            body = "foobarbaz"
+
+        message = MockMessage()
+        l.on_decode_error(message, KeyError("foo"))
+        self.assertTrue(message.ack.call_count)
+        self.assertIn("Can't decode message body",
+                      logger.critical.call_args[0][0])
+
+    def test_receieve_message(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                           send_events=False)
+        m = create_message(Mock(), task=foo_task.name,
+                           args=[2, 4, 8], kwargs={})
+
+        l.event_dispatcher = Mock()
+        l.receive_message(m.decode(), m)
+
+        in_bucket = self.ready_queue.get_nowait()
+        self.assertIsInstance(in_bucket, TaskRequest)
+        self.assertEqual(in_bucket.task_name, foo_task.name)
+        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
+        self.assertTrue(self.eta_schedule.empty())
+
+    def test_start_connection_error(self):
+
+        class MockConsumer(MainConsumer):
+            iterations = 0
+
+            def consume_messages(self):
+                if not self.iterations:
+                    self.iterations = 1
+                    raise KeyError("foo")
+                raise SyntaxError("bar")
+
+        l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False, pool=BasePool())
+        l.connection_errors = (KeyError, )
+        self.assertRaises(SyntaxError, l.start)
+        l.heart.stop()
+        l.priority_timer.stop()
+
+    def test_consume_messages_ignores_socket_timeout(self):
+
+        class Connection(current_app.broker_connection().__class__):
+            obj = None
+
+            def drain_events(self, **kwargs):
+                self.obj.connection = None
+                raise socket.timeout(10)
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                            send_events=False)
+        l.connection = Connection()
+        l.task_consumer = Mock()
+        l.connection.obj = l
+        l.qos = QoS(l.task_consumer, 10, l.logger)
+        l.consume_messages()
+
+    def test_consume_messages_when_socket_error(self):
+
+        class Connection(current_app.broker_connection().__class__):
+            obj = None
+
+            def drain_events(self, **kwargs):
+                self.obj.connection = None
+                raise socket.error("foo")
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                            send_events=False)
+        l._state = RUN
+        c = l.connection = Connection()
+        l.connection.obj = l
+        l.task_consumer = Mock()
+        l.qos = QoS(l.task_consumer, 10, l.logger)
+        with self.assertRaises(socket.error):
+            l.consume_messages()
+
+        l._state = CLOSE
+        l.connection = c
+        l.consume_messages()
+
+    def test_consume_messages(self):
+
+        class Connection(current_app.broker_connection().__class__):
+            obj = None
+
+            def drain_events(self, **kwargs):
+                self.obj.connection = None
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False)
+        l.connection = Connection()
+        l.connection.obj = l
+        l.task_consumer = Mock()
+        l.qos = QoS(l.task_consumer, 10, l.logger)
+
+        l.consume_messages()
+        l.consume_messages()
+        self.assertTrue(l.task_consumer.consume.call_count)
+        l.task_consumer.qos.assert_called_with(prefetch_count=10)
+        l.qos.decrement()
+        l.consume_messages()
+        l.task_consumer.qos.assert_called_with(prefetch_count=9)
+
+    def test_maybe_conn_error(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False)
+        l.connection_errors = (KeyError, )
+        l.channel_errors = (SyntaxError, )
+        l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
+        l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
+        l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
+        self.assertRaises(IndexError, l.maybe_conn_error,
+                Mock(side_effect=IndexError("foo")))
+
+    def test_apply_eta_task(self):
+        from celery.worker import state
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False)
+        l.qos = QoS(None, 10, l.logger)
+
+        task = object()
+        qos = l.qos.value
+        l.apply_eta_task(task)
+        self.assertIn(task, state.reserved_requests)
+        self.assertEqual(l.qos.value, qos - 1)
+        self.assertIs(self.ready_queue.get_nowait(), task)
+
+    def test_receieve_message_eta_isoformat(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False)
+        m = create_message(Mock(), task=foo_task.name,
+                           eta=datetime.now().isoformat(),
+                           args=[2, 4, 8], kwargs={})
+
+        l.task_consumer = Mock()
+        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
+        l.event_dispatcher = Mock()
+        l.enabled = False
+        l.receive_message(m.decode(), m)
+        l.eta_schedule.stop()
+
+        items = [entry[2] for entry in self.eta_schedule.queue]
+        found = 0
+        for item in items:
+            if item.args[0].task_name == foo_task.name:
+                found = True
+        self.assertTrue(found)
+        self.assertTrue(l.task_consumer.qos.call_count)
+        l.eta_schedule.stop()
+
+    def test_on_control(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False)
+        l.pidbox_node = Mock()
+        l.reset_pidbox_node = Mock()
+
+        l.on_control("foo", "bar")
+        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
+
+        l.pidbox_node = Mock()
+        l.pidbox_node.handle_message.side_effect = KeyError("foo")
+        l.on_control("foo", "bar")
+        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
+
+        l.pidbox_node = Mock()
+        l.pidbox_node.handle_message.side_effect = ValueError("foo")
+        l.on_control("foo", "bar")
+        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
+        l.reset_pidbox_node.assert_called_with()
+
+    def test_revoke(self):
+        ready_queue = FastQueue()
+        l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
+                           send_events=False)
+        backend = Mock()
+        id = uuid()
+        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
+                           kwargs={}, id=id)
+        from celery.worker.state import revoked
+        revoked.add(id)
+
+        l.receive_message(t.decode(), t)
+        self.assertTrue(ready_queue.empty())
+
+    def test_receieve_message_not_registered(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        backend = Mock()
+        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
+
+        l.event_dispatcher = Mock()
+        self.assertFalse(l.receive_message(m.decode(), m))
+        self.assertRaises(Empty, self.ready_queue.get_nowait)
+        self.assertTrue(self.eta_schedule.empty())
+
+    def test_receieve_message_ack_raises(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        backend = Mock()
+        m = create_message(backend, args=[2, 4, 8], kwargs={})
+
+        l.event_dispatcher = Mock()
+        l.connection_errors = (socket.error, )
+        l.logger = Mock()
+        m.ack = Mock()
+        m.ack.side_effect = socket.error("foo")
+        with catch_warnings(record=True) as log:
+            self.assertFalse(l.receive_message(m.decode(), m))
+            self.assertTrue(log)
+            self.assertIn("unknown message", log[0].message.args[0])
+        self.assertRaises(Empty, self.ready_queue.get_nowait)
+        self.assertTrue(self.eta_schedule.empty())
+        m.ack.assert_called_with()
+        self.assertTrue(l.logger.critical.call_count)
+
+    def test_receieve_message_eta(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        l.event_dispatcher = Mock()
+        l.event_dispatcher._outbound_buffer = deque()
+        backend = Mock()
+        m = create_message(backend, task=foo_task.name,
+                           args=[2, 4, 8], kwargs={},
+                           eta=(datetime.now() +
+                               timedelta(days=1)).isoformat())
+
+        l.reset_connection()
+        p = l.app.conf.BROKER_CONNECTION_RETRY
+        l.app.conf.BROKER_CONNECTION_RETRY = False
+        try:
+            l.reset_connection()
+        finally:
+            l.app.conf.BROKER_CONNECTION_RETRY = p
+        l.stop_consumers()
+        l.event_dispatcher = Mock()
+        l.receive_message(m.decode(), m)
+        l.eta_schedule.stop()
+        in_hold = self.eta_schedule.queue[0]
+        self.assertEqual(len(in_hold), 3)
+        eta, priority, entry = in_hold
+        task = entry.args[0]
+        self.assertIsInstance(task, TaskRequest)
+        self.assertEqual(task.task_name, foo_task.name)
+        self.assertEqual(task.execute(), 2 * 4 * 8)
+        self.assertRaises(Empty, self.ready_queue.get_nowait)
+
+    def test_reset_pidbox_node(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        l.pidbox_node = Mock()
+        chan = l.pidbox_node.channel = Mock()
+        l.connection = Mock()
+        chan.close.side_effect = socket.error("foo")
+        l.connection_errors = (socket.error, )
+        l.reset_pidbox_node()
+        chan.close.assert_called_with()
+
+    def test_reset_pidbox_node_green(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        l.pool = Mock()
+        l.pool.is_green = True
+        l.reset_pidbox_node()
+        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)
+
+    def test__green_pidbox_node(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        l.pidbox_node = Mock()
+
+        connections = []
+
+        class Connection(object):
+
+            def __init__(self, obj):
+                connections.append(self)
+                self.obj = obj
+                self.closed = False
+
+            def channel(self):
+                return Mock()
+
+            def drain_events(self):
+                self.obj.connection = None
+
+            def close(self):
+                self.closed = True
+
+        l.connection = Mock()
+        l._open_connection = lambda: Connection(obj=l)
+        l._green_pidbox_node()
+
+        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
+        self.assertTrue(l.broadcast_consumer)
+        l.broadcast_consumer.consume.assert_called_with()
+
+        self.assertIsNone(l.connection)
+        self.assertTrue(connections[0].closed)
+
+    def test_start__consume_messages(self):
+
+        class _QoS(object):
+            prev = 3
+            value = 4
+
+            def update(self):
+                self.prev = self.value
+
+        class _Consumer(MyKombuConsumer):
+            iterations = 0
+
+            def reset_connection(self):
+                if self.iterations >= 1:
+                    raise KeyError("foo")
+
+        init_callback = Mock()
+        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
+                      send_events=False, init_callback=init_callback)
+        l.task_consumer = Mock()
+        l.broadcast_consumer = Mock()
+        l.qos = _QoS()
+        l.connection = BrokerConnection()
+        l.iterations = 0
+
+        def raises_KeyError(limit=None):
+            l.iterations += 1
+            if l.qos.prev != l.qos.value:
+                l.qos.update()
+            if l.iterations >= 2:
+                raise KeyError("foo")
+
+        l.consume_messages = raises_KeyError
+        self.assertRaises(KeyError, l.start)
+        self.assertTrue(init_callback.call_count)
+        self.assertEqual(l.iterations, 1)
+        self.assertEqual(l.qos.prev, l.qos.value)
+
+        init_callback.reset_mock()
+        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
+                      send_events=False, init_callback=init_callback)
+        l.qos = _QoS()
+        l.task_consumer = Mock()
+        l.broadcast_consumer = Mock()
+        l.connection = BrokerConnection()
+        l.consume_messages = Mock(side_effect=socket.error("foo"))
+        self.assertRaises(socket.error, l.start)
+        self.assertTrue(init_callback.call_count)
+        self.assertTrue(l.consume_messages.call_count)
+
+    def test_reset_connection_with_no_node(self):
+
+        l = MainConsumer(self.ready_queue, self.eta_schedule, self.logger)
+        self.assertEqual(None, l.pool)
+        l.reset_connection()
+
+
+class test_WorkController(AppCase):
+
+    def setup(self):
+        self.worker = self.create_worker()
+
+    def create_worker(self, **kw):
+        worker = WorkController(concurrency=1, loglevel=0, **kw)
+        worker.logger = Mock()
+        return worker
+
+    @patch("celery.platforms.signals")
+    @patch("celery.platforms.set_mp_process_title")
+    def test_process_initializer(self, set_mp_process_title, _signals):
+        from celery import Celery
+        from celery import signals
+        from celery.app import _tls
+        from celery.worker import process_initializer
+        from celery.worker import WORKER_SIGRESET, WORKER_SIGIGNORE
+
+        def on_worker_process_init(**kwargs):
+            on_worker_process_init.called = True
+        on_worker_process_init.called = False
+        signals.worker_process_init.connect(on_worker_process_init)
+
+        app = Celery(loader=Mock(), set_as_current=False)
+        process_initializer(app, "awesome.worker.com")
+        self.assertIn((tuple(WORKER_SIGIGNORE), {}),
+                      _signals.ignore.call_args_list)
+        self.assertIn((tuple(WORKER_SIGRESET), {}),
+                      _signals.reset.call_args_list)
+        self.assertTrue(app.loader.init_worker.call_count)
+        self.assertTrue(on_worker_process_init.called)
+        self.assertIs(_tls.current_app, app)
+        set_mp_process_title.assert_called_with("celeryd",
+                        hostname="awesome.worker.com")
+
+    def test_with_rate_limits_disabled(self):
+        worker = WorkController(concurrency=1, loglevel=0,
+                                disable_rate_limits=True)
+        self.assertTrue(hasattr(worker.ready_queue, "put"))
+
+    def test_attrs(self):
+        worker = self.worker
+        self.assertIsInstance(worker.scheduler, Timer)
+        self.assertTrue(worker.scheduler)
+        self.assertTrue(worker.pool)
+        self.assertTrue(worker.consumer)
+        self.assertTrue(worker.mediator)
+        self.assertTrue(worker.components)
+
+    def test_with_embedded_celerybeat(self):
+        worker = WorkController(concurrency=1, loglevel=0,
+                                embed_clockservice=True)
+        self.assertTrue(worker.beat)
+        self.assertIn(worker.beat, worker.components)
+
+    def test_with_autoscaler(self):
+        worker = self.create_worker(autoscale=[10, 3], send_events=False,
+                                eta_scheduler_cls="celery.utils.timer2.Timer")
+        self.assertTrue(worker.autoscaler)
+
+    def test_dont_stop_or_terminate(self):
+        worker = WorkController(concurrency=1, loglevel=0)
+        worker.stop()
+        self.assertNotEqual(worker._state, worker.CLOSE)
+        worker.terminate()
+        self.assertNotEqual(worker._state, worker.CLOSE)
+
+        sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
+        try:
+            worker._state = worker.RUN
+            worker.stop(in_sighandler=True)
+            self.assertNotEqual(worker._state, worker.CLOSE)
+            worker.terminate(in_sighandler=True)
+            self.assertNotEqual(worker._state, worker.CLOSE)
+        finally:
+            worker.pool.signal_safe = sigsafe
+
+    def test_on_timer_error(self):
+        worker = WorkController(concurrency=1, loglevel=0)
+        worker.logger = Mock()
+
+        try:
+            raise KeyError("foo")
+        except KeyError:
+            exc_info = sys.exc_info()
+
+        worker.on_timer_error(exc_info)
+        msg, args = worker.logger.error.call_args[0]
+        self.assertIn("KeyError", msg % args)
+
+    def test_on_timer_tick(self):
+        worker = WorkController(concurrency=1, loglevel=10)
+        worker.logger = Mock()
+        worker.timer_debug = worker.logger.debug
+
+        worker.on_timer_tick(30.0)
+        logged = worker.logger.debug.call_args[0][0]
+        self.assertIn("30.0", logged)
+
+    def test_process_task(self):
+        worker = self.worker
+        worker.pool = Mock()
+        backend = Mock()
+        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+                           kwargs={})
+        task = TaskRequest.from_message(m, m.decode())
+        worker.process_task(task)
+        self.assertEqual(worker.pool.apply_async.call_count, 1)
+        worker.pool.stop()
+
+    def test_process_task_raise_base(self):
+        worker = self.worker
+        worker.pool = Mock()
+        worker.pool.apply_async.side_effect = KeyboardInterrupt("Ctrl+C")
+        backend = Mock()
+        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+                           kwargs={})
+        task = TaskRequest.from_message(m, m.decode())
+        worker.components = []
+        worker._state = worker.RUN
+        self.assertRaises(KeyboardInterrupt, worker.process_task, task)
+        self.assertEqual(worker._state, worker.TERMINATE)
+
+    def test_process_task_raise_SystemTerminate(self):
+        worker = self.worker
+        worker.pool = Mock()
+        worker.pool.apply_async.side_effect = SystemTerminate()
+        backend = Mock()
+        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+                           kwargs={})
+        task = TaskRequest.from_message(m, m.decode())
+        worker.components = []
+        worker._state = worker.RUN
+        self.assertRaises(SystemExit, worker.process_task, task)
+        self.assertEqual(worker._state, worker.TERMINATE)
+
+    def test_process_task_raise_regular(self):
+        worker = self.worker
+        worker.pool = Mock()
+        worker.pool.apply_async.side_effect = KeyError("some exception")
+        backend = Mock()
+        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+                           kwargs={})
+        task = TaskRequest.from_message(m, m.decode())
+        worker.process_task(task)
+        worker.pool.stop()
+
+    def test_start_catches_base_exceptions(self):
+        worker1 = self.create_worker()
+        stc = Mock()
+        stc.start.side_effect = SystemTerminate()
+        worker1.components = [stc]
+        self.assertRaises(SystemExit, worker1.start)
+        self.assertTrue(stc.terminate.call_count)
+
+        worker2 = self.create_worker()
+        sec = Mock()
+        sec.start.side_effect = SystemExit()
+        sec.terminate = None
+        worker2.components = [sec]
+        self.assertRaises(SystemExit, worker2.start)
+        self.assertTrue(sec.stop.call_count)
+
+    def test_state_db(self):
+        from celery.worker import state
+        Persistent = state.Persistent
+
+        state.Persistent = Mock()
+        try:
+            worker = self.create_worker(db="statefilename")
+            self.assertTrue(worker._finalize_db)
+            worker._finalize_db.cancel()
+        finally:
+            state.Persistent = Persistent
+
+    def test_disable_rate_limits(self):
+        from celery.worker.buckets import FastQueue
+        worker = self.create_worker(disable_rate_limits=True)
+        self.assertIsInstance(worker.ready_queue, FastQueue)
+        self.assertIsNone(worker.mediator)
+        self.assertEqual(worker.ready_queue.put, worker.process_task)
+
+    def test_start__stop(self):
+        worker = self.worker
+        worker.components = [Mock(), Mock(), Mock(), Mock()]
+
+        worker.start()
+        for w in worker.components:
+            self.assertTrue(w.start.call_count)
+        worker.stop()
+        for component in worker.components:
+            self.assertTrue(w.stop.call_count)
+
+    def test_start__terminate(self):
+        worker = self.worker
+        worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
+        for component in worker.components[:3]:
+            component.terminate = None
+
+        worker.start()
+        for w in worker.components[:3]:
+            self.assertTrue(w.start.call_count)
+        self.assertTrue(worker._running, len(worker.components))
+        self.assertEqual(worker._state, RUN)
+        worker.terminate()
+        for component in worker.components[:3]:
+            self.assertTrue(component.stop.call_count)
+        self.assertTrue(worker.components[4].terminate.call_count)

+ 0 - 889
celery/tests/test_worker/test_worker.py

@@ -1,889 +0,0 @@
-from __future__ import with_statement
-
-import socket
-import sys
-
-from collections import deque
-from datetime import datetime, timedelta
-from Queue import Empty
-
-from kombu.transport.base import Message
-from kombu.connection import BrokerConnection
-from mock import Mock, patch
-
-from celery import current_app
-from celery.concurrency.base import BasePool
-from celery.exceptions import SystemTerminate
-from celery.task import task as task_dec
-from celery.task import periodic_task as periodic_task_dec
-from celery.utils import uuid
-from celery.worker import WorkController
-from celery.worker.buckets import FastQueue
-from celery.worker.job import TaskRequest
-from celery.worker.consumer import Consumer as MainConsumer
-from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX, CLOSE
-from celery.utils.serialization import pickle
-from celery.utils.timer2 import Timer
-
-from celery.tests.compat import catch_warnings
-from celery.tests.utils import unittest
-from celery.tests.utils import AppCase
-
-
-class PlaceHolder(object):
-        pass
-
-
-class MyKombuConsumer(MainConsumer):
-    broadcast_consumer = Mock()
-    task_consumer = Mock()
-
-    def __init__(self, *args, **kwargs):
-        kwargs.setdefault("pool", BasePool(2))
-        super(MyKombuConsumer, self).__init__(*args, **kwargs)
-
-    def restart_heartbeat(self):
-        self.heart = None
-
-
-class MockNode(object):
-    commands = []
-
-    def handle_message(self, body, message):
-        self.commands.append(body.pop("command", None))
-
-
-class MockEventDispatcher(object):
-    sent = []
-    closed = False
-    flushed = False
-    _outbound_buffer = []
-
-    def send(self, event, *args, **kwargs):
-        self.sent.append(event)
-
-    def close(self):
-        self.closed = True
-
-    def flush(self):
-        self.flushed = True
-
-
-class MockHeart(object):
-    closed = False
-
-    def stop(self):
-        self.closed = True
-
-
-@task_dec()
-def foo_task(x, y, z, **kwargs):
-    return x * y * z
-
-
-@periodic_task_dec(run_every=60)
-def foo_periodic_task():
-    return "foo"
-
-
-def create_message(channel, **data):
-    data.setdefault("id", uuid())
-    channel.no_ack_consumers = set()
-    return Message(channel, body=pickle.dumps(dict(**data)),
-                   content_type="application/x-python-serialize",
-                   content_encoding="binary",
-                   delivery_info={"consumer_tag": "mock"})
-
-
-class test_QoS(unittest.TestCase):
-
-    class _QoS(QoS):
-        def __init__(self, value):
-            self.value = value
-            QoS.__init__(self, None, value, None)
-
-        def set(self, value):
-            return value
-
-    def test_qos_increment_decrement(self):
-        qos = self._QoS(10)
-        self.assertEqual(qos.increment(), 11)
-        self.assertEqual(qos.increment(3), 14)
-        self.assertEqual(qos.increment(-30), 14)
-        self.assertEqual(qos.decrement(7), 7)
-        self.assertEqual(qos.decrement(), 6)
-        self.assertRaises(AssertionError, qos.decrement, 10)
-
-    def test_qos_disabled_increment_decrement(self):
-        qos = self._QoS(0)
-        self.assertEqual(qos.increment(), 0)
-        self.assertEqual(qos.increment(3), 0)
-        self.assertEqual(qos.increment(-30), 0)
-        self.assertEqual(qos.decrement(7), 0)
-        self.assertEqual(qos.decrement(), 0)
-        self.assertEqual(qos.decrement(10), 0)
-
-    def test_qos_thread_safe(self):
-        qos = self._QoS(10)
-
-        def add():
-            for i in xrange(1000):
-                qos.increment()
-
-        def sub():
-            for i in xrange(1000):
-                qos.decrement_eventually()
-
-        def threaded(funs):
-            from threading import Thread
-            threads = [Thread(target=fun) for fun in funs]
-            for thread in threads:
-                thread.start()
-            for thread in threads:
-                thread.join()
-
-        threaded([add, add])
-        self.assertEqual(qos.value, 2010)
-
-        qos.value = 1000
-        threaded([add, sub])  # n = 2
-        self.assertEqual(qos.value, 1000)
-
-    def test_exceeds_short(self):
-        qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1,
-                current_app.log.get_default_logger())
-        qos.update()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
-        qos.increment()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
-        qos.increment()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
-        qos.decrement()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
-        qos.decrement()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
-
-    def test_consumer_increment_decrement(self):
-        consumer = Mock()
-        qos = QoS(consumer, 10, current_app.log.get_default_logger())
-        qos.update()
-        self.assertEqual(qos.value, 10)
-        self.assertIn({"prefetch_count": 10}, consumer.qos.call_args)
-        qos.decrement()
-        self.assertEqual(qos.value, 9)
-        self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 8)
-        self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
-
-        # Does not decrement 0 value
-        qos.value = 0
-        qos.decrement()
-        self.assertEqual(qos.value, 0)
-        qos.increment()
-        self.assertEqual(qos.value, 0)
-
-    def test_consumer_decrement_eventually(self):
-        consumer = Mock()
-        qos = QoS(consumer, 10, current_app.log.get_default_logger())
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 9)
-        qos.value = 0
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 0)
-
-    def test_set(self):
-        consumer = Mock()
-        qos = QoS(consumer, 10, current_app.log.get_default_logger())
-        qos.set(12)
-        self.assertEqual(qos.prev, 12)
-        qos.set(qos.prev)
-
-
-class test_Consumer(unittest.TestCase):
-
-    def setUp(self):
-        self.ready_queue = FastQueue()
-        self.eta_schedule = Timer()
-        self.logger = current_app.log.get_default_logger()
-        self.logger.setLevel(0)
-
-    def tearDown(self):
-        self.eta_schedule.stop()
-
-    def test_info(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                           send_events=False)
-        l.qos = QoS(l.task_consumer, 10, l.logger)
-        info = l.info
-        self.assertEqual(info["prefetch_count"], 10)
-        self.assertFalse(info["broker"])
-
-        l.connection = current_app.broker_connection()
-        info = l.info
-        self.assertTrue(info["broker"])
-
-    def test_start_when_closed(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                            send_events=False)
-        l._state = CLOSE
-        l.start()
-
-    def test_connection(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                           send_events=False)
-
-        l.reset_connection()
-        self.assertIsInstance(l.connection, BrokerConnection)
-
-        l._state = RUN
-        l.event_dispatcher = None
-        l.stop_consumers(close_connection=False)
-        self.assertTrue(l.connection)
-
-        l._state = RUN
-        l.stop_consumers()
-        self.assertIsNone(l.connection)
-        self.assertIsNone(l.task_consumer)
-
-        l.reset_connection()
-        self.assertIsInstance(l.connection, BrokerConnection)
-        l.stop_consumers()
-
-        l.stop()
-        l.close_connection()
-        self.assertIsNone(l.connection)
-        self.assertIsNone(l.task_consumer)
-
-    def test_close_connection(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                           send_events=False)
-        l._state = RUN
-        l.close_connection()
-
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                           send_events=False)
-        eventer = l.event_dispatcher = Mock()
-        eventer.enabled = True
-        heart = l.heart = MockHeart()
-        l._state = RUN
-        l.stop_consumers()
-        self.assertTrue(eventer.close.call_count)
-        self.assertTrue(heart.closed)
-
-    def test_receive_message_unknown(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                           send_events=False)
-        backend = Mock()
-        m = create_message(backend, unknown={"baz": "!!!"})
-        l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
-
-        with catch_warnings(record=True) as log:
-            l.receive_message(m.decode(), m)
-            self.assertTrue(log)
-            self.assertIn("unknown message", log[0].message.args[0])
-
-    @patch("celery.utils.timer2.to_timestamp")
-    def test_receive_message_eta_OverflowError(self, to_timestamp):
-        to_timestamp.side_effect = OverflowError()
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                             send_events=False)
-        m = create_message(Mock(), task=foo_task.name,
-                                   args=("2, 2"),
-                                   kwargs={},
-                                   eta=datetime.now().isoformat())
-        l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
-
-        l.receive_message(m.decode(), m)
-        self.assertTrue(m.acknowledged)
-        self.assertTrue(to_timestamp.call_count)
-
-    def test_receive_message_InvalidTaskError(self):
-        logger = Mock()
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
-                           send_events=False)
-        m = create_message(Mock(), task=foo_task.name,
-                           args=(1, 2), kwargs="foobarbaz", id=1)
-        l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
-
-        l.receive_message(m.decode(), m)
-        self.assertIn("Received invalid task message",
-                      logger.error.call_args[0][0])
-
-    def test_on_decode_error(self):
-        logger = Mock()
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
-                           send_events=False)
-
-        class MockMessage(Mock):
-            content_type = "application/x-msgpack"
-            content_encoding = "binary"
-            body = "foobarbaz"
-
-        message = MockMessage()
-        l.on_decode_error(message, KeyError("foo"))
-        self.assertTrue(message.ack.call_count)
-        self.assertIn("Can't decode message body",
-                      logger.critical.call_args[0][0])
-
-    def test_receieve_message(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                           send_events=False)
-        m = create_message(Mock(), task=foo_task.name,
-                           args=[2, 4, 8], kwargs={})
-
-        l.event_dispatcher = Mock()
-        l.receive_message(m.decode(), m)
-
-        in_bucket = self.ready_queue.get_nowait()
-        self.assertIsInstance(in_bucket, TaskRequest)
-        self.assertEqual(in_bucket.task_name, foo_task.name)
-        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
-        self.assertTrue(self.eta_schedule.empty())
-
-    def test_start_connection_error(self):
-
-        class MockConsumer(MainConsumer):
-            iterations = 0
-
-            def consume_messages(self):
-                if not self.iterations:
-                    self.iterations = 1
-                    raise KeyError("foo")
-                raise SyntaxError("bar")
-
-        l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                             send_events=False, pool=BasePool())
-        l.connection_errors = (KeyError, )
-        self.assertRaises(SyntaxError, l.start)
-        l.heart.stop()
-        l.priority_timer.stop()
-
-    def test_consume_messages_ignores_socket_timeout(self):
-
-        class Connection(current_app.broker_connection().__class__):
-            obj = None
-
-            def drain_events(self, **kwargs):
-                self.obj.connection = None
-                raise socket.timeout(10)
-
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                            send_events=False)
-        l.connection = Connection()
-        l.task_consumer = Mock()
-        l.connection.obj = l
-        l.qos = QoS(l.task_consumer, 10, l.logger)
-        l.consume_messages()
-
-    def test_consume_messages_when_socket_error(self):
-
-        class Connection(current_app.broker_connection().__class__):
-            obj = None
-
-            def drain_events(self, **kwargs):
-                self.obj.connection = None
-                raise socket.error("foo")
-
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                            send_events=False)
-        l._state = RUN
-        c = l.connection = Connection()
-        l.connection.obj = l
-        l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, 10, l.logger)
-        with self.assertRaises(socket.error):
-            l.consume_messages()
-
-        l._state = CLOSE
-        l.connection = c
-        l.consume_messages()
-
-    def test_consume_messages(self):
-
-        class Connection(current_app.broker_connection().__class__):
-            obj = None
-
-            def drain_events(self, **kwargs):
-                self.obj.connection = None
-
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                             send_events=False)
-        l.connection = Connection()
-        l.connection.obj = l
-        l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, 10, l.logger)
-
-        l.consume_messages()
-        l.consume_messages()
-        self.assertTrue(l.task_consumer.consume.call_count)
-        l.task_consumer.qos.assert_called_with(prefetch_count=10)
-        l.qos.decrement()
-        l.consume_messages()
-        l.task_consumer.qos.assert_called_with(prefetch_count=9)
-
-    def test_maybe_conn_error(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                             send_events=False)
-        l.connection_errors = (KeyError, )
-        l.channel_errors = (SyntaxError, )
-        l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
-        l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
-        l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
-        self.assertRaises(IndexError, l.maybe_conn_error,
-                Mock(side_effect=IndexError("foo")))
-
-    def test_apply_eta_task(self):
-        from celery.worker import state
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                             send_events=False)
-        l.qos = QoS(None, 10, l.logger)
-
-        task = object()
-        qos = l.qos.value
-        l.apply_eta_task(task)
-        self.assertIn(task, state.reserved_requests)
-        self.assertEqual(l.qos.value, qos - 1)
-        self.assertIs(self.ready_queue.get_nowait(), task)
-
-    def test_receieve_message_eta_isoformat(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                             send_events=False)
-        m = create_message(Mock(), task=foo_task.name,
-                           eta=datetime.now().isoformat(),
-                           args=[2, 4, 8], kwargs={})
-
-        l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
-        l.event_dispatcher = Mock()
-        l.enabled = False
-        l.receive_message(m.decode(), m)
-        l.eta_schedule.stop()
-
-        items = [entry[2] for entry in self.eta_schedule.queue]
-        found = 0
-        for item in items:
-            if item.args[0].task_name == foo_task.name:
-                found = True
-        self.assertTrue(found)
-        self.assertTrue(l.task_consumer.qos.call_count)
-        l.eta_schedule.stop()
-
-    def test_on_control(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                             send_events=False)
-        l.pidbox_node = Mock()
-        l.reset_pidbox_node = Mock()
-
-        l.on_control("foo", "bar")
-        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
-
-        l.pidbox_node = Mock()
-        l.pidbox_node.handle_message.side_effect = KeyError("foo")
-        l.on_control("foo", "bar")
-        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
-
-        l.pidbox_node = Mock()
-        l.pidbox_node.handle_message.side_effect = ValueError("foo")
-        l.on_control("foo", "bar")
-        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
-        l.reset_pidbox_node.assert_called_with()
-
-    def test_revoke(self):
-        ready_queue = FastQueue()
-        l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
-                           send_events=False)
-        backend = Mock()
-        id = uuid()
-        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
-                           kwargs={}, id=id)
-        from celery.worker.state import revoked
-        revoked.add(id)
-
-        l.receive_message(t.decode(), t)
-        self.assertTrue(ready_queue.empty())
-
-    def test_receieve_message_not_registered(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                          send_events=False)
-        backend = Mock()
-        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
-
-        l.event_dispatcher = Mock()
-        self.assertFalse(l.receive_message(m.decode(), m))
-        self.assertRaises(Empty, self.ready_queue.get_nowait)
-        self.assertTrue(self.eta_schedule.empty())
-
-    def test_receieve_message_ack_raises(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                          send_events=False)
-        backend = Mock()
-        m = create_message(backend, args=[2, 4, 8], kwargs={})
-
-        l.event_dispatcher = Mock()
-        l.connection_errors = (socket.error, )
-        l.logger = Mock()
-        m.ack = Mock()
-        m.ack.side_effect = socket.error("foo")
-        with catch_warnings(record=True) as log:
-            self.assertFalse(l.receive_message(m.decode(), m))
-            self.assertTrue(log)
-            self.assertIn("unknown message", log[0].message.args[0])
-        self.assertRaises(Empty, self.ready_queue.get_nowait)
-        self.assertTrue(self.eta_schedule.empty())
-        m.ack.assert_called_with()
-        self.assertTrue(l.logger.critical.call_count)
-
-    def test_receieve_message_eta(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                          send_events=False)
-        l.event_dispatcher = Mock()
-        l.event_dispatcher._outbound_buffer = deque()
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name,
-                           args=[2, 4, 8], kwargs={},
-                           eta=(datetime.now() +
-                               timedelta(days=1)).isoformat())
-
-        l.reset_connection()
-        p = l.app.conf.BROKER_CONNECTION_RETRY
-        l.app.conf.BROKER_CONNECTION_RETRY = False
-        try:
-            l.reset_connection()
-        finally:
-            l.app.conf.BROKER_CONNECTION_RETRY = p
-        l.stop_consumers()
-        l.event_dispatcher = Mock()
-        l.receive_message(m.decode(), m)
-        l.eta_schedule.stop()
-        in_hold = self.eta_schedule.queue[0]
-        self.assertEqual(len(in_hold), 3)
-        eta, priority, entry = in_hold
-        task = entry.args[0]
-        self.assertIsInstance(task, TaskRequest)
-        self.assertEqual(task.task_name, foo_task.name)
-        self.assertEqual(task.execute(), 2 * 4 * 8)
-        self.assertRaises(Empty, self.ready_queue.get_nowait)
-
-    def test_reset_pidbox_node(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                          send_events=False)
-        l.pidbox_node = Mock()
-        chan = l.pidbox_node.channel = Mock()
-        l.connection = Mock()
-        chan.close.side_effect = socket.error("foo")
-        l.connection_errors = (socket.error, )
-        l.reset_pidbox_node()
-        chan.close.assert_called_with()
-
-    def test_reset_pidbox_node_green(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                          send_events=False)
-        l.pool = Mock()
-        l.pool.is_green = True
-        l.reset_pidbox_node()
-        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)
-
-    def test__green_pidbox_node(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
-                          send_events=False)
-        l.pidbox_node = Mock()
-
-        connections = []
-
-        class Connection(object):
-
-            def __init__(self, obj):
-                connections.append(self)
-                self.obj = obj
-                self.closed = False
-
-            def channel(self):
-                return Mock()
-
-            def drain_events(self):
-                self.obj.connection = None
-
-            def close(self):
-                self.closed = True
-
-        l.connection = Mock()
-        l._open_connection = lambda: Connection(obj=l)
-        l._green_pidbox_node()
-
-        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
-        self.assertTrue(l.broadcast_consumer)
-        l.broadcast_consumer.consume.assert_called_with()
-
-        self.assertIsNone(l.connection)
-        self.assertTrue(connections[0].closed)
-
-    def test_start__consume_messages(self):
-
-        class _QoS(object):
-            prev = 3
-            value = 4
-
-            def update(self):
-                self.prev = self.value
-
-        class _Consumer(MyKombuConsumer):
-            iterations = 0
-
-            def reset_connection(self):
-                if self.iterations >= 1:
-                    raise KeyError("foo")
-
-        init_callback = Mock()
-        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
-                      send_events=False, init_callback=init_callback)
-        l.task_consumer = Mock()
-        l.broadcast_consumer = Mock()
-        l.qos = _QoS()
-        l.connection = BrokerConnection()
-        l.iterations = 0
-
-        def raises_KeyError(limit=None):
-            l.iterations += 1
-            if l.qos.prev != l.qos.value:
-                l.qos.update()
-            if l.iterations >= 2:
-                raise KeyError("foo")
-
-        l.consume_messages = raises_KeyError
-        self.assertRaises(KeyError, l.start)
-        self.assertTrue(init_callback.call_count)
-        self.assertEqual(l.iterations, 1)
-        self.assertEqual(l.qos.prev, l.qos.value)
-
-        init_callback.reset_mock()
-        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
-                      send_events=False, init_callback=init_callback)
-        l.qos = _QoS()
-        l.task_consumer = Mock()
-        l.broadcast_consumer = Mock()
-        l.connection = BrokerConnection()
-        l.consume_messages = Mock(side_effect=socket.error("foo"))
-        self.assertRaises(socket.error, l.start)
-        self.assertTrue(init_callback.call_count)
-        self.assertTrue(l.consume_messages.call_count)
-
-    def test_reset_connection_with_no_node(self):
-
-        l = MainConsumer(self.ready_queue, self.eta_schedule, self.logger)
-        self.assertEqual(None, l.pool)
-        l.reset_connection()
-
-
-class test_WorkController(AppCase):
-
-    def setup(self):
-        self.worker = self.create_worker()
-
-    def create_worker(self, **kw):
-        worker = WorkController(concurrency=1, loglevel=0, **kw)
-        worker.logger = Mock()
-        return worker
-
-    @patch("celery.platforms.signals")
-    @patch("celery.platforms.set_mp_process_title")
-    def test_process_initializer(self, set_mp_process_title, _signals):
-        from celery import Celery
-        from celery import signals
-        from celery.app import _tls
-        from celery.worker import process_initializer
-        from celery.worker import WORKER_SIGRESET, WORKER_SIGIGNORE
-
-        def on_worker_process_init(**kwargs):
-            on_worker_process_init.called = True
-        on_worker_process_init.called = False
-        signals.worker_process_init.connect(on_worker_process_init)
-
-        app = Celery(loader=Mock(), set_as_current=False)
-        process_initializer(app, "awesome.worker.com")
-        self.assertIn((tuple(WORKER_SIGIGNORE), {}),
-                      _signals.ignore.call_args_list)
-        self.assertIn((tuple(WORKER_SIGRESET), {}),
-                      _signals.reset.call_args_list)
-        self.assertTrue(app.loader.init_worker.call_count)
-        self.assertTrue(on_worker_process_init.called)
-        self.assertIs(_tls.current_app, app)
-        set_mp_process_title.assert_called_with("celeryd",
-                        hostname="awesome.worker.com")
-
-    def test_with_rate_limits_disabled(self):
-        worker = WorkController(concurrency=1, loglevel=0,
-                                disable_rate_limits=True)
-        self.assertTrue(hasattr(worker.ready_queue, "put"))
-
-    def test_attrs(self):
-        worker = self.worker
-        self.assertIsInstance(worker.scheduler, Timer)
-        self.assertTrue(worker.scheduler)
-        self.assertTrue(worker.pool)
-        self.assertTrue(worker.consumer)
-        self.assertTrue(worker.mediator)
-        self.assertTrue(worker.components)
-
-    def test_with_embedded_celerybeat(self):
-        worker = WorkController(concurrency=1, loglevel=0,
-                                embed_clockservice=True)
-        self.assertTrue(worker.beat)
-        self.assertIn(worker.beat, worker.components)
-
-    def test_with_autoscaler(self):
-        worker = self.create_worker(autoscale=[10, 3], send_events=False,
-                                eta_scheduler_cls="celery.utils.timer2.Timer")
-        self.assertTrue(worker.autoscaler)
-
-    def test_dont_stop_or_terminate(self):
-        worker = WorkController(concurrency=1, loglevel=0)
-        worker.stop()
-        self.assertNotEqual(worker._state, worker.CLOSE)
-        worker.terminate()
-        self.assertNotEqual(worker._state, worker.CLOSE)
-
-        sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
-        try:
-            worker._state = worker.RUN
-            worker.stop(in_sighandler=True)
-            self.assertNotEqual(worker._state, worker.CLOSE)
-            worker.terminate(in_sighandler=True)
-            self.assertNotEqual(worker._state, worker.CLOSE)
-        finally:
-            worker.pool.signal_safe = sigsafe
-
-    def test_on_timer_error(self):
-        worker = WorkController(concurrency=1, loglevel=0)
-        worker.logger = Mock()
-
-        try:
-            raise KeyError("foo")
-        except KeyError:
-            exc_info = sys.exc_info()
-
-        worker.on_timer_error(exc_info)
-        msg, args = worker.logger.error.call_args[0]
-        self.assertIn("KeyError", msg % args)
-
-    def test_on_timer_tick(self):
-        worker = WorkController(concurrency=1, loglevel=10)
-        worker.logger = Mock()
-        worker.timer_debug = worker.logger.debug
-
-        worker.on_timer_tick(30.0)
-        logged = worker.logger.debug.call_args[0][0]
-        self.assertIn("30.0", logged)
-
-    def test_process_task(self):
-        worker = self.worker
-        worker.pool = Mock()
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
-                           kwargs={})
-        task = TaskRequest.from_message(m, m.decode())
-        worker.process_task(task)
-        self.assertEqual(worker.pool.apply_async.call_count, 1)
-        worker.pool.stop()
-
-    def test_process_task_raise_base(self):
-        worker = self.worker
-        worker.pool = Mock()
-        worker.pool.apply_async.side_effect = KeyboardInterrupt("Ctrl+C")
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
-                           kwargs={})
-        task = TaskRequest.from_message(m, m.decode())
-        worker.components = []
-        worker._state = worker.RUN
-        self.assertRaises(KeyboardInterrupt, worker.process_task, task)
-        self.assertEqual(worker._state, worker.TERMINATE)
-
-    def test_process_task_raise_SystemTerminate(self):
-        worker = self.worker
-        worker.pool = Mock()
-        worker.pool.apply_async.side_effect = SystemTerminate()
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
-                           kwargs={})
-        task = TaskRequest.from_message(m, m.decode())
-        worker.components = []
-        worker._state = worker.RUN
-        self.assertRaises(SystemExit, worker.process_task, task)
-        self.assertEqual(worker._state, worker.TERMINATE)
-
-    def test_process_task_raise_regular(self):
-        worker = self.worker
-        worker.pool = Mock()
-        worker.pool.apply_async.side_effect = KeyError("some exception")
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
-                           kwargs={})
-        task = TaskRequest.from_message(m, m.decode())
-        worker.process_task(task)
-        worker.pool.stop()
-
-    def test_start_catches_base_exceptions(self):
-        worker1 = self.create_worker()
-        stc = Mock()
-        stc.start.side_effect = SystemTerminate()
-        worker1.components = [stc]
-        self.assertRaises(SystemExit, worker1.start)
-        self.assertTrue(stc.terminate.call_count)
-
-        worker2 = self.create_worker()
-        sec = Mock()
-        sec.start.side_effect = SystemExit()
-        sec.terminate = None
-        worker2.components = [sec]
-        self.assertRaises(SystemExit, worker2.start)
-        self.assertTrue(sec.stop.call_count)
-
-    def test_state_db(self):
-        from celery.worker import state
-        Persistent = state.Persistent
-
-        state.Persistent = Mock()
-        try:
-            worker = self.create_worker(db="statefilename")
-            self.assertTrue(worker._finalize_db)
-            worker._finalize_db.cancel()
-        finally:
-            state.Persistent = Persistent
-
-    def test_disable_rate_limits(self):
-        from celery.worker.buckets import FastQueue
-        worker = self.create_worker(disable_rate_limits=True)
-        self.assertIsInstance(worker.ready_queue, FastQueue)
-        self.assertIsNone(worker.mediator)
-        self.assertEqual(worker.ready_queue.put, worker.process_task)
-
-    def test_start__stop(self):
-        worker = self.worker
-        worker.components = [Mock(), Mock(), Mock(), Mock()]
-
-        worker.start()
-        for w in worker.components:
-            self.assertTrue(w.start.call_count)
-        worker.stop()
-        for component in worker.components:
-            self.assertTrue(w.stop.call_count)
-
-    def test_start__terminate(self):
-        worker = self.worker
-        worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
-        for component in worker.components[:3]:
-            component.terminate = None
-
-        worker.start()
-        for w in worker.components[:3]:
-            self.assertTrue(w.start.call_count)
-        self.assertTrue(worker._running, len(worker.components))
-        self.assertEqual(worker._state, RUN)
-        worker.terminate()
-        for component in worker.components[:3]:
-            self.assertTrue(component.stop.call_count)
-        self.assertTrue(worker.components[4].terminate.call_count)