소스 검색

100% coverage for celery.worker.*

Ask Solem 13 년 전
부모
커밋
ae06ac0eae

+ 0 - 333
celery/tests/app/__init__.py

@@ -1,333 +0,0 @@
-from __future__ import absolute_import
-from __future__ import with_statement
-
-import os
-
-from mock import Mock, patch
-from pickle import loads, dumps
-
-from celery import Celery
-from celery import app as _app
-from celery.app import defaults
-from celery.app import state
-from celery.loaders.base import BaseLoader
-from celery.platforms import pyimplementation
-from celery.utils.serialization import pickle
-
-from celery.tests import config
-from celery.tests.utils import (Case, mask_modules, platform_pyimp,
-                                sys_platform, pypy_version)
-from celery.utils import uuid
-from celery.utils.mail import ErrorMail
-
-THIS_IS_A_KEY = "this is a value"
-
-
-class Object(object):
-
-    def __init__(self, **kwargs):
-        for key, value in kwargs.items():
-            setattr(self, key, value)
-
-
-def _get_test_config():
-    return dict((key, getattr(config, key))
-                    for key in dir(config)
-                        if key.isupper() and not key.startswith("_"))
-
-test_config = _get_test_config()
-
-
-class test_module(Case):
-
-    def test_default_app(self):
-        self.assertEqual(_app.default_app, state.default_app)
-
-    def test_bugreport(self):
-        self.assertTrue(_app.bugreport())
-
-
-class test_App(Case):
-
-    def setUp(self):
-        self.app = Celery(set_as_current=False)
-        self.app.conf.update(test_config)
-
-    def test_task(self):
-        app = Celery("foozibari", set_as_current=False)
-
-        def fun():
-            pass
-
-        fun.__module__ = "__main__"
-        task = app.task(fun)
-        self.assertEqual(task.name, app.main + ".fun")
-
-    def test_with_broker(self):
-        app = Celery(set_as_current=False, broker="foo://baribaz")
-        self.assertEqual(app.conf.BROKER_HOST, "foo://baribaz")
-
-    def test_repr(self):
-        self.assertTrue(repr(self.app))
-
-    def test_TaskSet(self):
-        ts = self.app.TaskSet()
-        self.assertListEqual(ts.tasks, [])
-        self.assertIs(ts.app, self.app)
-
-    def test_pickle_app(self):
-        changes = dict(THE_FOO_BAR="bars",
-                       THE_MII_MAR="jars")
-        self.app.conf.update(changes)
-        saved = pickle.dumps(self.app)
-        self.assertLess(len(saved), 2048)
-        restored = pickle.loads(saved)
-        self.assertDictContainsSubset(changes, restored.conf)
-
-    def test_worker_main(self):
-        from celery.bin import celeryd
-
-        class WorkerCommand(celeryd.WorkerCommand):
-
-            def execute_from_commandline(self, argv):
-                return argv
-
-        prev, celeryd.WorkerCommand = celeryd.WorkerCommand, WorkerCommand
-        try:
-            ret = self.app.worker_main(argv=["--version"])
-            self.assertListEqual(ret, ["--version"])
-        finally:
-            celeryd.WorkerCommand = prev
-
-    def test_config_from_envvar(self):
-        os.environ["CELERYTEST_CONFIG_OBJECT"] = "celery.tests.test_app"
-        self.app.config_from_envvar("CELERYTEST_CONFIG_OBJECT")
-        self.assertEqual(self.app.conf.THIS_IS_A_KEY, "this is a value")
-
-    def test_config_from_object(self):
-
-        class Object(object):
-            LEAVE_FOR_WORK = True
-            MOMENT_TO_STOP = True
-            CALL_ME_BACK = 123456789
-            WANT_ME_TO = False
-            UNDERSTAND_ME = True
-
-        self.app.config_from_object(Object())
-
-        self.assertTrue(self.app.conf.LEAVE_FOR_WORK)
-        self.assertTrue(self.app.conf.MOMENT_TO_STOP)
-        self.assertEqual(self.app.conf.CALL_ME_BACK, 123456789)
-        self.assertFalse(self.app.conf.WANT_ME_TO)
-        self.assertTrue(self.app.conf.UNDERSTAND_ME)
-
-    def test_config_from_cmdline(self):
-        cmdline = [".always_eager=no",
-                   ".result_backend=/dev/null",
-                   '.task_error_whitelist=(list)["a", "b", "c"]',
-                   "celeryd.prefetch_multiplier=368",
-                   ".foobarstring=(string)300",
-                   ".foobarint=(int)300",
-                   '.result_engine_options=(dict){"foo": "bar"}']
-        self.app.config_from_cmdline(cmdline, namespace="celery")
-        self.assertFalse(self.app.conf.CELERY_ALWAYS_EAGER)
-        self.assertEqual(self.app.conf.CELERY_RESULT_BACKEND, "/dev/null")
-        self.assertEqual(self.app.conf.CELERYD_PREFETCH_MULTIPLIER, 368)
-        self.assertListEqual(self.app.conf.CELERY_TASK_ERROR_WHITELIST,
-                             ["a", "b", "c"])
-        self.assertEqual(self.app.conf.CELERY_FOOBARSTRING, "300")
-        self.assertEqual(self.app.conf.CELERY_FOOBARINT, 300)
-        self.assertDictEqual(self.app.conf.CELERY_RESULT_ENGINE_OPTIONS,
-                             {"foo": "bar"})
-
-    def test_compat_setting_CELERY_BACKEND(self):
-
-        self.app.config_from_object(Object(CELERY_BACKEND="set_by_us"))
-        self.assertEqual(self.app.conf.CELERY_RESULT_BACKEND, "set_by_us")
-
-    def test_setting_BROKER_TRANSPORT_OPTIONS(self):
-
-        _args = {'foo': 'bar', 'spam': 'baz'}
-
-        self.app.config_from_object(Object())
-        self.assertEqual(self.app.conf.BROKER_TRANSPORT_OPTIONS, {})
-
-        self.app.config_from_object(Object(BROKER_TRANSPORT_OPTIONS=_args))
-        self.assertEqual(self.app.conf.BROKER_TRANSPORT_OPTIONS, _args)
-
-    def test_Windows_log_color_disabled(self):
-        self.app.IS_WINDOWS = True
-        self.assertFalse(self.app.log.supports_color())
-
-    def test_compat_setting_CARROT_BACKEND(self):
-        self.app.config_from_object(Object(CARROT_BACKEND="set_by_us"))
-        self.assertEqual(self.app.conf.BROKER_TRANSPORT, "set_by_us")
-
-    def test_WorkController(self):
-        x = self.app.Worker()
-        self.assertIs(x.app, self.app)
-
-    def test_AsyncResult(self):
-        x = self.app.AsyncResult("1")
-        self.assertIs(x.app, self.app)
-        r = loads(dumps(x))
-        # not set as current, so ends up as default app after reduce
-        self.assertIs(r.app, state.default_app)
-
-    @patch("celery.bin.celery.CeleryCommand.execute_from_commandline")
-    def test_start(self, execute):
-        self.app.start()
-        self.assertTrue(execute.called)
-
-    def test_mail_admins(self):
-
-        class Loader(BaseLoader):
-
-            def mail_admins(*args, **kwargs):
-                return args, kwargs
-
-        self.app.loader = Loader()
-        self.app.conf.ADMINS = None
-        self.assertFalse(self.app.mail_admins("Subject", "Body"))
-        self.app.conf.ADMINS = [("George Costanza", "george@vandelay.com")]
-        self.assertTrue(self.app.mail_admins("Subject", "Body"))
-
-    def test_amqp_get_broker_info(self):
-        self.assertDictContainsSubset({"hostname": "localhost",
-                                       "userid": "guest",
-                                       "password": "guest",
-                                       "virtual_host": "/"},
-                                      self.app.broker_connection(
-                                          transport="amqplib").info())
-        self.app.conf.BROKER_PORT = 1978
-        self.app.conf.BROKER_VHOST = "foo"
-        self.assertDictContainsSubset({"port": 1978,
-                                       "virtual_host": "foo"},
-                                      self.app.broker_connection(
-                                          transport="amqplib").info())
-        conn = self.app.broker_connection(virtual_host="/value")
-        self.assertDictContainsSubset({"virtual_host": "/value"},
-                                      conn.info())
-
-    def test_BROKER_BACKEND_alias(self):
-        self.assertEqual(self.app.conf.BROKER_BACKEND,
-                         self.app.conf.BROKER_TRANSPORT)
-
-    def test_with_default_connection(self):
-
-        @self.app.with_default_connection
-        def handler(connection=None, foo=None):
-            return connection, foo
-
-        connection, foo = handler(foo=42)
-        self.assertEqual(foo, 42)
-        self.assertTrue(connection)
-
-    def test_after_fork(self):
-        p = self.app._pool = Mock()
-        self.app._after_fork(self.app)
-        p.force_close_all.assert_called_with()
-        self.assertIsNone(self.app._pool)
-        self.app._after_fork(self.app)
-
-    def test_pool_no_multiprocessing(self):
-        with mask_modules("multiprocessing.util"):
-            pool = self.app.pool
-            self.assertIs(pool, self.app._pool)
-
-    def test_bugreport(self):
-        self.assertTrue(self.app.bugreport())
-
-    def test_send_task_sent_event(self):
-
-        class Dispatcher(object):
-            sent = []
-
-            def send(self, type, **fields):
-                self.sent.append((type, fields))
-
-        conn = self.app.broker_connection()
-        chan = conn.channel()
-        try:
-            for e in ("foo_exchange", "moo_exchange", "bar_exchange"):
-                chan.exchange_declare(e, "direct", durable=True)
-                chan.queue_declare(e, durable=True)
-                chan.queue_bind(e, e, e)
-        finally:
-            chan.close()
-        assert conn.transport_cls == "memory"
-
-        entities = conn.declared_entities
-
-        pub = self.app.amqp.TaskPublisher(conn, exchange="foo_exchange")
-        self.assertNotIn(pub._get_exchange("foo_exchange"), entities)
-
-        dispatcher = Dispatcher()
-        self.assertTrue(pub.delay_task("footask", (), {},
-                                       exchange="moo_exchange",
-                                       routing_key="moo_exchange",
-                                       event_dispatcher=dispatcher))
-        self.assertIn(pub._get_exchange("moo_exchange"), entities)
-        self.assertTrue(dispatcher.sent)
-        self.assertEqual(dispatcher.sent[0][0], "task-sent")
-        self.assertTrue(pub.delay_task("footask", (), {},
-                                       event_dispatcher=dispatcher,
-                                       exchange="bar_exchange",
-                                       routing_key="bar_exchange"))
-        self.assertIn(pub._get_exchange("bar_exchange"), entities)
-
-    def test_error_mail_sender(self):
-        x = ErrorMail.subject % {"name": "task_name",
-                                 "id": uuid(),
-                                 "exc": "FOOBARBAZ",
-                                 "hostname": "lana"}
-        self.assertTrue(x)
-
-
-class test_defaults(Case):
-
-    def test_str_to_bool(self):
-        for s in ("false", "no", "0"):
-            self.assertFalse(defaults.str_to_bool(s))
-        for s in ("true", "yes", "1"):
-            self.assertTrue(defaults.str_to_bool(s))
-        with self.assertRaises(TypeError):
-            defaults.str_to_bool("unsure")
-
-
-class test_debugging_utils(Case):
-
-    def test_enable_disable_trace(self):
-        try:
-            _app.enable_trace()
-            self.assertEqual(_app.app_or_default, _app._app_or_default_trace)
-            _app.disable_trace()
-            self.assertEqual(_app.app_or_default, _app._app_or_default)
-        finally:
-            _app.disable_trace()
-
-
-class test_pyimplementation(Case):
-
-    def test_platform_python_implementation(self):
-        with platform_pyimp(lambda: "Xython"):
-            self.assertEqual(pyimplementation(), "Xython")
-
-    def test_platform_jython(self):
-        with platform_pyimp():
-            with sys_platform("java 1.6.51"):
-                self.assertIn("Jython", pyimplementation())
-
-    def test_platform_pypy(self):
-        with platform_pyimp():
-            with sys_platform("darwin"):
-                with pypy_version((1, 4, 3)):
-                    self.assertIn("PyPy", pyimplementation())
-                with pypy_version((1, 4, 3, "a4")):
-                    self.assertIn("PyPy", pyimplementation())
-
-    def test_platform_fallback(self):
-        with platform_pyimp():
-            with sys_platform("darwin"):
-                with pypy_version():
-                    self.assertEqual("CPython", pyimplementation())

+ 333 - 0
celery/tests/app/test_app.py

@@ -0,0 +1,333 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+import os
+
+from mock import Mock, patch
+from pickle import loads, dumps
+
+from celery import Celery
+from celery import app as _app
+from celery.app import defaults
+from celery.app import state
+from celery.loaders.base import BaseLoader
+from celery.platforms import pyimplementation
+from celery.utils.serialization import pickle
+
+from celery.tests import config
+from celery.tests.utils import (Case, mask_modules, platform_pyimp,
+                                sys_platform, pypy_version)
+from celery.utils import uuid
+from celery.utils.mail import ErrorMail
+
+THIS_IS_A_KEY = "this is a value"
+
+
+class Object(object):
+
+    def __init__(self, **kwargs):
+        for key, value in kwargs.items():
+            setattr(self, key, value)
+
+
+def _get_test_config():
+    return dict((key, getattr(config, key))
+                    for key in dir(config)
+                        if key.isupper() and not key.startswith("_"))
+
+test_config = _get_test_config()
+
+
+class test_module(Case):
+
+    def test_default_app(self):
+        self.assertEqual(_app.default_app, state.default_app)
+
+    def test_bugreport(self):
+        self.assertTrue(_app.bugreport())
+
+
+class test_App(Case):
+
+    def setUp(self):
+        self.app = Celery(set_as_current=False)
+        self.app.conf.update(test_config)
+
+    def test_task(self):
+        app = Celery("foozibari", set_as_current=False)
+
+        def fun():
+            pass
+
+        fun.__module__ = "__main__"
+        task = app.task(fun)
+        self.assertEqual(task.name, app.main + ".fun")
+
+    def test_with_broker(self):
+        app = Celery(set_as_current=False, broker="foo://baribaz")
+        self.assertEqual(app.conf.BROKER_HOST, "foo://baribaz")
+
+    def test_repr(self):
+        self.assertTrue(repr(self.app))
+
+    def test_TaskSet(self):
+        ts = self.app.TaskSet()
+        self.assertListEqual(ts.tasks, [])
+        self.assertIs(ts.app, self.app)
+
+    def test_pickle_app(self):
+        changes = dict(THE_FOO_BAR="bars",
+                       THE_MII_MAR="jars")
+        self.app.conf.update(changes)
+        saved = pickle.dumps(self.app)
+        self.assertLess(len(saved), 2048)
+        restored = pickle.loads(saved)
+        self.assertDictContainsSubset(changes, restored.conf)
+
+    def test_worker_main(self):
+        from celery.bin import celeryd
+
+        class WorkerCommand(celeryd.WorkerCommand):
+
+            def execute_from_commandline(self, argv):
+                return argv
+
+        prev, celeryd.WorkerCommand = celeryd.WorkerCommand, WorkerCommand
+        try:
+            ret = self.app.worker_main(argv=["--version"])
+            self.assertListEqual(ret, ["--version"])
+        finally:
+            celeryd.WorkerCommand = prev
+
+    def test_config_from_envvar(self):
+        os.environ["CELERYTEST_CONFIG_OBJECT"] = "celery.tests.app.test_app"
+        self.app.config_from_envvar("CELERYTEST_CONFIG_OBJECT")
+        self.assertEqual(self.app.conf.THIS_IS_A_KEY, "this is a value")
+
+    def test_config_from_object(self):
+
+        class Object(object):
+            LEAVE_FOR_WORK = True
+            MOMENT_TO_STOP = True
+            CALL_ME_BACK = 123456789
+            WANT_ME_TO = False
+            UNDERSTAND_ME = True
+
+        self.app.config_from_object(Object())
+
+        self.assertTrue(self.app.conf.LEAVE_FOR_WORK)
+        self.assertTrue(self.app.conf.MOMENT_TO_STOP)
+        self.assertEqual(self.app.conf.CALL_ME_BACK, 123456789)
+        self.assertFalse(self.app.conf.WANT_ME_TO)
+        self.assertTrue(self.app.conf.UNDERSTAND_ME)
+
+    def test_config_from_cmdline(self):
+        cmdline = [".always_eager=no",
+                   ".result_backend=/dev/null",
+                   '.task_error_whitelist=(list)["a", "b", "c"]',
+                   "celeryd.prefetch_multiplier=368",
+                   ".foobarstring=(string)300",
+                   ".foobarint=(int)300",
+                   '.result_engine_options=(dict){"foo": "bar"}']
+        self.app.config_from_cmdline(cmdline, namespace="celery")
+        self.assertFalse(self.app.conf.CELERY_ALWAYS_EAGER)
+        self.assertEqual(self.app.conf.CELERY_RESULT_BACKEND, "/dev/null")
+        self.assertEqual(self.app.conf.CELERYD_PREFETCH_MULTIPLIER, 368)
+        self.assertListEqual(self.app.conf.CELERY_TASK_ERROR_WHITELIST,
+                             ["a", "b", "c"])
+        self.assertEqual(self.app.conf.CELERY_FOOBARSTRING, "300")
+        self.assertEqual(self.app.conf.CELERY_FOOBARINT, 300)
+        self.assertDictEqual(self.app.conf.CELERY_RESULT_ENGINE_OPTIONS,
+                             {"foo": "bar"})
+
+    def test_compat_setting_CELERY_BACKEND(self):
+
+        self.app.config_from_object(Object(CELERY_BACKEND="set_by_us"))
+        self.assertEqual(self.app.conf.CELERY_RESULT_BACKEND, "set_by_us")
+
+    def test_setting_BROKER_TRANSPORT_OPTIONS(self):
+
+        _args = {'foo': 'bar', 'spam': 'baz'}
+
+        self.app.config_from_object(Object())
+        self.assertEqual(self.app.conf.BROKER_TRANSPORT_OPTIONS, {})
+
+        self.app.config_from_object(Object(BROKER_TRANSPORT_OPTIONS=_args))
+        self.assertEqual(self.app.conf.BROKER_TRANSPORT_OPTIONS, _args)
+
+    def test_Windows_log_color_disabled(self):
+        self.app.IS_WINDOWS = True
+        self.assertFalse(self.app.log.supports_color())
+
+    def test_compat_setting_CARROT_BACKEND(self):
+        self.app.config_from_object(Object(CARROT_BACKEND="set_by_us"))
+        self.assertEqual(self.app.conf.BROKER_TRANSPORT, "set_by_us")
+
+    def test_WorkController(self):
+        x = self.app.Worker()
+        self.assertIs(x.app, self.app)
+
+    def test_AsyncResult(self):
+        x = self.app.AsyncResult("1")
+        self.assertIs(x.app, self.app)
+        r = loads(dumps(x))
+        # not set as current, so ends up as default app after reduce
+        self.assertIs(r.app, state.default_app)
+
+    @patch("celery.bin.celery.CeleryCommand.execute_from_commandline")
+    def test_start(self, execute):
+        self.app.start()
+        self.assertTrue(execute.called)
+
+    def test_mail_admins(self):
+
+        class Loader(BaseLoader):
+
+            def mail_admins(*args, **kwargs):
+                return args, kwargs
+
+        self.app.loader = Loader()
+        self.app.conf.ADMINS = None
+        self.assertFalse(self.app.mail_admins("Subject", "Body"))
+        self.app.conf.ADMINS = [("George Costanza", "george@vandelay.com")]
+        self.assertTrue(self.app.mail_admins("Subject", "Body"))
+
+    def test_amqp_get_broker_info(self):
+        self.assertDictContainsSubset({"hostname": "localhost",
+                                       "userid": "guest",
+                                       "password": "guest",
+                                       "virtual_host": "/"},
+                                      self.app.broker_connection(
+                                          transport="amqplib").info())
+        self.app.conf.BROKER_PORT = 1978
+        self.app.conf.BROKER_VHOST = "foo"
+        self.assertDictContainsSubset({"port": 1978,
+                                       "virtual_host": "foo"},
+                                      self.app.broker_connection(
+                                          transport="amqplib").info())
+        conn = self.app.broker_connection(virtual_host="/value")
+        self.assertDictContainsSubset({"virtual_host": "/value"},
+                                      conn.info())
+
+    def test_BROKER_BACKEND_alias(self):
+        self.assertEqual(self.app.conf.BROKER_BACKEND,
+                         self.app.conf.BROKER_TRANSPORT)
+
+    def test_with_default_connection(self):
+
+        @self.app.with_default_connection
+        def handler(connection=None, foo=None):
+            return connection, foo
+
+        connection, foo = handler(foo=42)
+        self.assertEqual(foo, 42)
+        self.assertTrue(connection)
+
+    def test_after_fork(self):
+        p = self.app._pool = Mock()
+        self.app._after_fork(self.app)
+        p.force_close_all.assert_called_with()
+        self.assertIsNone(self.app._pool)
+        self.app._after_fork(self.app)
+
+    def test_pool_no_multiprocessing(self):
+        with mask_modules("multiprocessing.util"):
+            pool = self.app.pool
+            self.assertIs(pool, self.app._pool)
+
+    def test_bugreport(self):
+        self.assertTrue(self.app.bugreport())
+
+    def test_send_task_sent_event(self):
+
+        class Dispatcher(object):
+            sent = []
+
+            def send(self, type, **fields):
+                self.sent.append((type, fields))
+
+        conn = self.app.broker_connection()
+        chan = conn.channel()
+        try:
+            for e in ("foo_exchange", "moo_exchange", "bar_exchange"):
+                chan.exchange_declare(e, "direct", durable=True)
+                chan.queue_declare(e, durable=True)
+                chan.queue_bind(e, e, e)
+        finally:
+            chan.close()
+        assert conn.transport_cls == "memory"
+
+        entities = conn.declared_entities
+
+        pub = self.app.amqp.TaskPublisher(conn, exchange="foo_exchange")
+        self.assertNotIn(pub._get_exchange("foo_exchange"), entities)
+
+        dispatcher = Dispatcher()
+        self.assertTrue(pub.delay_task("footask", (), {},
+                                       exchange="moo_exchange",
+                                       routing_key="moo_exchange",
+                                       event_dispatcher=dispatcher))
+        self.assertIn(pub._get_exchange("moo_exchange"), entities)
+        self.assertTrue(dispatcher.sent)
+        self.assertEqual(dispatcher.sent[0][0], "task-sent")
+        self.assertTrue(pub.delay_task("footask", (), {},
+                                       event_dispatcher=dispatcher,
+                                       exchange="bar_exchange",
+                                       routing_key="bar_exchange"))
+        self.assertIn(pub._get_exchange("bar_exchange"), entities)
+
+    def test_error_mail_sender(self):
+        x = ErrorMail.subject % {"name": "task_name",
+                                 "id": uuid(),
+                                 "exc": "FOOBARBAZ",
+                                 "hostname": "lana"}
+        self.assertTrue(x)
+
+
+class test_defaults(Case):
+
+    def test_str_to_bool(self):
+        for s in ("false", "no", "0"):
+            self.assertFalse(defaults.str_to_bool(s))
+        for s in ("true", "yes", "1"):
+            self.assertTrue(defaults.str_to_bool(s))
+        with self.assertRaises(TypeError):
+            defaults.str_to_bool("unsure")
+
+
+class test_debugging_utils(Case):
+
+    def test_enable_disable_trace(self):
+        try:
+            _app.enable_trace()
+            self.assertEqual(_app.app_or_default, _app._app_or_default_trace)
+            _app.disable_trace()
+            self.assertEqual(_app.app_or_default, _app._app_or_default)
+        finally:
+            _app.disable_trace()
+
+
+class test_pyimplementation(Case):
+
+    def test_platform_python_implementation(self):
+        with platform_pyimp(lambda: "Xython"):
+            self.assertEqual(pyimplementation(), "Xython")
+
+    def test_platform_jython(self):
+        with platform_pyimp():
+            with sys_platform("java 1.6.51"):
+                self.assertIn("Jython", pyimplementation())
+
+    def test_platform_pypy(self):
+        with platform_pyimp():
+            with sys_platform("darwin"):
+                with pypy_version((1, 4, 3)):
+                    self.assertIn("PyPy", pyimplementation())
+                with pypy_version((1, 4, 3, "a4")):
+                    self.assertIn("PyPy", pyimplementation())
+
+    def test_platform_fallback(self):
+        with platform_pyimp():
+            with sys_platform("darwin"):
+                with pypy_version():
+                    self.assertEqual("CPython", pyimplementation())

+ 0 - 40
celery/tests/backends/__init__.py

@@ -1,40 +0,0 @@
-from __future__ import absolute_import
-from __future__ import with_statement
-
-from celery import current_app
-from celery import backends
-from celery.backends.amqp import AMQPBackend
-from celery.backends.cache import CacheBackend
-from celery.tests.utils import Case
-
-
-class test_backends(Case):
-
-    def test_get_backend_aliases(self):
-        expects = [("amqp", AMQPBackend),
-                   ("cache", CacheBackend)]
-        for expect_name, expect_cls in expects:
-            self.assertIsInstance(backends.get_backend_cls(expect_name)(),
-                                  expect_cls)
-
-    def test_get_backend_cache(self):
-        backends.get_backend_cls.clear()
-        hits = backends.get_backend_cls.hits
-        misses = backends.get_backend_cls.misses
-        self.assertTrue(backends.get_backend_cls("amqp"))
-        self.assertEqual(backends.get_backend_cls.misses, misses + 1)
-        self.assertTrue(backends.get_backend_cls("amqp"))
-        self.assertEqual(backends.get_backend_cls.hits, hits + 1)
-
-    def test_unknown_backend(self):
-        with self.assertRaises(ValueError):
-            backends.get_backend_cls("fasodaopjeqijwqe")
-
-    def test_default_backend(self):
-        self.assertEqual(backends.default_backend, current_app.backend)
-
-    def test_backend_by_url(self, url="redis://localhost/1"):
-        from celery.backends.redis import RedisBackend
-        backend, url_ = backends.get_backend_by_url(url)
-        self.assertIs(backend, RedisBackend)
-        self.assertEqual(url_, url)

+ 40 - 0
celery/tests/backends/test_backends.py

@@ -0,0 +1,40 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+from celery import current_app
+from celery import backends
+from celery.backends.amqp import AMQPBackend
+from celery.backends.cache import CacheBackend
+from celery.tests.utils import Case
+
+
+class test_backends(Case):
+
+    def test_get_backend_aliases(self):
+        expects = [("amqp", AMQPBackend),
+                   ("cache", CacheBackend)]
+        for expect_name, expect_cls in expects:
+            self.assertIsInstance(backends.get_backend_cls(expect_name)(),
+                                  expect_cls)
+
+    def test_get_backend_cache(self):
+        backends.get_backend_cls.clear()
+        hits = backends.get_backend_cls.hits
+        misses = backends.get_backend_cls.misses
+        self.assertTrue(backends.get_backend_cls("amqp"))
+        self.assertEqual(backends.get_backend_cls.misses, misses + 1)
+        self.assertTrue(backends.get_backend_cls("amqp"))
+        self.assertEqual(backends.get_backend_cls.hits, hits + 1)
+
+    def test_unknown_backend(self):
+        with self.assertRaises(ValueError):
+            backends.get_backend_cls("fasodaopjeqijwqe")
+
+    def test_default_backend(self):
+        self.assertEqual(backends.default_backend, current_app.backend)
+
+    def test_backend_by_url(self, url="redis://localhost/1"):
+        from celery.backends.redis import RedisBackend
+        backend, url_ = backends.get_backend_by_url(url)
+        self.assertIs(backend, RedisBackend)
+        self.assertEqual(url_, url)

+ 0 - 91
celery/tests/bin/__init__.py

@@ -1,91 +0,0 @@
-from __future__ import absolute_import
-from __future__ import with_statement
-
-import os
-
-from celery.bin.base import Command
-from celery.tests.utils import AppCase, override_stdouts
-
-
-class Object(object):
-    pass
-
-
-class MyApp(object):
-    pass
-
-APP = MyApp()  # <-- Used by test_with_custom_app
-
-
-class MockCommand(Command):
-    mock_args = ("arg1", "arg2", "arg3")
-
-    def parse_options(self, prog_name, arguments):
-        options = Object()
-        options.foo = "bar"
-        options.prog_name = prog_name
-        return options, self.mock_args
-
-    def run(self, *args, **kwargs):
-        return args, kwargs
-
-
-class test_Command(AppCase):
-
-    def test_get_options(self):
-        cmd = Command()
-        cmd.option_list = (1, 2, 3)
-        self.assertTupleEqual(cmd.get_options(), (1, 2, 3))
-
-    def test_run_interface(self):
-        with self.assertRaises(NotImplementedError):
-            Command().run()
-
-    def test_execute_from_commandline(self):
-        cmd = MockCommand()
-        args1, kwargs1 = cmd.execute_from_commandline()     # sys.argv
-        self.assertTupleEqual(args1, cmd.mock_args)
-        self.assertDictContainsSubset({"foo": "bar"}, kwargs1)
-        self.assertTrue(kwargs1.get("prog_name"))
-        args2, kwargs2 = cmd.execute_from_commandline(["foo"])   # pass list
-        self.assertTupleEqual(args2, cmd.mock_args)
-        self.assertDictContainsSubset({"foo": "bar", "prog_name": "foo"},
-                                      kwargs2)
-
-    def test_with_bogus_args(self):
-        cmd = MockCommand()
-        cmd.supports_args = False
-        with override_stdouts() as (_, stderr):
-            with self.assertRaises(SystemExit):
-                cmd.execute_from_commandline(argv=["--bogus"])
-        self.assertTrue(stderr.getvalue())
-        self.assertIn("Unrecognized", stderr.getvalue())
-
-    def test_with_custom_config_module(self):
-        prev = os.environ.pop("CELERY_CONFIG_MODULE", None)
-        try:
-            cmd = MockCommand()
-            cmd.setup_app_from_commandline(["--config=foo.bar.baz"])
-            self.assertEqual(os.environ.get("CELERY_CONFIG_MODULE"),
-                             "foo.bar.baz")
-        finally:
-            if prev:
-                os.environ["CELERY_CONFIG_MODULE"] = prev
-
-    def test_with_custom_app(self):
-        cmd = MockCommand()
-        app = ".".join([__name__, "APP"])
-        cmd.setup_app_from_commandline(["--app=%s" % (app, ),
-                                        "--loglevel=INFO"])
-        self.assertIs(cmd.app, APP)
-
-    def test_with_cmdline_config(self):
-        cmd = MockCommand()
-        cmd.enable_config_from_cmdline = True
-        cmd.namespace = "celeryd"
-        rest = cmd.setup_app_from_commandline(argv=[
-            "--loglevel=INFO", "--", "broker.host=broker.example.com",
-            ".prefetch_multiplier=100"])
-        self.assertEqual(cmd.app.conf.BROKER_HOST, "broker.example.com")
-        self.assertEqual(cmd.app.conf.CELERYD_PREFETCH_MULTIPLIER, 100)
-        self.assertListEqual(rest, ["--loglevel=INFO"])

+ 91 - 0
celery/tests/bin/test_base.py

@@ -0,0 +1,91 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+import os
+
+from celery.bin.base import Command
+from celery.tests.utils import AppCase, override_stdouts
+
+
+class Object(object):
+    pass
+
+
+class MyApp(object):
+    pass
+
+APP = MyApp()  # <-- Used by test_with_custom_app
+
+
+class MockCommand(Command):
+    mock_args = ("arg1", "arg2", "arg3")
+
+    def parse_options(self, prog_name, arguments):
+        options = Object()
+        options.foo = "bar"
+        options.prog_name = prog_name
+        return options, self.mock_args
+
+    def run(self, *args, **kwargs):
+        return args, kwargs
+
+
+class test_Command(AppCase):
+
+    def test_get_options(self):
+        cmd = Command()
+        cmd.option_list = (1, 2, 3)
+        self.assertTupleEqual(cmd.get_options(), (1, 2, 3))
+
+    def test_run_interface(self):
+        with self.assertRaises(NotImplementedError):
+            Command().run()
+
+    def test_execute_from_commandline(self):
+        cmd = MockCommand()
+        args1, kwargs1 = cmd.execute_from_commandline()     # sys.argv
+        self.assertTupleEqual(args1, cmd.mock_args)
+        self.assertDictContainsSubset({"foo": "bar"}, kwargs1)
+        self.assertTrue(kwargs1.get("prog_name"))
+        args2, kwargs2 = cmd.execute_from_commandline(["foo"])   # pass list
+        self.assertTupleEqual(args2, cmd.mock_args)
+        self.assertDictContainsSubset({"foo": "bar", "prog_name": "foo"},
+                                      kwargs2)
+
+    def test_with_bogus_args(self):
+        cmd = MockCommand()
+        cmd.supports_args = False
+        with override_stdouts() as (_, stderr):
+            with self.assertRaises(SystemExit):
+                cmd.execute_from_commandline(argv=["--bogus"])
+        self.assertTrue(stderr.getvalue())
+        self.assertIn("Unrecognized", stderr.getvalue())
+
+    def test_with_custom_config_module(self):
+        prev = os.environ.pop("CELERY_CONFIG_MODULE", None)
+        try:
+            cmd = MockCommand()
+            cmd.setup_app_from_commandline(["--config=foo.bar.baz"])
+            self.assertEqual(os.environ.get("CELERY_CONFIG_MODULE"),
+                             "foo.bar.baz")
+        finally:
+            if prev:
+                os.environ["CELERY_CONFIG_MODULE"] = prev
+
+    def test_with_custom_app(self):
+        cmd = MockCommand()
+        app = ".".join([__name__, "APP"])
+        cmd.setup_app_from_commandline(["--app=%s" % (app, ),
+                                        "--loglevel=INFO"])
+        self.assertIs(cmd.app, APP)
+
+    def test_with_cmdline_config(self):
+        cmd = MockCommand()
+        cmd.enable_config_from_cmdline = True
+        cmd.namespace = "celeryd"
+        rest = cmd.setup_app_from_commandline(argv=[
+            "--loglevel=INFO", "--", "broker.host=broker.example.com",
+            ".prefetch_multiplier=100"])
+        self.assertEqual(cmd.app.conf.BROKER_HOST, "broker.example.com")
+        self.assertEqual(cmd.app.conf.CELERYD_PREFETCH_MULTIPLIER, 100)
+        self.assertListEqual(rest, ["--loglevel=INFO"])

+ 0 - 71
celery/tests/concurrency/__init__.py

@@ -1,71 +0,0 @@
-from __future__ import absolute_import
-from __future__ import with_statement
-
-import os
-
-from itertools import count
-
-from celery.concurrency.base import apply_target, BasePool
-from celery.tests.utils import Case
-
-
-class test_BasePool(Case):
-
-    def test_apply_target(self):
-
-        scratch = {}
-        counter = count(0).next
-
-        def gen_callback(name, retval=None):
-
-            def callback(*args):
-                scratch[name] = (counter(), args)
-                return retval
-
-            return callback
-
-        apply_target(gen_callback("target", 42),
-                     args=(8, 16),
-                     callback=gen_callback("callback"),
-                     accept_callback=gen_callback("accept_callback"))
-
-        self.assertDictContainsSubset({
-                              "target": (1, (8, 16)),
-                              "callback": (2, (42, ))}, scratch)
-        pa1 = scratch["accept_callback"]
-        self.assertEqual(0, pa1[0])
-        self.assertEqual(pa1[1][0], os.getpid())
-        self.assertTrue(pa1[1][1])
-
-        # No accept callback
-        scratch.clear()
-        apply_target(gen_callback("target", 42),
-                     args=(8, 16),
-                     callback=gen_callback("callback"),
-                     accept_callback=None)
-        self.assertDictEqual(scratch,
-                              {"target": (3, (8, 16)),
-                               "callback": (4, (42, ))})
-
-    def test_interface_on_start(self):
-        BasePool(10).on_start()
-
-    def test_interface_on_stop(self):
-        BasePool(10).on_stop()
-
-    def test_interface_on_apply(self):
-        BasePool(10).on_apply()
-
-    def test_interface_info(self):
-        self.assertDictEqual(BasePool(10).info, {})
-
-    def test_active(self):
-        p = BasePool(10)
-        self.assertFalse(p.active)
-        p._state = p.RUN
-        self.assertTrue(p.active)
-
-    def test_restart(self):
-        p = BasePool(10)
-        with self.assertRaises(NotImplementedError):
-            p.restart()

+ 71 - 0
celery/tests/concurrency/test_concurrency.py

@@ -0,0 +1,71 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+import os
+
+from itertools import count
+
+from celery.concurrency.base import apply_target, BasePool
+from celery.tests.utils import Case
+
+
+class test_BasePool(Case):
+
+    def test_apply_target(self):
+
+        scratch = {}
+        counter = count(0).next
+
+        def gen_callback(name, retval=None):
+
+            def callback(*args):
+                scratch[name] = (counter(), args)
+                return retval
+
+            return callback
+
+        apply_target(gen_callback("target", 42),
+                     args=(8, 16),
+                     callback=gen_callback("callback"),
+                     accept_callback=gen_callback("accept_callback"))
+
+        self.assertDictContainsSubset({
+                              "target": (1, (8, 16)),
+                              "callback": (2, (42, ))}, scratch)
+        pa1 = scratch["accept_callback"]
+        self.assertEqual(0, pa1[0])
+        self.assertEqual(pa1[1][0], os.getpid())
+        self.assertTrue(pa1[1][1])
+
+        # No accept callback
+        scratch.clear()
+        apply_target(gen_callback("target", 42),
+                     args=(8, 16),
+                     callback=gen_callback("callback"),
+                     accept_callback=None)
+        self.assertDictEqual(scratch,
+                              {"target": (3, (8, 16)),
+                               "callback": (4, (42, ))})
+
+    def test_interface_on_start(self):
+        BasePool(10).on_start()
+
+    def test_interface_on_stop(self):
+        BasePool(10).on_stop()
+
+    def test_interface_on_apply(self):
+        BasePool(10).on_apply()
+
+    def test_interface_info(self):
+        self.assertDictEqual(BasePool(10).info, {})
+
+    def test_active(self):
+        p = BasePool(10)
+        self.assertFalse(p.active)
+        p._state = p.RUN
+        self.assertTrue(p.active)
+
+    def test_restart(self):
+        p = BasePool(10)
+        with self.assertRaises(NotImplementedError):
+            p.restart()

+ 0 - 191
celery/tests/events/__init__.py

@@ -1,191 +0,0 @@
-from __future__ import absolute_import
-from __future__ import with_statement
-
-import socket
-
-from celery import events
-from celery.app import app_or_default
-from celery.tests.utils import Case
-
-
-class MockProducer(object):
-    raise_on_publish = False
-
-    def __init__(self, *args, **kwargs):
-        self.sent = []
-
-    def publish(self, msg, *args, **kwargs):
-        if self.raise_on_publish:
-            raise KeyError()
-        self.sent.append(msg)
-
-    def close(self):
-        pass
-
-    def has_event(self, kind):
-        for event in self.sent:
-            if event["type"] == kind:
-                return event
-        return False
-
-
-class test_Event(Case):
-
-    def test_constructor(self):
-        event = events.Event("world war II")
-        self.assertEqual(event["type"], "world war II")
-        self.assertTrue(event["timestamp"])
-
-
-class test_EventDispatcher(Case):
-
-    def setUp(self):
-        self.app = app_or_default()
-
-    def test_send(self):
-        producer = MockProducer()
-        eventer = self.app.events.Dispatcher(object(), enabled=False)
-        eventer.publisher = producer
-        eventer.enabled = True
-        eventer.send("World War II", ended=True)
-        self.assertTrue(producer.has_event("World War II"))
-        eventer.enabled = False
-        eventer.send("World War III")
-        self.assertFalse(producer.has_event("World War III"))
-
-        evs = ("Event 1", "Event 2", "Event 3")
-        eventer.enabled = True
-        eventer.publisher.raise_on_publish = True
-        eventer.buffer_while_offline = False
-        with self.assertRaises(KeyError):
-            eventer.send("Event X")
-        eventer.buffer_while_offline = True
-        for ev in evs:
-            eventer.send(ev)
-        eventer.publisher.raise_on_publish = False
-        eventer.flush()
-        for ev in evs:
-            self.assertTrue(producer.has_event(ev))
-
-    def test_enabled_disable(self):
-        connection = self.app.broker_connection()
-        channel = connection.channel()
-        try:
-            dispatcher = self.app.events.Dispatcher(connection,
-                                                    enabled=True)
-            dispatcher2 = self.app.events.Dispatcher(connection,
-                                                     enabled=True,
-                                                      channel=channel)
-            self.assertTrue(dispatcher.enabled)
-            self.assertTrue(dispatcher.publisher.channel)
-            self.assertEqual(dispatcher.publisher.serializer,
-                            self.app.conf.CELERY_EVENT_SERIALIZER)
-
-            created_channel = dispatcher.publisher.channel
-            dispatcher.disable()
-            dispatcher.disable()  # Disable with no active publisher
-            dispatcher2.disable()
-            self.assertFalse(dispatcher.enabled)
-            self.assertIsNone(dispatcher.publisher)
-            self.assertTrue(created_channel.closed)
-            self.assertFalse(dispatcher2.channel.closed,
-                             "does not close manually provided channel")
-
-            dispatcher.enable()
-            self.assertTrue(dispatcher.enabled)
-            self.assertTrue(dispatcher.publisher)
-        finally:
-            channel.close()
-            connection.close()
-
-
-class test_EventReceiver(Case):
-
-    def setUp(self):
-        self.app = app_or_default()
-
-    def test_process(self):
-
-        message = {"type": "world-war"}
-
-        got_event = [False]
-
-        def my_handler(event):
-            got_event[0] = True
-
-        r = events.EventReceiver(object(),
-                                 handlers={"world-war": my_handler},
-                                 node_id="celery.tests",
-                                 )
-        r._receive(message, object())
-        self.assertTrue(got_event[0])
-
-    def test_catch_all_event(self):
-
-        message = {"type": "world-war"}
-
-        got_event = [False]
-
-        def my_handler(event):
-            got_event[0] = True
-
-        r = events.EventReceiver(object(), node_id="celery.tests")
-        events.EventReceiver.handlers["*"] = my_handler
-        try:
-            r._receive(message, object())
-            self.assertTrue(got_event[0])
-        finally:
-            events.EventReceiver.handlers = {}
-
-    def test_itercapture(self):
-        connection = self.app.broker_connection()
-        try:
-            r = self.app.events.Receiver(connection, node_id="celery.tests")
-            it = r.itercapture(timeout=0.0001, wakeup=False)
-            consumer = it.next()
-            self.assertTrue(consumer.queues)
-            self.assertEqual(consumer.callbacks[0], r._receive)
-
-            with self.assertRaises(socket.timeout):
-                it.next()
-
-            with self.assertRaises(socket.timeout):
-                r.capture(timeout=0.00001)
-        finally:
-            connection.close()
-
-    def test_itercapture_limit(self):
-        connection = self.app.broker_connection()
-        channel = connection.channel()
-        try:
-            events_received = [0]
-
-            def handler(event):
-                events_received[0] += 1
-
-            producer = self.app.events.Dispatcher(connection,
-                                                  enabled=True,
-                                                  channel=channel)
-            r = self.app.events.Receiver(connection,
-                                         handlers={"*": handler},
-                                         node_id="celery.tests")
-            evs = ["ev1", "ev2", "ev3", "ev4", "ev5"]
-            for ev in evs:
-                producer.send(ev)
-            it = r.itercapture(limit=4, wakeup=True)
-            it.next()  # skip consumer (see itercapture)
-            list(it)
-            self.assertEqual(events_received[0], 4)
-        finally:
-            channel.close()
-            connection.close()
-
-
-class test_misc(Case):
-
-    def setUp(self):
-        self.app = app_or_default()
-
-    def test_State(self):
-        state = self.app.events.State()
-        self.assertDictEqual(dict(state.workers), {})

+ 191 - 0
celery/tests/events/test_events.py

@@ -0,0 +1,191 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+import socket
+
+from celery import events
+from celery.app import app_or_default
+from celery.tests.utils import Case
+
+
+class MockProducer(object):
+    raise_on_publish = False
+
+    def __init__(self, *args, **kwargs):
+        self.sent = []
+
+    def publish(self, msg, *args, **kwargs):
+        if self.raise_on_publish:
+            raise KeyError()
+        self.sent.append(msg)
+
+    def close(self):
+        pass
+
+    def has_event(self, kind):
+        for event in self.sent:
+            if event["type"] == kind:
+                return event
+        return False
+
+
+class test_Event(Case):
+
+    def test_constructor(self):
+        event = events.Event("world war II")
+        self.assertEqual(event["type"], "world war II")
+        self.assertTrue(event["timestamp"])
+
+
+class test_EventDispatcher(Case):
+
+    def setUp(self):
+        self.app = app_or_default()
+
+    def test_send(self):
+        producer = MockProducer()
+        eventer = self.app.events.Dispatcher(object(), enabled=False)
+        eventer.publisher = producer
+        eventer.enabled = True
+        eventer.send("World War II", ended=True)
+        self.assertTrue(producer.has_event("World War II"))
+        eventer.enabled = False
+        eventer.send("World War III")
+        self.assertFalse(producer.has_event("World War III"))
+
+        evs = ("Event 1", "Event 2", "Event 3")
+        eventer.enabled = True
+        eventer.publisher.raise_on_publish = True
+        eventer.buffer_while_offline = False
+        with self.assertRaises(KeyError):
+            eventer.send("Event X")
+        eventer.buffer_while_offline = True
+        for ev in evs:
+            eventer.send(ev)
+        eventer.publisher.raise_on_publish = False
+        eventer.flush()
+        for ev in evs:
+            self.assertTrue(producer.has_event(ev))
+
+    def test_enabled_disable(self):
+        connection = self.app.broker_connection()
+        channel = connection.channel()
+        try:
+            dispatcher = self.app.events.Dispatcher(connection,
+                                                    enabled=True)
+            dispatcher2 = self.app.events.Dispatcher(connection,
+                                                     enabled=True,
+                                                      channel=channel)
+            self.assertTrue(dispatcher.enabled)
+            self.assertTrue(dispatcher.publisher.channel)
+            self.assertEqual(dispatcher.publisher.serializer,
+                            self.app.conf.CELERY_EVENT_SERIALIZER)
+
+            created_channel = dispatcher.publisher.channel
+            dispatcher.disable()
+            dispatcher.disable()  # Disable with no active publisher
+            dispatcher2.disable()
+            self.assertFalse(dispatcher.enabled)
+            self.assertIsNone(dispatcher.publisher)
+            self.assertTrue(created_channel.closed)
+            self.assertFalse(dispatcher2.channel.closed,
+                             "does not close manually provided channel")
+
+            dispatcher.enable()
+            self.assertTrue(dispatcher.enabled)
+            self.assertTrue(dispatcher.publisher)
+        finally:
+            channel.close()
+            connection.close()
+
+
+class test_EventReceiver(Case):
+
+    def setUp(self):
+        self.app = app_or_default()
+
+    def test_process(self):
+
+        message = {"type": "world-war"}
+
+        got_event = [False]
+
+        def my_handler(event):
+            got_event[0] = True
+
+        r = events.EventReceiver(object(),
+                                 handlers={"world-war": my_handler},
+                                 node_id="celery.tests",
+                                 )
+        r._receive(message, object())
+        self.assertTrue(got_event[0])
+
+    def test_catch_all_event(self):
+
+        message = {"type": "world-war"}
+
+        got_event = [False]
+
+        def my_handler(event):
+            got_event[0] = True
+
+        r = events.EventReceiver(object(), node_id="celery.tests")
+        events.EventReceiver.handlers["*"] = my_handler
+        try:
+            r._receive(message, object())
+            self.assertTrue(got_event[0])
+        finally:
+            events.EventReceiver.handlers = {}
+
+    def test_itercapture(self):
+        connection = self.app.broker_connection()
+        try:
+            r = self.app.events.Receiver(connection, node_id="celery.tests")
+            it = r.itercapture(timeout=0.0001, wakeup=False)
+            consumer = it.next()
+            self.assertTrue(consumer.queues)
+            self.assertEqual(consumer.callbacks[0], r._receive)
+
+            with self.assertRaises(socket.timeout):
+                it.next()
+
+            with self.assertRaises(socket.timeout):
+                r.capture(timeout=0.00001)
+        finally:
+            connection.close()
+
+    def test_itercapture_limit(self):
+        connection = self.app.broker_connection()
+        channel = connection.channel()
+        try:
+            events_received = [0]
+
+            def handler(event):
+                events_received[0] += 1
+
+            producer = self.app.events.Dispatcher(connection,
+                                                  enabled=True,
+                                                  channel=channel)
+            r = self.app.events.Receiver(connection,
+                                         handlers={"*": handler},
+                                         node_id="celery.tests")
+            evs = ["ev1", "ev2", "ev3", "ev4", "ev5"]
+            for ev in evs:
+                producer.send(ev)
+            it = r.itercapture(limit=4, wakeup=True)
+            it.next()  # skip consumer (see itercapture)
+            list(it)
+            self.assertEqual(events_received[0], 4)
+        finally:
+            channel.close()
+            connection.close()
+
+
+class test_misc(Case):
+
+    def setUp(self):
+        self.app = app_or_default()
+
+    def test_State(self):
+        state = self.app.events.State()
+        self.assertDictEqual(dict(state.workers), {})

+ 0 - 96
celery/tests/security/__init__.py

@@ -1,33 +1,4 @@
-"""
-Keys and certificates for tests (KEY1 is a private key of CERT1, etc.)
-
-Generated with::
-
-    $ openssl genrsa -des3 -passout pass:test -out key1.key 1024
-    $ openssl req -new -key key1.key -out key1.csr -passin pass:test
-    $ cp key1.key key1.key.org
-    $ openssl rsa -in key1.key.org -out key1.key -passin pass:test
-    $ openssl x509 -req -days 365 -in cert1.csr \
-              -signkey key1.key -out cert1.crt
-    $ rm key1.key.org cert1.csr
-
-"""
 from __future__ import absolute_import
-from __future__ import with_statement
-
-import __builtin__
-
-from mock import Mock, patch
-
-from celery import current_app
-from celery.exceptions import ImproperlyConfigured
-from celery.security import setup_security, disable_untrusted_serializers
-from kombu.serialization import registry
-
-from .case import SecurityCase
-
-from celery.tests.utils import mock_open
-
 
 KEY1 = """-----BEGIN RSA PRIVATE KEY-----
 MIICXgIBAAKBgQDCsmLC+eqL4z6bhtv0nzbcnNXuQrZUoh827jGfDI3kxNZ2LbEy
@@ -88,70 +59,3 @@ AAOBgQBzaZ5vBkzksPhnWb2oobuy6Ne/LMEtdQ//qeVY4sKl2tOJUCSdWRen9fqP
 e+zYdEdkFCd8rp568Eiwkq/553uy4rlE927/AEqs/+KGYmAtibk/9vmi+/+iZXyS
 WWZybzzDZFncq1/N1C3Y/hrCBNDFO4TsnTLAhWtZ4c0vDAiacw==
 -----END CERTIFICATE-----"""
-
-
-class test_security(SecurityCase):
-
-    def tearDown(self):
-        registry._disabled_content_types.clear()
-
-    def test_disable_untrusted_serializers(self):
-        disabled = registry._disabled_content_types
-        self.assertEqual(0, len(disabled))
-
-        disable_untrusted_serializers(
-                ['application/json', 'application/x-python-serialize'])
-        self.assertIn('application/x-yaml', disabled)
-        self.assertNotIn('application/json', disabled)
-        self.assertNotIn('application/x-python-serialize', disabled)
-        disabled.clear()
-
-        disable_untrusted_serializers()
-        self.assertIn('application/x-yaml', disabled)
-        self.assertIn('application/json', disabled)
-        self.assertIn('application/x-python-serialize', disabled)
-
-    def test_setup_security(self):
-        disabled = registry._disabled_content_types
-        self.assertEqual(0, len(disabled))
-
-        current_app.conf.CELERY_TASK_SERIALIZER = 'json'
-
-        setup_security()
-        self.assertIn('application/x-python-serialize', disabled)
-        disabled.clear()
-
-    @patch("celery.security.register_auth")
-    @patch("celery.security.disable_untrusted_serializers")
-    def test_setup_registry_complete(self, dis, reg, key="KEY", cert="CERT"):
-        calls = [0]
-
-        def effect(*args):
-            try:
-                m = Mock()
-                m.read.return_value = "B" if calls[0] else "A"
-                return m
-            finally:
-                calls[0] += 1
-
-        with mock_open(side_effect=effect):
-            store = Mock()
-            setup_security(["json"], key, cert, store)
-            dis.assert_called_with(["json"])
-            reg.assert_called_with("A", "B", store)
-
-    def test_security_conf(self):
-        current_app.conf.CELERY_TASK_SERIALIZER = 'auth'
-
-        self.assertRaises(ImproperlyConfigured, setup_security)
-
-        _import = __builtin__.__import__
-
-        def import_hook(name, *args, **kwargs):
-            if name == 'OpenSSL':
-                raise ImportError
-            return _import(name, *args, **kwargs)
-
-        __builtin__.__import__ = import_hook
-        self.assertRaises(ImproperlyConfigured, setup_security)
-        __builtin__.__import__ = _import

+ 96 - 0
celery/tests/security/test_security.py

@@ -0,0 +1,96 @@
+"""
+Keys and certificates for tests (KEY1 is a private key of CERT1, etc.)
+
+Generated with::
+
+    $ openssl genrsa -des3 -passout pass:test -out key1.key 1024
+    $ openssl req -new -key key1.key -out key1.csr -passin pass:test
+    $ cp key1.key key1.key.org
+    $ openssl rsa -in key1.key.org -out key1.key -passin pass:test
+    $ openssl x509 -req -days 365 -in cert1.csr \
+              -signkey key1.key -out cert1.crt
+    $ rm key1.key.org cert1.csr
+
+"""
+from __future__ import absolute_import
+from __future__ import with_statement
+
+import __builtin__
+
+from mock import Mock, patch
+
+from celery import current_app
+from celery.exceptions import ImproperlyConfigured
+from celery.security import setup_security, disable_untrusted_serializers
+from kombu.serialization import registry
+
+from .case import SecurityCase
+
+from celery.tests.utils import mock_open
+
+
+class test_security(SecurityCase):
+
+    def tearDown(self):
+        registry._disabled_content_types.clear()
+
+    def test_disable_untrusted_serializers(self):
+        disabled = registry._disabled_content_types
+        self.assertEqual(0, len(disabled))
+
+        disable_untrusted_serializers(
+                ['application/json', 'application/x-python-serialize'])
+        self.assertIn('application/x-yaml', disabled)
+        self.assertNotIn('application/json', disabled)
+        self.assertNotIn('application/x-python-serialize', disabled)
+        disabled.clear()
+
+        disable_untrusted_serializers()
+        self.assertIn('application/x-yaml', disabled)
+        self.assertIn('application/json', disabled)
+        self.assertIn('application/x-python-serialize', disabled)
+
+    def test_setup_security(self):
+        disabled = registry._disabled_content_types
+        self.assertEqual(0, len(disabled))
+
+        current_app.conf.CELERY_TASK_SERIALIZER = 'json'
+
+        setup_security()
+        self.assertIn('application/x-python-serialize', disabled)
+        disabled.clear()
+
+    @patch("celery.security.register_auth")
+    @patch("celery.security.disable_untrusted_serializers")
+    def test_setup_registry_complete(self, dis, reg, key="KEY", cert="CERT"):
+        calls = [0]
+
+        def effect(*args):
+            try:
+                m = Mock()
+                m.read.return_value = "B" if calls[0] else "A"
+                return m
+            finally:
+                calls[0] += 1
+
+        with mock_open(side_effect=effect):
+            store = Mock()
+            setup_security(["json"], key, cert, store)
+            dis.assert_called_with(["json"])
+            reg.assert_called_with("A", "B", store)
+
+    def test_security_conf(self):
+        current_app.conf.CELERY_TASK_SERIALIZER = 'auth'
+
+        self.assertRaises(ImproperlyConfigured, setup_security)
+
+        _import = __builtin__.__import__
+
+        def import_hook(name, *args, **kwargs):
+            if name == 'OpenSSL':
+                raise ImportError
+            return _import(name, *args, **kwargs)
+
+        __builtin__.__import__ = import_hook
+        self.assertRaises(ImproperlyConfigured, setup_security)
+        __builtin__.__import__ = _import

+ 39 - 1
celery/tests/slow/test_buckets.py

@@ -6,6 +6,9 @@ import time
 
 from functools import partial
 from itertools import chain, izip
+from Queue import Empty
+
+from mock import Mock, patch
 
 from celery.app.registry import TaskRegistry
 from celery.task.base import Task
@@ -13,7 +16,7 @@ from celery.utils import timeutils
 from celery.utils import uuid
 from celery.worker import buckets
 
-from celery.tests.utils import Case, skip_if_environ
+from celery.tests.utils import Case, skip_if_environ, mock_context
 
 skip_if_disabled = partial(skip_if_environ("SKIP_RLIMITS"))
 
@@ -140,6 +143,41 @@ class test_TaskBucket(Case):
         with self.assertRaises(buckets.Empty):
             x.get_nowait()
 
+    @patch("celery.worker.buckets.sleep")
+    def test_get_block(self, sleep):
+        x = buckets.TaskBucket(task_registry=self.registry)
+        x.not_empty = Mock()
+        get = x._get = Mock()
+        calls = [0]
+        remaining = [0]
+
+        def effect():
+            try:
+                if not calls[0]:
+                    raise Empty()
+                rem = remaining[0]
+                remaining[0] = 0
+                return rem, Mock()
+            finally:
+                calls[0] += 1
+        get.side_effect = effect
+
+        with mock_context(Mock()) as context:
+            x.not_empty = context
+            x.wait = Mock()
+            x.get(block=True)
+
+            calls[0] = 0
+            remaining[0] = 1
+            x.get(block=True)
+
+    def test_get_raises_rate(self):
+        x = buckets.TaskBucket(task_registry=self.registry)
+        x.buckets = {1: Mock()}
+        x.buckets[1].get_nowait.side_effect = buckets.RateLimitExceeded()
+        x.buckets[1].expected_time.return_value = 0
+        x._get()
+
     @skip_if_disabled
     def test_refresh(self):
         reg = {}

+ 0 - 865
celery/tests/tasks/__init__.py

@@ -1,865 +0,0 @@
-from __future__ import absolute_import
-from __future__ import with_statement
-
-from datetime import datetime, timedelta
-from functools import wraps
-
-from celery import task
-from celery.task import current
-from celery.app import app_or_default
-from celery.task import task as task_dec
-from celery.exceptions import RetryTaskError
-from celery.execute import send_task
-from celery.result import EagerResult
-from celery.schedules import crontab, crontab_parser, ParseException
-from celery.utils import uuid
-from celery.utils.timeutils import parse_iso8601, timedelta_seconds
-
-from celery.tests.utils import Case, with_eager_tasks, WhateverIO
-
-
-def return_True(*args, **kwargs):
-    # Task run functions can't be closures/lambdas, as they're pickled.
-    return True
-
-
-return_True_task = task_dec()(return_True)
-
-
-def raise_exception(self, **kwargs):
-    raise Exception("%s error" % self.__class__)
-
-
-class MockApplyTask(task.Task):
-    applied = 0
-
-    def run(self, x, y):
-        return x * y
-
-    @classmethod
-    def apply_async(self, *args, **kwargs):
-        self.applied += 1
-
-
-@task.task(name="c.unittest.increment_counter_task", count=0)
-def increment_counter(increment_by=1):
-    increment_counter.count += increment_by or 1
-    return increment_counter.count
-
-
-@task.task(name="c.unittest.raising_task")
-def raising():
-    raise KeyError("foo")
-
-
-@task.task(max_retries=3, iterations=0)
-def retry_task(arg1, arg2, kwarg=1, max_retries=None, care=True):
-    current.iterations += 1
-    rmax = current.max_retries if max_retries is None else max_retries
-
-    retries = current.request.retries
-    if care and retries >= rmax:
-        return arg1
-    else:
-        return current.retry(countdown=0, max_retries=rmax)
-
-
-@task.task(max_retries=3, iterations=0)
-def retry_task_noargs(**kwargs):
-    current.iterations += 1
-
-    retries = kwargs["task_retries"]
-    if retries >= 3:
-        return 42
-    else:
-        return current.retry(countdown=0)
-
-
-@task.task(max_retries=3, iterations=0, base=MockApplyTask)
-def retry_task_mockapply(arg1, arg2, kwarg=1, **kwargs):
-    current.iterations += 1
-
-    retries = kwargs["task_retries"]
-    if retries >= 3:
-        return arg1
-    else:
-        kwargs.update(kwarg=kwarg)
-    return current.retry(countdown=0)
-
-
-class MyCustomException(Exception):
-    """Random custom exception."""
-
-
-@task.task(max_retries=3, iterations=0, accept_magic_kwargs=True)
-def retry_task_customexc(arg1, arg2, kwarg=1, **kwargs):
-    current.iterations += 1
-
-    retries = kwargs["task_retries"]
-    if retries >= 3:
-        return arg1 + kwarg
-    else:
-        try:
-            raise MyCustomException("Elaine Marie Benes")
-        except MyCustomException, exc:
-            kwargs.update(kwarg=kwarg)
-            return current.retry(countdown=0, exc=exc)
-
-
-class test_task_retries(Case):
-
-    def test_retry(self):
-        retry_task.__class__.max_retries = 3
-        retry_task.iterations = 0
-        result = retry_task.apply([0xFF, 0xFFFF])
-        self.assertEqual(result.get(), 0xFF)
-        self.assertEqual(retry_task.iterations, 4)
-
-        retry_task.__class__.max_retries = 3
-        retry_task.iterations = 0
-        result = retry_task.apply([0xFF, 0xFFFF], {"max_retries": 10})
-        self.assertEqual(result.get(), 0xFF)
-        self.assertEqual(retry_task.iterations, 11)
-
-    def test_retry_no_args(self):
-        retry_task_noargs.__class__.max_retries = 3
-        retry_task_noargs.iterations = 0
-        result = retry_task_noargs.apply()
-        self.assertEqual(result.get(), 42)
-        self.assertEqual(retry_task_noargs.iterations, 4)
-
-    def test_retry_kwargs_can_be_empty(self):
-        with self.assertRaises(RetryTaskError):
-            retry_task_mockapply.retry(args=[4, 4], kwargs=None)
-
-    def test_retry_not_eager(self):
-        retry_task_mockapply.request.called_directly = False
-        exc = Exception("baz")
-        try:
-            retry_task_mockapply.retry(args=[4, 4], kwargs={"task_retries": 0},
-                                       exc=exc, throw=False)
-            self.assertTrue(retry_task_mockapply.__class__.applied)
-        finally:
-            retry_task_mockapply.__class__.applied = 0
-
-        try:
-            with self.assertRaises(RetryTaskError):
-                retry_task_mockapply.retry(
-                    args=[4, 4], kwargs={"task_retries": 0},
-                    exc=exc, throw=True)
-            self.assertTrue(retry_task_mockapply.__class__.applied)
-        finally:
-            retry_task_mockapply.__class__.applied = 0
-
-    def test_retry_with_kwargs(self):
-        retry_task_customexc.__class__.max_retries = 3
-        retry_task_customexc.iterations = 0
-        result = retry_task_customexc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
-        self.assertEqual(result.get(), 0xFF + 0xF)
-        self.assertEqual(retry_task_customexc.iterations, 4)
-
-    def test_retry_with_custom_exception(self):
-        retry_task_customexc.__class__.max_retries = 2
-        retry_task_customexc.iterations = 0
-        result = retry_task_customexc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
-        with self.assertRaises(MyCustomException):
-            result.get()
-        self.assertEqual(retry_task_customexc.iterations, 3)
-
-    def test_max_retries_exceeded(self):
-        retry_task.__class__.max_retries = 2
-        retry_task.iterations = 0
-        result = retry_task.apply([0xFF, 0xFFFF], {"care": False})
-        with self.assertRaises(retry_task.MaxRetriesExceededError):
-            result.get()
-        self.assertEqual(retry_task.iterations, 3)
-
-        retry_task.__class__.max_retries = 1
-        retry_task.iterations = 0
-        result = retry_task.apply([0xFF, 0xFFFF], {"care": False})
-        with self.assertRaises(retry_task.MaxRetriesExceededError):
-            result.get()
-        self.assertEqual(retry_task.iterations, 2)
-
-
-class test_tasks(Case):
-
-    def test_unpickle_task(self):
-        import pickle
-
-        @task_dec
-        def xxx():
-            pass
-
-        self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
-
-    def createTask(self, name):
-        return task.task(__module__=self.__module__, name=name)(return_True)
-
-    def test_AsyncResult(self):
-        task_id = uuid()
-        result = retry_task.AsyncResult(task_id)
-        self.assertEqual(result.backend, retry_task.backend)
-        self.assertEqual(result.id, task_id)
-
-    def assertNextTaskDataEqual(self, consumer, presult, task_name,
-            test_eta=False, test_expires=False, **kwargs):
-        next_task = consumer.fetch()
-        task_data = next_task.decode()
-        self.assertEqual(task_data["id"], presult.id)
-        self.assertEqual(task_data["task"], task_name)
-        task_kwargs = task_data.get("kwargs", {})
-        if test_eta:
-            self.assertIsInstance(task_data.get("eta"), basestring)
-            to_datetime = parse_iso8601(task_data.get("eta"))
-            self.assertIsInstance(to_datetime, datetime)
-        if test_expires:
-            self.assertIsInstance(task_data.get("expires"), basestring)
-            to_datetime = parse_iso8601(task_data.get("expires"))
-            self.assertIsInstance(to_datetime, datetime)
-        for arg_name, arg_value in kwargs.items():
-            self.assertEqual(task_kwargs.get(arg_name), arg_value)
-
-    def test_incomplete_task_cls(self):
-
-        class IncompleteTask(task.Task):
-            name = "c.unittest.t.itask"
-
-        with self.assertRaises(NotImplementedError):
-            IncompleteTask().run()
-
-    def test_task_kwargs_must_be_dictionary(self):
-        with self.assertRaises(ValueError):
-            increment_counter.apply_async([], "str")
-
-    def test_task_args_must_be_list(self):
-        with self.assertRaises(ValueError):
-            increment_counter.apply_async("str", {})
-
-    def test_regular_task(self):
-        T1 = self.createTask("c.unittest.t.t1")
-        self.assertIsInstance(T1, task.BaseTask)
-        self.assertTrue(T1.run())
-        self.assertTrue(callable(T1),
-                "Task class is callable()")
-        self.assertTrue(T1(),
-                "Task class runs run() when called")
-
-        consumer = T1.get_consumer()
-        with self.assertRaises(NotImplementedError):
-            consumer.receive("foo", "foo")
-        consumer.discard_all()
-        self.assertIsNone(consumer.fetch())
-
-        # Without arguments.
-        presult = T1.delay()
-        self.assertNextTaskDataEqual(consumer, presult, T1.name)
-
-        # With arguments.
-        presult2 = T1.apply_async(kwargs=dict(name="George Costanza"))
-        self.assertNextTaskDataEqual(consumer, presult2, T1.name,
-                name="George Costanza")
-
-        # send_task
-        sresult = send_task(T1.name, kwargs=dict(name="Elaine M. Benes"))
-        self.assertNextTaskDataEqual(consumer, sresult, T1.name,
-                name="Elaine M. Benes")
-
-        # With eta.
-        presult2 = T1.apply_async(kwargs=dict(name="George Costanza"),
-                            eta=datetime.utcnow() + timedelta(days=1),
-                            expires=datetime.utcnow() + timedelta(days=2))
-        self.assertNextTaskDataEqual(consumer, presult2, T1.name,
-                name="George Costanza", test_eta=True, test_expires=True)
-
-        # With countdown.
-        presult2 = T1.apply_async(kwargs=dict(name="George Costanza"),
-                                  countdown=10, expires=12)
-        self.assertNextTaskDataEqual(consumer, presult2, T1.name,
-                name="George Costanza", test_eta=True, test_expires=True)
-
-        # Discarding all tasks.
-        consumer.discard_all()
-        T1.apply_async()
-        self.assertEqual(consumer.discard_all(), 1)
-        self.assertIsNone(consumer.fetch())
-
-        self.assertFalse(presult.successful())
-        T1.backend.mark_as_done(presult.id, result=None)
-        self.assertTrue(presult.successful())
-
-        publisher = T1.get_publisher()
-        self.assertTrue(publisher.exchange)
-
-    def test_context_get(self):
-        request = self.createTask("c.unittest.t.c.g").request
-        request.foo = 32
-        self.assertEqual(request.get("foo"), 32)
-        self.assertEqual(request.get("bar", 36), 36)
-        request.clear()
-
-    def test_task_class_repr(self):
-        task = self.createTask("c.unittest.t.repr")
-        self.assertIn("class Task of", repr(task.app.Task))
-
-    def test_after_return(self):
-        task = self.createTask("c.unittest.t.after_return")
-        task.request.chord = return_True_task.s()
-        task.after_return("SUCCESS", 1.0, "foobar", (), {}, None)
-        task.request.clear()
-
-    def test_send_task_sent_event(self):
-        T1 = self.createTask("c.unittest.t.t1")
-        app = T1.app
-        conn = app.broker_connection()
-        chan = conn.channel()
-        app.conf.CELERY_SEND_TASK_SENT_EVENT = True
-        dispatcher = [None]
-
-        class Pub(object):
-            channel = chan
-
-            def delay_task(self, *args, **kwargs):
-                dispatcher[0] = kwargs.get("event_dispatcher")
-
-        try:
-            T1.apply_async(publisher=Pub())
-        finally:
-            app.conf.CELERY_SEND_TASK_SENT_EVENT = False
-            chan.close()
-            conn.close()
-
-        self.assertTrue(dispatcher[0])
-
-    def test_get_publisher(self):
-        connection = app_or_default().broker_connection()
-        p = increment_counter.get_publisher(connection, auto_declare=False,
-                                            exchange="foo")
-        self.assertEqual(p.exchange.name, "foo")
-        p = increment_counter.get_publisher(connection, auto_declare=False,
-                                            exchange_type="fanout")
-        self.assertEqual(p.exchange.type, "fanout")
-
-    def test_update_state(self):
-
-        @task_dec
-        def yyy():
-            pass
-
-        tid = uuid()
-        yyy.update_state(tid, "FROBULATING", {"fooz": "baaz"})
-        self.assertEqual(yyy.AsyncResult(tid).status, "FROBULATING")
-        self.assertDictEqual(yyy.AsyncResult(tid).result, {"fooz": "baaz"})
-
-        yyy.request.id = tid
-        yyy.update_state(state="FROBUZATING", meta={"fooz": "baaz"})
-        self.assertEqual(yyy.AsyncResult(tid).status, "FROBUZATING")
-        self.assertDictEqual(yyy.AsyncResult(tid).result, {"fooz": "baaz"})
-
-    def test_repr(self):
-
-        @task_dec
-        def task_test_repr():
-            pass
-
-        self.assertIn("task_test_repr", repr(task_test_repr))
-
-    def test_has___name__(self):
-
-        @task_dec
-        def yyy2():
-            pass
-
-        self.assertTrue(yyy2.__name__)
-
-    def test_get_logger(self):
-        t1 = self.createTask("c.unittest.t.t1")
-        logfh = WhateverIO()
-        logger = t1.get_logger(logfile=logfh, loglevel=0)
-        self.assertTrue(logger)
-
-        t1.request.loglevel = 3
-        logger = t1.get_logger(logfile=logfh, loglevel=None)
-        self.assertTrue(logger)
-
-
-class test_TaskSet(Case):
-
-    @with_eager_tasks
-    def test_function_taskset(self):
-        subtasks = [return_True_task.s(i) for i in range(1, 6)]
-        ts = task.TaskSet(subtasks)
-        res = ts.apply_async()
-        self.assertListEqual(res.join(), [True, True, True, True, True])
-
-    def test_counter_taskset(self):
-        increment_counter.count = 0
-        ts = task.TaskSet(tasks=[
-            increment_counter.s(),
-            increment_counter.s(increment_by=2),
-            increment_counter.s(increment_by=3),
-            increment_counter.s(increment_by=4),
-            increment_counter.s(increment_by=5),
-            increment_counter.s(increment_by=6),
-            increment_counter.s(increment_by=7),
-            increment_counter.s(increment_by=8),
-            increment_counter.s(increment_by=9),
-        ])
-        self.assertEqual(ts.total, 9)
-
-        consumer = increment_counter.get_consumer()
-        consumer.purge()
-        consumer.close()
-        taskset_res = ts.apply_async()
-        subtasks = taskset_res.subtasks
-        taskset_id = taskset_res.taskset_id
-        consumer = increment_counter.get_consumer()
-        for subtask in subtasks:
-            m = consumer.fetch().payload
-            self.assertDictContainsSubset({"taskset": taskset_id,
-                                           "task": increment_counter.name,
-                                           "id": subtask.id}, m)
-            increment_counter(
-                    increment_by=m.get("kwargs", {}).get("increment_by"))
-        self.assertEqual(increment_counter.count, sum(xrange(1, 10)))
-
-    def test_named_taskset(self):
-        prefix = "test_named_taskset-"
-        ts = task.TaskSet([return_True_task.subtask([1])])
-        res = ts.apply(taskset_id=prefix + uuid())
-        self.assertTrue(res.taskset_id.startswith(prefix))
-
-
-class test_apply_task(Case):
-
-    def test_apply_throw(self):
-        with self.assertRaises(KeyError):
-            raising.apply(throw=True)
-
-    def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
-        raising.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
-        try:
-            with self.assertRaises(KeyError):
-                raising.apply()
-        finally:
-            raising.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = False
-
-    def test_apply(self):
-        increment_counter.count = 0
-
-        e = increment_counter.apply()
-        self.assertIsInstance(e, EagerResult)
-        self.assertEqual(e.get(), 1)
-
-        e = increment_counter.apply(args=[1])
-        self.assertEqual(e.get(), 2)
-
-        e = increment_counter.apply(kwargs={"increment_by": 4})
-        self.assertEqual(e.get(), 6)
-
-        self.assertTrue(e.successful())
-        self.assertTrue(e.ready())
-        self.assertTrue(repr(e).startswith("<EagerResult:"))
-
-        f = raising.apply()
-        self.assertTrue(f.ready())
-        self.assertFalse(f.successful())
-        self.assertTrue(f.traceback)
-        with self.assertRaises(KeyError):
-            f.get()
-
-
-@task.periodic_task(run_every=timedelta(hours=1))
-def my_periodic():
-    pass
-
-
-class test_periodic_tasks(Case):
-
-    def test_must_have_run_every(self):
-        with self.assertRaises(NotImplementedError):
-            type("Foo", (task.PeriodicTask, ), {"__module__": __name__})
-
-    def test_remaining_estimate(self):
-        self.assertIsInstance(
-            my_periodic.run_every.remaining_estimate(datetime.utcnow()),
-            timedelta)
-
-    def test_is_due_not_due(self):
-        due, remaining = my_periodic.run_every.is_due(datetime.utcnow())
-        self.assertFalse(due)
-        # This assertion may fail if executed in the
-        # first minute of an hour, thus 59 instead of 60
-        self.assertGreater(remaining, 59)
-
-    def test_is_due(self):
-        p = my_periodic
-        due, remaining = p.run_every.is_due(
-                datetime.utcnow() - p.run_every.run_every)
-        self.assertTrue(due)
-        self.assertEqual(remaining,
-                         timedelta_seconds(p.run_every.run_every))
-
-    def test_schedule_repr(self):
-        p = my_periodic
-        self.assertTrue(repr(p.run_every))
-
-
-@task.periodic_task(run_every=crontab())
-def every_minute():
-    pass
-
-
-@task.periodic_task(run_every=crontab(minute="*/15"))
-def quarterly():
-    pass
-
-
-@task.periodic_task(run_every=crontab(minute=30))
-def hourly():
-    pass
-
-
-@task.periodic_task(run_every=crontab(hour=7, minute=30))
-def daily():
-    pass
-
-
-@task.periodic_task(run_every=crontab(hour=7, minute=30,
-                                      day_of_week="thursday"))
-def weekly():
-    pass
-
-
-def patch_crontab_nowfun(cls, retval):
-
-    def create_patcher(fun):
-
-        @wraps(fun)
-        def __inner(*args, **kwargs):
-            prev_nowfun = cls.run_every.nowfun
-            cls.run_every.nowfun = lambda: retval
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                cls.run_every.nowfun = prev_nowfun
-
-        return __inner
-
-    return create_patcher
-
-
-class test_crontab_parser(Case):
-
-    def test_parse_star(self):
-        self.assertEqual(crontab_parser(24).parse('*'), set(range(24)))
-        self.assertEqual(crontab_parser(60).parse('*'), set(range(60)))
-        self.assertEqual(crontab_parser(7).parse('*'), set(range(7)))
-
-    def test_parse_range(self):
-        self.assertEqual(crontab_parser(60).parse('1-10'),
-                          set(range(1, 10 + 1)))
-        self.assertEqual(crontab_parser(24).parse('0-20'),
-                          set(range(0, 20 + 1)))
-        self.assertEqual(crontab_parser().parse('2-10'),
-                          set(range(2, 10 + 1)))
-
-    def test_parse_groups(self):
-        self.assertEqual(crontab_parser().parse('1,2,3,4'),
-                          set([1, 2, 3, 4]))
-        self.assertEqual(crontab_parser().parse('0,15,30,45'),
-                          set([0, 15, 30, 45]))
-
-    def test_parse_steps(self):
-        self.assertEqual(crontab_parser(8).parse('*/2'),
-                          set([0, 2, 4, 6]))
-        self.assertEqual(crontab_parser().parse('*/2'),
-                          set(i * 2 for i in xrange(30)))
-        self.assertEqual(crontab_parser().parse('*/3'),
-                          set(i * 3 for i in xrange(20)))
-
-    def test_parse_composite(self):
-        self.assertEqual(crontab_parser(8).parse('*/2'), set([0, 2, 4, 6]))
-        self.assertEqual(crontab_parser().parse('2-9/5'), set([2, 7]))
-        self.assertEqual(crontab_parser().parse('2-10/5'), set([2, 7]))
-        self.assertEqual(crontab_parser().parse('2-11/5,3'), set([2, 3, 7]))
-        self.assertEqual(crontab_parser().parse('2-4/3,*/5,0-21/4'),
-                set([0, 2, 4, 5, 8, 10, 12, 15, 16,
-                     20, 25, 30, 35, 40, 45, 50, 55]))
-        self.assertEqual(crontab_parser().parse('1-9/2'),
-                set([1, 3, 5, 7, 9]))
-
-    def test_parse_errors_on_empty_string(self):
-        with self.assertRaises(ParseException):
-            crontab_parser(60).parse('')
-
-    def test_parse_errors_on_empty_group(self):
-        with self.assertRaises(ParseException):
-            crontab_parser(60).parse('1,,2')
-
-    def test_parse_errors_on_empty_steps(self):
-        with self.assertRaises(ParseException):
-            crontab_parser(60).parse('*/')
-
-    def test_parse_errors_on_negative_number(self):
-        with self.assertRaises(ParseException):
-            crontab_parser(60).parse('-20')
-
-    def test_expand_cronspec_eats_iterables(self):
-        self.assertEqual(crontab._expand_cronspec(iter([1, 2, 3]), 100),
-                         set([1, 2, 3]))
-
-    def test_expand_cronspec_invalid_type(self):
-        with self.assertRaises(TypeError):
-            crontab._expand_cronspec(object(), 100)
-
-    def test_repr(self):
-        self.assertIn("*", repr(crontab("*")))
-
-    def test_eq(self):
-        self.assertEqual(crontab(day_of_week="1, 2"),
-                         crontab(day_of_week="1-2"))
-        self.assertEqual(crontab(minute="1", hour="2", day_of_week="5"),
-                         crontab(minute="1", hour="2", day_of_week="5"))
-        self.assertNotEqual(crontab(minute="1"), crontab(minute="2"))
-        self.assertFalse(object() == crontab(minute="1"))
-        self.assertFalse(crontab(minute="1") == object())
-
-
-class test_crontab_remaining_estimate(Case):
-
-    def next_ocurrance(self, crontab, now):
-        crontab.nowfun = lambda: now
-        return now + crontab.remaining_estimate(now)
-
-    def test_next_minute(self):
-        next = self.next_ocurrance(crontab(),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 14, 31))
-
-    def test_not_next_minute(self):
-        next = self.next_ocurrance(crontab(),
-                                   datetime(2010, 9, 11, 14, 59, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 15, 0))
-
-    def test_this_hour(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 14, 42))
-
-    def test_not_this_hour(self):
-        next = self.next_ocurrance(crontab(minute=[5, 10, 15]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 15, 5))
-
-    def test_today(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42], hour=[12, 17]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 11, 17, 5))
-
-    def test_not_today(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42], hour=[12]),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 12, 12, 5))
-
-    def test_weekday(self):
-        next = self.next_ocurrance(crontab(minute=30,
-                                           hour=14,
-                                           day_of_week="sat"),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 18, 14, 30))
-
-    def test_not_weekday(self):
-        next = self.next_ocurrance(crontab(minute=[5, 42],
-                                           day_of_week="mon-fri"),
-                                   datetime(2010, 9, 11, 14, 30, 15))
-        self.assertEqual(next, datetime(2010, 9, 13, 0, 5))
-
-
-class test_crontab_is_due(Case):
-
-    def setUp(self):
-        self.now = datetime.utcnow()
-        self.next_minute = 60 - self.now.second - 1e-6 * self.now.microsecond
-
-    def test_default_crontab_spec(self):
-        c = crontab()
-        self.assertEqual(c.minute, set(range(60)))
-        self.assertEqual(c.hour, set(range(24)))
-        self.assertEqual(c.day_of_week, set(range(7)))
-
-    def test_simple_crontab_spec(self):
-        c = crontab(minute=30)
-        self.assertEqual(c.minute, set([30]))
-        self.assertEqual(c.hour, set(range(24)))
-        self.assertEqual(c.day_of_week, set(range(7)))
-
-    def test_crontab_spec_minute_formats(self):
-        c = crontab(minute=30)
-        self.assertEqual(c.minute, set([30]))
-        c = crontab(minute='30')
-        self.assertEqual(c.minute, set([30]))
-        c = crontab(minute=(30, 40, 50))
-        self.assertEqual(c.minute, set([30, 40, 50]))
-        c = crontab(minute=set([30, 40, 50]))
-        self.assertEqual(c.minute, set([30, 40, 50]))
-
-    def test_crontab_spec_invalid_minute(self):
-        with self.assertRaises(ValueError):
-            crontab(minute=60)
-        with self.assertRaises(ValueError):
-            crontab(minute='0-100')
-
-    def test_crontab_spec_hour_formats(self):
-        c = crontab(hour=6)
-        self.assertEqual(c.hour, set([6]))
-        c = crontab(hour='5')
-        self.assertEqual(c.hour, set([5]))
-        c = crontab(hour=(4, 8, 12))
-        self.assertEqual(c.hour, set([4, 8, 12]))
-
-    def test_crontab_spec_invalid_hour(self):
-        with self.assertRaises(ValueError):
-            crontab(hour=24)
-        with self.assertRaises(ValueError):
-            crontab(hour='0-30')
-
-    def test_crontab_spec_dow_formats(self):
-        c = crontab(day_of_week=5)
-        self.assertEqual(c.day_of_week, set([5]))
-        c = crontab(day_of_week='5')
-        self.assertEqual(c.day_of_week, set([5]))
-        c = crontab(day_of_week='fri')
-        self.assertEqual(c.day_of_week, set([5]))
-        c = crontab(day_of_week='tuesday,sunday,fri')
-        self.assertEqual(c.day_of_week, set([0, 2, 5]))
-        c = crontab(day_of_week='mon-fri')
-        self.assertEqual(c.day_of_week, set([1, 2, 3, 4, 5]))
-        c = crontab(day_of_week='*/2')
-        self.assertEqual(c.day_of_week, set([0, 2, 4, 6]))
-
-    def seconds_almost_equal(self, a, b, precision):
-        for index, skew in enumerate((+0.1, 0, -0.1)):
-            try:
-                self.assertAlmostEqual(a, b + skew, precision)
-            except AssertionError:
-                if index + 1 >= 3:
-                    raise
-            else:
-                break
-
-    def test_crontab_spec_invalid_dow(self):
-        with self.assertRaises(ValueError):
-            crontab(day_of_week='fooday-barday')
-        with self.assertRaises(ValueError):
-            crontab(day_of_week='1,4,foo')
-        with self.assertRaises(ValueError):
-            crontab(day_of_week='7')
-        with self.assertRaises(ValueError):
-            crontab(day_of_week='12')
-
-    def test_every_minute_execution_is_due(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertTrue(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    def test_every_minute_execution_is_not_due(self):
-        last_ran = self.now - timedelta(seconds=self.now.second)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertFalse(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    # 29th of May 2010 is a saturday
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 29, 10, 30))
-    def test_execution_is_due_on_saturday(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertTrue(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    # 30th of May 2010 is a sunday
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 30, 10, 30))
-    def test_execution_is_due_on_sunday(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertTrue(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    # 31st of May 2010 is a monday
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 31, 10, 30))
-    def test_execution_is_due_on_monday(self):
-        last_ran = self.now - timedelta(seconds=61)
-        due, remaining = every_minute.run_every.is_due(last_ran)
-        self.assertTrue(due)
-        self.seconds_almost_equal(remaining, self.next_minute, 1)
-
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 10, 10, 30))
-    def test_every_hour_execution_is_due(self):
-        due, remaining = hourly.run_every.is_due(
-                datetime(2010, 5, 10, 6, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 60 * 60)
-
-    @patch_crontab_nowfun(hourly, datetime(2010, 5, 10, 10, 29))
-    def test_every_hour_execution_is_not_due(self):
-        due, remaining = hourly.run_every.is_due(
-                datetime(2010, 5, 10, 9, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 60)
-
-    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 15))
-    def test_first_quarter_execution_is_due(self):
-        due, remaining = quarterly.run_every.is_due(
-                            datetime(2010, 5, 10, 6, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 15 * 60)
-
-    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 30))
-    def test_second_quarter_execution_is_due(self):
-        due, remaining = quarterly.run_every.is_due(
-                            datetime(2010, 5, 10, 6, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 15 * 60)
-
-    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 14))
-    def test_first_quarter_execution_is_not_due(self):
-        due, remaining = quarterly.run_every.is_due(
-                            datetime(2010, 5, 10, 10, 0))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 60)
-
-    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 29))
-    def test_second_quarter_execution_is_not_due(self):
-        due, remaining = quarterly.run_every.is_due(
-                            datetime(2010, 5, 10, 10, 15))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 60)
-
-    @patch_crontab_nowfun(daily, datetime(2010, 5, 10, 7, 30))
-    def test_daily_execution_is_due(self):
-        due, remaining = daily.run_every.is_due(
-                datetime(2010, 5, 9, 7, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 24 * 60 * 60)
-
-    @patch_crontab_nowfun(daily, datetime(2010, 5, 10, 10, 30))
-    def test_daily_execution_is_not_due(self):
-        due, remaining = daily.run_every.is_due(
-                datetime(2010, 5, 10, 7, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 21 * 60 * 60)
-
-    @patch_crontab_nowfun(weekly, datetime(2010, 5, 6, 7, 30))
-    def test_weekly_execution_is_due(self):
-        due, remaining = weekly.run_every.is_due(
-                datetime(2010, 4, 30, 7, 30))
-        self.assertTrue(due)
-        self.assertEqual(remaining, 7 * 24 * 60 * 60)
-
-    @patch_crontab_nowfun(weekly, datetime(2010, 5, 7, 10, 30))
-    def test_weekly_execution_is_not_due(self):
-        due, remaining = weekly.run_every.is_due(
-                datetime(2010, 5, 6, 7, 30))
-        self.assertFalse(due)
-        self.assertEqual(remaining, 6 * 24 * 60 * 60 - 3 * 60 * 60)

+ 865 - 0
celery/tests/tasks/test_tasks.py

@@ -0,0 +1,865 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+from datetime import datetime, timedelta
+from functools import wraps
+
+from celery import task
+from celery.task import current
+from celery.app import app_or_default
+from celery.task import task as task_dec
+from celery.exceptions import RetryTaskError
+from celery.execute import send_task
+from celery.result import EagerResult
+from celery.schedules import crontab, crontab_parser, ParseException
+from celery.utils import uuid
+from celery.utils.timeutils import parse_iso8601, timedelta_seconds
+
+from celery.tests.utils import Case, with_eager_tasks, WhateverIO
+
+
+def return_True(*args, **kwargs):
+    # Task run functions can't be closures/lambdas, as they're pickled.
+    return True
+
+
+return_True_task = task_dec()(return_True)
+
+
+def raise_exception(self, **kwargs):
+    raise Exception("%s error" % self.__class__)
+
+
+class MockApplyTask(task.Task):
+    applied = 0
+
+    def run(self, x, y):
+        return x * y
+
+    @classmethod
+    def apply_async(self, *args, **kwargs):
+        self.applied += 1
+
+
+@task.task(name="c.unittest.increment_counter_task", count=0)
+def increment_counter(increment_by=1):
+    increment_counter.count += increment_by or 1
+    return increment_counter.count
+
+
+@task.task(name="c.unittest.raising_task")
+def raising():
+    raise KeyError("foo")
+
+
+@task.task(max_retries=3, iterations=0)
+def retry_task(arg1, arg2, kwarg=1, max_retries=None, care=True):
+    current.iterations += 1
+    rmax = current.max_retries if max_retries is None else max_retries
+
+    retries = current.request.retries
+    if care and retries >= rmax:
+        return arg1
+    else:
+        return current.retry(countdown=0, max_retries=rmax)
+
+
+@task.task(max_retries=3, iterations=0)
+def retry_task_noargs(**kwargs):
+    current.iterations += 1
+
+    retries = kwargs["task_retries"]
+    if retries >= 3:
+        return 42
+    else:
+        return current.retry(countdown=0)
+
+
+@task.task(max_retries=3, iterations=0, base=MockApplyTask)
+def retry_task_mockapply(arg1, arg2, kwarg=1, **kwargs):
+    current.iterations += 1
+
+    retries = kwargs["task_retries"]
+    if retries >= 3:
+        return arg1
+    else:
+        kwargs.update(kwarg=kwarg)
+    return current.retry(countdown=0)
+
+
+class MyCustomException(Exception):
+    """Random custom exception."""
+
+
+@task.task(max_retries=3, iterations=0, accept_magic_kwargs=True)
+def retry_task_customexc(arg1, arg2, kwarg=1, **kwargs):
+    current.iterations += 1
+
+    retries = kwargs["task_retries"]
+    if retries >= 3:
+        return arg1 + kwarg
+    else:
+        try:
+            raise MyCustomException("Elaine Marie Benes")
+        except MyCustomException, exc:
+            kwargs.update(kwarg=kwarg)
+            return current.retry(countdown=0, exc=exc)
+
+
+class test_task_retries(Case):
+
+    def test_retry(self):
+        retry_task.__class__.max_retries = 3
+        retry_task.iterations = 0
+        result = retry_task.apply([0xFF, 0xFFFF])
+        self.assertEqual(result.get(), 0xFF)
+        self.assertEqual(retry_task.iterations, 4)
+
+        retry_task.__class__.max_retries = 3
+        retry_task.iterations = 0
+        result = retry_task.apply([0xFF, 0xFFFF], {"max_retries": 10})
+        self.assertEqual(result.get(), 0xFF)
+        self.assertEqual(retry_task.iterations, 11)
+
+    def test_retry_no_args(self):
+        retry_task_noargs.__class__.max_retries = 3
+        retry_task_noargs.iterations = 0
+        result = retry_task_noargs.apply()
+        self.assertEqual(result.get(), 42)
+        self.assertEqual(retry_task_noargs.iterations, 4)
+
+    def test_retry_kwargs_can_be_empty(self):
+        with self.assertRaises(RetryTaskError):
+            retry_task_mockapply.retry(args=[4, 4], kwargs=None)
+
+    def test_retry_not_eager(self):
+        retry_task_mockapply.request.called_directly = False
+        exc = Exception("baz")
+        try:
+            retry_task_mockapply.retry(args=[4, 4], kwargs={"task_retries": 0},
+                                       exc=exc, throw=False)
+            self.assertTrue(retry_task_mockapply.__class__.applied)
+        finally:
+            retry_task_mockapply.__class__.applied = 0
+
+        try:
+            with self.assertRaises(RetryTaskError):
+                retry_task_mockapply.retry(
+                    args=[4, 4], kwargs={"task_retries": 0},
+                    exc=exc, throw=True)
+            self.assertTrue(retry_task_mockapply.__class__.applied)
+        finally:
+            retry_task_mockapply.__class__.applied = 0
+
+    def test_retry_with_kwargs(self):
+        retry_task_customexc.__class__.max_retries = 3
+        retry_task_customexc.iterations = 0
+        result = retry_task_customexc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
+        self.assertEqual(result.get(), 0xFF + 0xF)
+        self.assertEqual(retry_task_customexc.iterations, 4)
+
+    def test_retry_with_custom_exception(self):
+        retry_task_customexc.__class__.max_retries = 2
+        retry_task_customexc.iterations = 0
+        result = retry_task_customexc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
+        with self.assertRaises(MyCustomException):
+            result.get()
+        self.assertEqual(retry_task_customexc.iterations, 3)
+
+    def test_max_retries_exceeded(self):
+        retry_task.__class__.max_retries = 2
+        retry_task.iterations = 0
+        result = retry_task.apply([0xFF, 0xFFFF], {"care": False})
+        with self.assertRaises(retry_task.MaxRetriesExceededError):
+            result.get()
+        self.assertEqual(retry_task.iterations, 3)
+
+        retry_task.__class__.max_retries = 1
+        retry_task.iterations = 0
+        result = retry_task.apply([0xFF, 0xFFFF], {"care": False})
+        with self.assertRaises(retry_task.MaxRetriesExceededError):
+            result.get()
+        self.assertEqual(retry_task.iterations, 2)
+
+
+class test_tasks(Case):
+
+    def test_unpickle_task(self):
+        import pickle
+
+        @task_dec
+        def xxx():
+            pass
+
+        self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
+
+    def createTask(self, name):
+        return task.task(__module__=self.__module__, name=name)(return_True)
+
+    def test_AsyncResult(self):
+        task_id = uuid()
+        result = retry_task.AsyncResult(task_id)
+        self.assertEqual(result.backend, retry_task.backend)
+        self.assertEqual(result.id, task_id)
+
+    def assertNextTaskDataEqual(self, consumer, presult, task_name,
+            test_eta=False, test_expires=False, **kwargs):
+        next_task = consumer.fetch()
+        task_data = next_task.decode()
+        self.assertEqual(task_data["id"], presult.id)
+        self.assertEqual(task_data["task"], task_name)
+        task_kwargs = task_data.get("kwargs", {})
+        if test_eta:
+            self.assertIsInstance(task_data.get("eta"), basestring)
+            to_datetime = parse_iso8601(task_data.get("eta"))
+            self.assertIsInstance(to_datetime, datetime)
+        if test_expires:
+            self.assertIsInstance(task_data.get("expires"), basestring)
+            to_datetime = parse_iso8601(task_data.get("expires"))
+            self.assertIsInstance(to_datetime, datetime)
+        for arg_name, arg_value in kwargs.items():
+            self.assertEqual(task_kwargs.get(arg_name), arg_value)
+
+    def test_incomplete_task_cls(self):
+
+        class IncompleteTask(task.Task):
+            name = "c.unittest.t.itask"
+
+        with self.assertRaises(NotImplementedError):
+            IncompleteTask().run()
+
+    def test_task_kwargs_must_be_dictionary(self):
+        with self.assertRaises(ValueError):
+            increment_counter.apply_async([], "str")
+
+    def test_task_args_must_be_list(self):
+        with self.assertRaises(ValueError):
+            increment_counter.apply_async("str", {})
+
+    def test_regular_task(self):
+        T1 = self.createTask("c.unittest.t.t1")
+        self.assertIsInstance(T1, task.BaseTask)
+        self.assertTrue(T1.run())
+        self.assertTrue(callable(T1),
+                "Task class is callable()")
+        self.assertTrue(T1(),
+                "Task class runs run() when called")
+
+        consumer = T1.get_consumer()
+        with self.assertRaises(NotImplementedError):
+            consumer.receive("foo", "foo")
+        consumer.discard_all()
+        self.assertIsNone(consumer.fetch())
+
+        # Without arguments.
+        presult = T1.delay()
+        self.assertNextTaskDataEqual(consumer, presult, T1.name)
+
+        # With arguments.
+        presult2 = T1.apply_async(kwargs=dict(name="George Costanza"))
+        self.assertNextTaskDataEqual(consumer, presult2, T1.name,
+                name="George Costanza")
+
+        # send_task
+        sresult = send_task(T1.name, kwargs=dict(name="Elaine M. Benes"))
+        self.assertNextTaskDataEqual(consumer, sresult, T1.name,
+                name="Elaine M. Benes")
+
+        # With eta.
+        presult2 = T1.apply_async(kwargs=dict(name="George Costanza"),
+                            eta=datetime.utcnow() + timedelta(days=1),
+                            expires=datetime.utcnow() + timedelta(days=2))
+        self.assertNextTaskDataEqual(consumer, presult2, T1.name,
+                name="George Costanza", test_eta=True, test_expires=True)
+
+        # With countdown.
+        presult2 = T1.apply_async(kwargs=dict(name="George Costanza"),
+                                  countdown=10, expires=12)
+        self.assertNextTaskDataEqual(consumer, presult2, T1.name,
+                name="George Costanza", test_eta=True, test_expires=True)
+
+        # Discarding all tasks.
+        consumer.discard_all()
+        T1.apply_async()
+        self.assertEqual(consumer.discard_all(), 1)
+        self.assertIsNone(consumer.fetch())
+
+        self.assertFalse(presult.successful())
+        T1.backend.mark_as_done(presult.id, result=None)
+        self.assertTrue(presult.successful())
+
+        publisher = T1.get_publisher()
+        self.assertTrue(publisher.exchange)
+
+    def test_context_get(self):
+        request = self.createTask("c.unittest.t.c.g").request
+        request.foo = 32
+        self.assertEqual(request.get("foo"), 32)
+        self.assertEqual(request.get("bar", 36), 36)
+        request.clear()
+
+    def test_task_class_repr(self):
+        task = self.createTask("c.unittest.t.repr")
+        self.assertIn("class Task of", repr(task.app.Task))
+
+    def test_after_return(self):
+        task = self.createTask("c.unittest.t.after_return")
+        task.request.chord = return_True_task.s()
+        task.after_return("SUCCESS", 1.0, "foobar", (), {}, None)
+        task.request.clear()
+
+    def test_send_task_sent_event(self):
+        T1 = self.createTask("c.unittest.t.t1")
+        app = T1.app
+        conn = app.broker_connection()
+        chan = conn.channel()
+        app.conf.CELERY_SEND_TASK_SENT_EVENT = True
+        dispatcher = [None]
+
+        class Pub(object):
+            channel = chan
+
+            def delay_task(self, *args, **kwargs):
+                dispatcher[0] = kwargs.get("event_dispatcher")
+
+        try:
+            T1.apply_async(publisher=Pub())
+        finally:
+            app.conf.CELERY_SEND_TASK_SENT_EVENT = False
+            chan.close()
+            conn.close()
+
+        self.assertTrue(dispatcher[0])
+
+    def test_get_publisher(self):
+        connection = app_or_default().broker_connection()
+        p = increment_counter.get_publisher(connection, auto_declare=False,
+                                            exchange="foo")
+        self.assertEqual(p.exchange.name, "foo")
+        p = increment_counter.get_publisher(connection, auto_declare=False,
+                                            exchange_type="fanout")
+        self.assertEqual(p.exchange.type, "fanout")
+
+    def test_update_state(self):
+
+        @task_dec
+        def yyy():
+            pass
+
+        tid = uuid()
+        yyy.update_state(tid, "FROBULATING", {"fooz": "baaz"})
+        self.assertEqual(yyy.AsyncResult(tid).status, "FROBULATING")
+        self.assertDictEqual(yyy.AsyncResult(tid).result, {"fooz": "baaz"})
+
+        yyy.request.id = tid
+        yyy.update_state(state="FROBUZATING", meta={"fooz": "baaz"})
+        self.assertEqual(yyy.AsyncResult(tid).status, "FROBUZATING")
+        self.assertDictEqual(yyy.AsyncResult(tid).result, {"fooz": "baaz"})
+
+    def test_repr(self):
+
+        @task_dec
+        def task_test_repr():
+            pass
+
+        self.assertIn("task_test_repr", repr(task_test_repr))
+
+    def test_has___name__(self):
+
+        @task_dec
+        def yyy2():
+            pass
+
+        self.assertTrue(yyy2.__name__)
+
+    def test_get_logger(self):
+        t1 = self.createTask("c.unittest.t.t1")
+        logfh = WhateverIO()
+        logger = t1.get_logger(logfile=logfh, loglevel=0)
+        self.assertTrue(logger)
+
+        t1.request.loglevel = 3
+        logger = t1.get_logger(logfile=logfh, loglevel=None)
+        self.assertTrue(logger)
+
+
+class test_TaskSet(Case):
+
+    @with_eager_tasks
+    def test_function_taskset(self):
+        subtasks = [return_True_task.s(i) for i in range(1, 6)]
+        ts = task.TaskSet(subtasks)
+        res = ts.apply_async()
+        self.assertListEqual(res.join(), [True, True, True, True, True])
+
+    def test_counter_taskset(self):
+        increment_counter.count = 0
+        ts = task.TaskSet(tasks=[
+            increment_counter.s(),
+            increment_counter.s(increment_by=2),
+            increment_counter.s(increment_by=3),
+            increment_counter.s(increment_by=4),
+            increment_counter.s(increment_by=5),
+            increment_counter.s(increment_by=6),
+            increment_counter.s(increment_by=7),
+            increment_counter.s(increment_by=8),
+            increment_counter.s(increment_by=9),
+        ])
+        self.assertEqual(ts.total, 9)
+
+        consumer = increment_counter.get_consumer()
+        consumer.purge()
+        consumer.close()
+        taskset_res = ts.apply_async()
+        subtasks = taskset_res.subtasks
+        taskset_id = taskset_res.taskset_id
+        consumer = increment_counter.get_consumer()
+        for subtask in subtasks:
+            m = consumer.fetch().payload
+            self.assertDictContainsSubset({"taskset": taskset_id,
+                                           "task": increment_counter.name,
+                                           "id": subtask.id}, m)
+            increment_counter(
+                    increment_by=m.get("kwargs", {}).get("increment_by"))
+        self.assertEqual(increment_counter.count, sum(xrange(1, 10)))
+
+    def test_named_taskset(self):
+        prefix = "test_named_taskset-"
+        ts = task.TaskSet([return_True_task.subtask([1])])
+        res = ts.apply(taskset_id=prefix + uuid())
+        self.assertTrue(res.taskset_id.startswith(prefix))
+
+
+class test_apply_task(Case):
+
+    def test_apply_throw(self):
+        with self.assertRaises(KeyError):
+            raising.apply(throw=True)
+
+    def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
+        raising.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
+        try:
+            with self.assertRaises(KeyError):
+                raising.apply()
+        finally:
+            raising.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = False
+
+    def test_apply(self):
+        increment_counter.count = 0
+
+        e = increment_counter.apply()
+        self.assertIsInstance(e, EagerResult)
+        self.assertEqual(e.get(), 1)
+
+        e = increment_counter.apply(args=[1])
+        self.assertEqual(e.get(), 2)
+
+        e = increment_counter.apply(kwargs={"increment_by": 4})
+        self.assertEqual(e.get(), 6)
+
+        self.assertTrue(e.successful())
+        self.assertTrue(e.ready())
+        self.assertTrue(repr(e).startswith("<EagerResult:"))
+
+        f = raising.apply()
+        self.assertTrue(f.ready())
+        self.assertFalse(f.successful())
+        self.assertTrue(f.traceback)
+        with self.assertRaises(KeyError):
+            f.get()
+
+
+@task.periodic_task(run_every=timedelta(hours=1))
+def my_periodic():
+    pass
+
+
+class test_periodic_tasks(Case):
+
+    def test_must_have_run_every(self):
+        with self.assertRaises(NotImplementedError):
+            type("Foo", (task.PeriodicTask, ), {"__module__": __name__})
+
+    def test_remaining_estimate(self):
+        self.assertIsInstance(
+            my_periodic.run_every.remaining_estimate(datetime.utcnow()),
+            timedelta)
+
+    def test_is_due_not_due(self):
+        due, remaining = my_periodic.run_every.is_due(datetime.utcnow())
+        self.assertFalse(due)
+        # This assertion may fail if executed in the
+        # first minute of an hour, thus 59 instead of 60
+        self.assertGreater(remaining, 59)
+
+    def test_is_due(self):
+        p = my_periodic
+        due, remaining = p.run_every.is_due(
+                datetime.utcnow() - p.run_every.run_every)
+        self.assertTrue(due)
+        self.assertEqual(remaining,
+                         timedelta_seconds(p.run_every.run_every))
+
+    def test_schedule_repr(self):
+        p = my_periodic
+        self.assertTrue(repr(p.run_every))
+
+
+@task.periodic_task(run_every=crontab())
+def every_minute():
+    pass
+
+
+@task.periodic_task(run_every=crontab(minute="*/15"))
+def quarterly():
+    pass
+
+
+@task.periodic_task(run_every=crontab(minute=30))
+def hourly():
+    pass
+
+
+@task.periodic_task(run_every=crontab(hour=7, minute=30))
+def daily():
+    pass
+
+
+@task.periodic_task(run_every=crontab(hour=7, minute=30,
+                                      day_of_week="thursday"))
+def weekly():
+    pass
+
+
+def patch_crontab_nowfun(cls, retval):
+
+    def create_patcher(fun):
+
+        @wraps(fun)
+        def __inner(*args, **kwargs):
+            prev_nowfun = cls.run_every.nowfun
+            cls.run_every.nowfun = lambda: retval
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                cls.run_every.nowfun = prev_nowfun
+
+        return __inner
+
+    return create_patcher
+
+
+class test_crontab_parser(Case):
+
+    def test_parse_star(self):
+        self.assertEqual(crontab_parser(24).parse('*'), set(range(24)))
+        self.assertEqual(crontab_parser(60).parse('*'), set(range(60)))
+        self.assertEqual(crontab_parser(7).parse('*'), set(range(7)))
+
+    def test_parse_range(self):
+        self.assertEqual(crontab_parser(60).parse('1-10'),
+                          set(range(1, 10 + 1)))
+        self.assertEqual(crontab_parser(24).parse('0-20'),
+                          set(range(0, 20 + 1)))
+        self.assertEqual(crontab_parser().parse('2-10'),
+                          set(range(2, 10 + 1)))
+
+    def test_parse_groups(self):
+        self.assertEqual(crontab_parser().parse('1,2,3,4'),
+                          set([1, 2, 3, 4]))
+        self.assertEqual(crontab_parser().parse('0,15,30,45'),
+                          set([0, 15, 30, 45]))
+
+    def test_parse_steps(self):
+        self.assertEqual(crontab_parser(8).parse('*/2'),
+                          set([0, 2, 4, 6]))
+        self.assertEqual(crontab_parser().parse('*/2'),
+                          set(i * 2 for i in xrange(30)))
+        self.assertEqual(crontab_parser().parse('*/3'),
+                          set(i * 3 for i in xrange(20)))
+
+    def test_parse_composite(self):
+        self.assertEqual(crontab_parser(8).parse('*/2'), set([0, 2, 4, 6]))
+        self.assertEqual(crontab_parser().parse('2-9/5'), set([2, 7]))
+        self.assertEqual(crontab_parser().parse('2-10/5'), set([2, 7]))
+        self.assertEqual(crontab_parser().parse('2-11/5,3'), set([2, 3, 7]))
+        self.assertEqual(crontab_parser().parse('2-4/3,*/5,0-21/4'),
+                set([0, 2, 4, 5, 8, 10, 12, 15, 16,
+                     20, 25, 30, 35, 40, 45, 50, 55]))
+        self.assertEqual(crontab_parser().parse('1-9/2'),
+                set([1, 3, 5, 7, 9]))
+
+    def test_parse_errors_on_empty_string(self):
+        with self.assertRaises(ParseException):
+            crontab_parser(60).parse('')
+
+    def test_parse_errors_on_empty_group(self):
+        with self.assertRaises(ParseException):
+            crontab_parser(60).parse('1,,2')
+
+    def test_parse_errors_on_empty_steps(self):
+        with self.assertRaises(ParseException):
+            crontab_parser(60).parse('*/')
+
+    def test_parse_errors_on_negative_number(self):
+        with self.assertRaises(ParseException):
+            crontab_parser(60).parse('-20')
+
+    def test_expand_cronspec_eats_iterables(self):
+        self.assertEqual(crontab._expand_cronspec(iter([1, 2, 3]), 100),
+                         set([1, 2, 3]))
+
+    def test_expand_cronspec_invalid_type(self):
+        with self.assertRaises(TypeError):
+            crontab._expand_cronspec(object(), 100)
+
+    def test_repr(self):
+        self.assertIn("*", repr(crontab("*")))
+
+    def test_eq(self):
+        self.assertEqual(crontab(day_of_week="1, 2"),
+                         crontab(day_of_week="1-2"))
+        self.assertEqual(crontab(minute="1", hour="2", day_of_week="5"),
+                         crontab(minute="1", hour="2", day_of_week="5"))
+        self.assertNotEqual(crontab(minute="1"), crontab(minute="2"))
+        self.assertFalse(object() == crontab(minute="1"))
+        self.assertFalse(crontab(minute="1") == object())
+
+
+class test_crontab_remaining_estimate(Case):
+
+    def next_ocurrance(self, crontab, now):
+        crontab.nowfun = lambda: now
+        return now + crontab.remaining_estimate(now)
+
+    def test_next_minute(self):
+        next = self.next_ocurrance(crontab(),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEqual(next, datetime(2010, 9, 11, 14, 31))
+
+    def test_not_next_minute(self):
+        next = self.next_ocurrance(crontab(),
+                                   datetime(2010, 9, 11, 14, 59, 15))
+        self.assertEqual(next, datetime(2010, 9, 11, 15, 0))
+
+    def test_this_hour(self):
+        next = self.next_ocurrance(crontab(minute=[5, 42]),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEqual(next, datetime(2010, 9, 11, 14, 42))
+
+    def test_not_this_hour(self):
+        next = self.next_ocurrance(crontab(minute=[5, 10, 15]),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEqual(next, datetime(2010, 9, 11, 15, 5))
+
+    def test_today(self):
+        next = self.next_ocurrance(crontab(minute=[5, 42], hour=[12, 17]),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEqual(next, datetime(2010, 9, 11, 17, 5))
+
+    def test_not_today(self):
+        next = self.next_ocurrance(crontab(minute=[5, 42], hour=[12]),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEqual(next, datetime(2010, 9, 12, 12, 5))
+
+    def test_weekday(self):
+        next = self.next_ocurrance(crontab(minute=30,
+                                           hour=14,
+                                           day_of_week="sat"),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEqual(next, datetime(2010, 9, 18, 14, 30))
+
+    def test_not_weekday(self):
+        next = self.next_ocurrance(crontab(minute=[5, 42],
+                                           day_of_week="mon-fri"),
+                                   datetime(2010, 9, 11, 14, 30, 15))
+        self.assertEqual(next, datetime(2010, 9, 13, 0, 5))
+
+
+class test_crontab_is_due(Case):
+
+    def setUp(self):
+        self.now = datetime.utcnow()
+        self.next_minute = 60 - self.now.second - 1e-6 * self.now.microsecond
+
+    def test_default_crontab_spec(self):
+        c = crontab()
+        self.assertEqual(c.minute, set(range(60)))
+        self.assertEqual(c.hour, set(range(24)))
+        self.assertEqual(c.day_of_week, set(range(7)))
+
+    def test_simple_crontab_spec(self):
+        c = crontab(minute=30)
+        self.assertEqual(c.minute, set([30]))
+        self.assertEqual(c.hour, set(range(24)))
+        self.assertEqual(c.day_of_week, set(range(7)))
+
+    def test_crontab_spec_minute_formats(self):
+        c = crontab(minute=30)
+        self.assertEqual(c.minute, set([30]))
+        c = crontab(minute='30')
+        self.assertEqual(c.minute, set([30]))
+        c = crontab(minute=(30, 40, 50))
+        self.assertEqual(c.minute, set([30, 40, 50]))
+        c = crontab(minute=set([30, 40, 50]))
+        self.assertEqual(c.minute, set([30, 40, 50]))
+
+    def test_crontab_spec_invalid_minute(self):
+        with self.assertRaises(ValueError):
+            crontab(minute=60)
+        with self.assertRaises(ValueError):
+            crontab(minute='0-100')
+
+    def test_crontab_spec_hour_formats(self):
+        c = crontab(hour=6)
+        self.assertEqual(c.hour, set([6]))
+        c = crontab(hour='5')
+        self.assertEqual(c.hour, set([5]))
+        c = crontab(hour=(4, 8, 12))
+        self.assertEqual(c.hour, set([4, 8, 12]))
+
+    def test_crontab_spec_invalid_hour(self):
+        with self.assertRaises(ValueError):
+            crontab(hour=24)
+        with self.assertRaises(ValueError):
+            crontab(hour='0-30')
+
+    def test_crontab_spec_dow_formats(self):
+        c = crontab(day_of_week=5)
+        self.assertEqual(c.day_of_week, set([5]))
+        c = crontab(day_of_week='5')
+        self.assertEqual(c.day_of_week, set([5]))
+        c = crontab(day_of_week='fri')
+        self.assertEqual(c.day_of_week, set([5]))
+        c = crontab(day_of_week='tuesday,sunday,fri')
+        self.assertEqual(c.day_of_week, set([0, 2, 5]))
+        c = crontab(day_of_week='mon-fri')
+        self.assertEqual(c.day_of_week, set([1, 2, 3, 4, 5]))
+        c = crontab(day_of_week='*/2')
+        self.assertEqual(c.day_of_week, set([0, 2, 4, 6]))
+
+    def seconds_almost_equal(self, a, b, precision):
+        for index, skew in enumerate((+0.1, 0, -0.1)):
+            try:
+                self.assertAlmostEqual(a, b + skew, precision)
+            except AssertionError:
+                if index + 1 >= 3:
+                    raise
+            else:
+                break
+
+    def test_crontab_spec_invalid_dow(self):
+        with self.assertRaises(ValueError):
+            crontab(day_of_week='fooday-barday')
+        with self.assertRaises(ValueError):
+            crontab(day_of_week='1,4,foo')
+        with self.assertRaises(ValueError):
+            crontab(day_of_week='7')
+        with self.assertRaises(ValueError):
+            crontab(day_of_week='12')
+
+    def test_every_minute_execution_is_due(self):
+        last_ran = self.now - timedelta(seconds=61)
+        due, remaining = every_minute.run_every.is_due(last_ran)
+        self.assertTrue(due)
+        self.seconds_almost_equal(remaining, self.next_minute, 1)
+
+    def test_every_minute_execution_is_not_due(self):
+        last_ran = self.now - timedelta(seconds=self.now.second)
+        due, remaining = every_minute.run_every.is_due(last_ran)
+        self.assertFalse(due)
+        self.seconds_almost_equal(remaining, self.next_minute, 1)
+
+    # 29th of May 2010 is a saturday
+    @patch_crontab_nowfun(hourly, datetime(2010, 5, 29, 10, 30))
+    def test_execution_is_due_on_saturday(self):
+        last_ran = self.now - timedelta(seconds=61)
+        due, remaining = every_minute.run_every.is_due(last_ran)
+        self.assertTrue(due)
+        self.seconds_almost_equal(remaining, self.next_minute, 1)
+
+    # 30th of May 2010 is a sunday
+    @patch_crontab_nowfun(hourly, datetime(2010, 5, 30, 10, 30))
+    def test_execution_is_due_on_sunday(self):
+        last_ran = self.now - timedelta(seconds=61)
+        due, remaining = every_minute.run_every.is_due(last_ran)
+        self.assertTrue(due)
+        self.seconds_almost_equal(remaining, self.next_minute, 1)
+
+    # 31st of May 2010 is a monday
+    @patch_crontab_nowfun(hourly, datetime(2010, 5, 31, 10, 30))
+    def test_execution_is_due_on_monday(self):
+        last_ran = self.now - timedelta(seconds=61)
+        due, remaining = every_minute.run_every.is_due(last_ran)
+        self.assertTrue(due)
+        self.seconds_almost_equal(remaining, self.next_minute, 1)
+
+    @patch_crontab_nowfun(hourly, datetime(2010, 5, 10, 10, 30))
+    def test_every_hour_execution_is_due(self):
+        due, remaining = hourly.run_every.is_due(
+                datetime(2010, 5, 10, 6, 30))
+        self.assertTrue(due)
+        self.assertEqual(remaining, 60 * 60)
+
+    @patch_crontab_nowfun(hourly, datetime(2010, 5, 10, 10, 29))
+    def test_every_hour_execution_is_not_due(self):
+        due, remaining = hourly.run_every.is_due(
+                datetime(2010, 5, 10, 9, 30))
+        self.assertFalse(due)
+        self.assertEqual(remaining, 60)
+
+    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 15))
+    def test_first_quarter_execution_is_due(self):
+        due, remaining = quarterly.run_every.is_due(
+                            datetime(2010, 5, 10, 6, 30))
+        self.assertTrue(due)
+        self.assertEqual(remaining, 15 * 60)
+
+    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 30))
+    def test_second_quarter_execution_is_due(self):
+        due, remaining = quarterly.run_every.is_due(
+                            datetime(2010, 5, 10, 6, 30))
+        self.assertTrue(due)
+        self.assertEqual(remaining, 15 * 60)
+
+    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 14))
+    def test_first_quarter_execution_is_not_due(self):
+        due, remaining = quarterly.run_every.is_due(
+                            datetime(2010, 5, 10, 10, 0))
+        self.assertFalse(due)
+        self.assertEqual(remaining, 60)
+
+    @patch_crontab_nowfun(quarterly, datetime(2010, 5, 10, 10, 29))
+    def test_second_quarter_execution_is_not_due(self):
+        due, remaining = quarterly.run_every.is_due(
+                            datetime(2010, 5, 10, 10, 15))
+        self.assertFalse(due)
+        self.assertEqual(remaining, 60)
+
+    @patch_crontab_nowfun(daily, datetime(2010, 5, 10, 7, 30))
+    def test_daily_execution_is_due(self):
+        due, remaining = daily.run_every.is_due(
+                datetime(2010, 5, 9, 7, 30))
+        self.assertTrue(due)
+        self.assertEqual(remaining, 24 * 60 * 60)
+
+    @patch_crontab_nowfun(daily, datetime(2010, 5, 10, 10, 30))
+    def test_daily_execution_is_not_due(self):
+        due, remaining = daily.run_every.is_due(
+                datetime(2010, 5, 10, 7, 30))
+        self.assertFalse(due)
+        self.assertEqual(remaining, 21 * 60 * 60)
+
+    @patch_crontab_nowfun(weekly, datetime(2010, 5, 6, 7, 30))
+    def test_weekly_execution_is_due(self):
+        due, remaining = weekly.run_every.is_due(
+                datetime(2010, 4, 30, 7, 30))
+        self.assertTrue(due)
+        self.assertEqual(remaining, 7 * 24 * 60 * 60)
+
+    @patch_crontab_nowfun(weekly, datetime(2010, 5, 7, 10, 30))
+    def test_weekly_execution_is_not_due(self):
+        due, remaining = weekly.run_every.is_due(
+                datetime(2010, 5, 6, 7, 30))
+        self.assertFalse(due)
+        self.assertEqual(remaining, 6 * 24 * 60 * 60 - 3 * 60 * 60)

+ 0 - 135
celery/tests/utilities/__init__.py

@@ -1,135 +0,0 @@
-from __future__ import absolute_import
-from __future__ import with_statement
-
-from kombu.utils.functional import promise
-
-from celery import utils
-from celery.utils import text
-from celery.utils import functional
-from celery.utils.functional import mpromise
-from celery.utils.threads import bgThread
-from celery.tests.utils import Case
-
-
-def double(x):
-    return x * 2
-
-
-class test_bgThread_interface(Case):
-
-    def test_body(self):
-        x = bgThread()
-        with self.assertRaises(NotImplementedError):
-            x.body()
-
-
-class test_chunks(Case):
-
-    def test_chunks(self):
-
-        # n == 2
-        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2)
-        self.assertListEqual(list(x),
-            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]])
-
-        # n == 3
-        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3)
-        self.assertListEqual(list(x),
-            [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]])
-
-        # n == 2 (exact)
-        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 2)
-        self.assertListEqual(list(x),
-            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
-
-
-class test_utils(Case):
-
-    def test_is_iterable(self):
-        for a in "f", ["f"], ("f", ), {"f": "f"}:
-            self.assertTrue(utils.is_iterable(a))
-        for b in object(), 1:
-            self.assertFalse(utils.is_iterable(b))
-
-    def test_padlist(self):
-        self.assertListEqual(functional.padlist(
-                ["George", "Costanza", "NYC"], 3),
-                ["George", "Costanza", "NYC"])
-        self.assertListEqual(functional.padlist(["George", "Costanza"], 3),
-                ["George", "Costanza", None])
-        self.assertListEqual(functional.padlist(
-                ["George", "Costanza", "NYC"], 4, default="Earth"),
-                ["George", "Costanza", "NYC", "Earth"])
-
-    def test_firstmethod_AttributeError(self):
-        self.assertIsNone(functional.firstmethod("foo")([object()]))
-
-    def test_firstmethod_promises(self):
-
-        class A(object):
-
-            def __init__(self, value=None):
-                self.value = value
-
-            def m(self):
-                return self.value
-
-        self.assertEqual("four", functional.firstmethod("m")([
-            A(), A(), A(), A("four"), A("five")]))
-        self.assertEqual("four", functional.firstmethod("m")([
-            A(), A(), A(), promise(lambda: A("four")), A("five")]))
-
-    def test_first(self):
-        iterations = [0]
-
-        def predicate(value):
-            iterations[0] += 1
-            if value == 5:
-                return True
-            return False
-
-        self.assertEqual(5, functional.first(predicate, xrange(10)))
-        self.assertEqual(iterations[0], 6)
-
-        iterations[0] = 0
-        self.assertIsNone(functional.first(predicate, xrange(10, 20)))
-        self.assertEqual(iterations[0], 10)
-
-    def test_truncate_text(self):
-        self.assertEqual(text.truncate("ABCDEFGHI", 3), "ABC...")
-        self.assertEqual(text.truncate("ABCDEFGHI", 10), "ABCDEFGHI")
-
-    def test_abbr(self):
-        self.assertEqual(text.abbr(None, 3), "???")
-        self.assertEqual(text.abbr("ABCDEFGHI", 6), "ABC...")
-        self.assertEqual(text.abbr("ABCDEFGHI", 20), "ABCDEFGHI")
-        self.assertEqual(text.abbr("ABCDEFGHI", 6, None), "ABCDEF")
-
-    def test_abbrtask(self):
-        self.assertEqual(text.abbrtask(None, 3), "???")
-        self.assertEqual(text.abbrtask("feeds.tasks.refresh", 10),
-                                        "[.]refresh")
-        self.assertEqual(text.abbrtask("feeds.tasks.refresh", 30),
-                                        "feeds.tasks.refresh")
-
-    def test_cached_property(self):
-
-        def fun(obj):
-            return fun.value
-
-        x = utils.cached_property(fun)
-        self.assertIs(x.__get__(None), x)
-        self.assertIs(x.__set__(None, None), x)
-        self.assertIs(x.__delete__(None), x)
-
-
-class test_mpromise(Case):
-
-    def test_is_memoized(self):
-
-        it = iter(xrange(20, 30))
-        p = mpromise(it.next)
-        self.assertEqual(p(), 20)
-        self.assertTrue(p.evaluated)
-        self.assertEqual(p(), 20)
-        self.assertEqual(repr(p), "20")

+ 135 - 0
celery/tests/utilities/test_utils.py

@@ -0,0 +1,135 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+from kombu.utils.functional import promise
+
+from celery import utils
+from celery.utils import text
+from celery.utils import functional
+from celery.utils.functional import mpromise
+from celery.utils.threads import bgThread
+from celery.tests.utils import Case
+
+
+def double(x):
+    return x * 2
+
+
+class test_bgThread_interface(Case):
+
+    def test_body(self):
+        x = bgThread()
+        with self.assertRaises(NotImplementedError):
+            x.body()
+
+
+class test_chunks(Case):
+
+    def test_chunks(self):
+
+        # n == 2
+        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2)
+        self.assertListEqual(list(x),
+            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]])
+
+        # n == 3
+        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3)
+        self.assertListEqual(list(x),
+            [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]])
+
+        # n == 2 (exact)
+        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 2)
+        self.assertListEqual(list(x),
+            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
+
+
+class test_utils(Case):
+
+    def test_is_iterable(self):
+        for a in "f", ["f"], ("f", ), {"f": "f"}:
+            self.assertTrue(utils.is_iterable(a))
+        for b in object(), 1:
+            self.assertFalse(utils.is_iterable(b))
+
+    def test_padlist(self):
+        self.assertListEqual(functional.padlist(
+                ["George", "Costanza", "NYC"], 3),
+                ["George", "Costanza", "NYC"])
+        self.assertListEqual(functional.padlist(["George", "Costanza"], 3),
+                ["George", "Costanza", None])
+        self.assertListEqual(functional.padlist(
+                ["George", "Costanza", "NYC"], 4, default="Earth"),
+                ["George", "Costanza", "NYC", "Earth"])
+
+    def test_firstmethod_AttributeError(self):
+        self.assertIsNone(functional.firstmethod("foo")([object()]))
+
+    def test_firstmethod_promises(self):
+
+        class A(object):
+
+            def __init__(self, value=None):
+                self.value = value
+
+            def m(self):
+                return self.value
+
+        self.assertEqual("four", functional.firstmethod("m")([
+            A(), A(), A(), A("four"), A("five")]))
+        self.assertEqual("four", functional.firstmethod("m")([
+            A(), A(), A(), promise(lambda: A("four")), A("five")]))
+
+    def test_first(self):
+        iterations = [0]
+
+        def predicate(value):
+            iterations[0] += 1
+            if value == 5:
+                return True
+            return False
+
+        self.assertEqual(5, functional.first(predicate, xrange(10)))
+        self.assertEqual(iterations[0], 6)
+
+        iterations[0] = 0
+        self.assertIsNone(functional.first(predicate, xrange(10, 20)))
+        self.assertEqual(iterations[0], 10)
+
+    def test_truncate_text(self):
+        self.assertEqual(text.truncate("ABCDEFGHI", 3), "ABC...")
+        self.assertEqual(text.truncate("ABCDEFGHI", 10), "ABCDEFGHI")
+
+    def test_abbr(self):
+        self.assertEqual(text.abbr(None, 3), "???")
+        self.assertEqual(text.abbr("ABCDEFGHI", 6), "ABC...")
+        self.assertEqual(text.abbr("ABCDEFGHI", 20), "ABCDEFGHI")
+        self.assertEqual(text.abbr("ABCDEFGHI", 6, None), "ABCDEF")
+
+    def test_abbrtask(self):
+        self.assertEqual(text.abbrtask(None, 3), "???")
+        self.assertEqual(text.abbrtask("feeds.tasks.refresh", 10),
+                                        "[.]refresh")
+        self.assertEqual(text.abbrtask("feeds.tasks.refresh", 30),
+                                        "feeds.tasks.refresh")
+
+    def test_cached_property(self):
+
+        def fun(obj):
+            return fun.value
+
+        x = utils.cached_property(fun)
+        self.assertIs(x.__get__(None), x)
+        self.assertIs(x.__set__(None, None), x)
+        self.assertIs(x.__delete__(None), x)
+
+
+class test_mpromise(Case):
+
+    def test_is_memoized(self):
+
+        it = iter(xrange(20, 30))
+        p = mpromise(it.next)
+        self.assertEqual(p(), 20)
+        self.assertTrue(p.evaluated)
+        self.assertEqual(p(), 20)
+        self.assertEqual(repr(p), "20")

+ 0 - 956
celery/tests/worker/__init__.py

@@ -1,956 +0,0 @@
-from __future__ import absolute_import
-from __future__ import with_statement
-
-import socket
-import sys
-
-from collections import deque
-from datetime import datetime, timedelta
-from Queue import Empty
-
-from kombu.transport.base import Message
-from kombu.connection import BrokerConnection
-from mock import Mock, patch
-from nose import SkipTest
-
-from celery import current_app
-from celery.app.defaults import DEFAULTS
-from celery.concurrency.base import BasePool
-from celery.datastructures import AttributeDict
-from celery.exceptions import SystemTerminate
-from celery.task import task as task_dec
-from celery.task import periodic_task as periodic_task_dec
-from celery.utils import uuid
-from celery.worker import WorkController
-from celery.worker.buckets import FastQueue
-from celery.worker.job import Request
-from celery.worker.consumer import Consumer as MainConsumer
-from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX, CLOSE
-from celery.utils.serialization import pickle
-from celery.utils.timer2 import Timer
-
-from celery.tests.utils import AppCase, Case
-
-
-class PlaceHolder(object):
-        pass
-
-
-class MyKombuConsumer(MainConsumer):
-    broadcast_consumer = Mock()
-    task_consumer = Mock()
-
-    def __init__(self, *args, **kwargs):
-        kwargs.setdefault("pool", BasePool(2))
-        super(MyKombuConsumer, self).__init__(*args, **kwargs)
-
-    def restart_heartbeat(self):
-        self.heart = None
-
-
-class MockNode(object):
-    commands = []
-
-    def handle_message(self, body, message):
-        self.commands.append(body.pop("command", None))
-
-
-class MockEventDispatcher(object):
-    sent = []
-    closed = False
-    flushed = False
-    _outbound_buffer = []
-
-    def send(self, event, *args, **kwargs):
-        self.sent.append(event)
-
-    def close(self):
-        self.closed = True
-
-    def flush(self):
-        self.flushed = True
-
-
-class MockHeart(object):
-    closed = False
-
-    def stop(self):
-        self.closed = True
-
-
-@task_dec()
-def foo_task(x, y, z, **kwargs):
-    return x * y * z
-
-
-@periodic_task_dec(run_every=60)
-def foo_periodic_task():
-    return "foo"
-
-
-def create_message(channel, **data):
-    data.setdefault("id", uuid())
-    channel.no_ack_consumers = set()
-    return Message(channel, body=pickle.dumps(dict(**data)),
-                   content_type="application/x-python-serialize",
-                   content_encoding="binary",
-                   delivery_info={"consumer_tag": "mock"})
-
-
-class test_QoS(Case):
-
-    class _QoS(QoS):
-        def __init__(self, value):
-            self.value = value
-            QoS.__init__(self, None, value)
-
-        def set(self, value):
-            return value
-
-    def test_qos_increment_decrement(self):
-        qos = self._QoS(10)
-        self.assertEqual(qos.increment(), 11)
-        self.assertEqual(qos.increment(3), 14)
-        self.assertEqual(qos.increment(-30), 14)
-        self.assertEqual(qos.decrement(7), 7)
-        self.assertEqual(qos.decrement(), 6)
-        with self.assertRaises(AssertionError):
-            qos.decrement(10)
-
-    def test_qos_disabled_increment_decrement(self):
-        qos = self._QoS(0)
-        self.assertEqual(qos.increment(), 0)
-        self.assertEqual(qos.increment(3), 0)
-        self.assertEqual(qos.increment(-30), 0)
-        self.assertEqual(qos.decrement(7), 0)
-        self.assertEqual(qos.decrement(), 0)
-        self.assertEqual(qos.decrement(10), 0)
-
-    def test_qos_thread_safe(self):
-        qos = self._QoS(10)
-
-        def add():
-            for i in xrange(1000):
-                qos.increment()
-
-        def sub():
-            for i in xrange(1000):
-                qos.decrement_eventually()
-
-        def threaded(funs):
-            from threading import Thread
-            threads = [Thread(target=fun) for fun in funs]
-            for thread in threads:
-                thread.start()
-            for thread in threads:
-                thread.join()
-
-        threaded([add, add])
-        self.assertEqual(qos.value, 2010)
-
-        qos.value = 1000
-        threaded([add, sub])  # n = 2
-        self.assertEqual(qos.value, 1000)
-
-    def test_exceeds_short(self):
-        qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
-        qos.update()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
-        qos.increment()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
-        qos.increment()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
-        qos.decrement()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
-        qos.decrement()
-        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
-
-    def test_consumer_increment_decrement(self):
-        consumer = Mock()
-        qos = QoS(consumer, 10)
-        qos.update()
-        self.assertEqual(qos.value, 10)
-        self.assertIn({"prefetch_count": 10}, consumer.qos.call_args)
-        qos.decrement()
-        self.assertEqual(qos.value, 9)
-        self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 8)
-        self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
-
-        # Does not decrement 0 value
-        qos.value = 0
-        qos.decrement()
-        self.assertEqual(qos.value, 0)
-        qos.increment()
-        self.assertEqual(qos.value, 0)
-
-    def test_consumer_decrement_eventually(self):
-        consumer = Mock()
-        qos = QoS(consumer, 10)
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 9)
-        qos.value = 0
-        qos.decrement_eventually()
-        self.assertEqual(qos.value, 0)
-
-    def test_set(self):
-        consumer = Mock()
-        qos = QoS(consumer, 10)
-        qos.set(12)
-        self.assertEqual(qos.prev, 12)
-        qos.set(qos.prev)
-
-
-class test_Consumer(Case):
-
-    def setUp(self):
-        self.ready_queue = FastQueue()
-        self.eta_schedule = Timer()
-
-    def tearDown(self):
-        self.eta_schedule.stop()
-
-    def test_info(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                           send_events=False)
-        l.qos = QoS(l.task_consumer, 10)
-        info = l.info
-        self.assertEqual(info["prefetch_count"], 10)
-        self.assertFalse(info["broker"])
-
-        l.connection = current_app.broker_connection()
-        info = l.info
-        self.assertTrue(info["broker"])
-
-    def test_start_when_closed(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                            send_events=False)
-        l._state = CLOSE
-        l.start()
-
-    def test_connection(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                           send_events=False)
-
-        l.reset_connection()
-        self.assertIsInstance(l.connection, BrokerConnection)
-
-        l._state = RUN
-        l.event_dispatcher = None
-        l.stop_consumers(close_connection=False)
-        self.assertTrue(l.connection)
-
-        l._state = RUN
-        l.stop_consumers()
-        self.assertIsNone(l.connection)
-        self.assertIsNone(l.task_consumer)
-
-        l.reset_connection()
-        self.assertIsInstance(l.connection, BrokerConnection)
-        l.stop_consumers()
-
-        l.stop()
-        l.close_connection()
-        self.assertIsNone(l.connection)
-        self.assertIsNone(l.task_consumer)
-
-    def test_close_connection(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                           send_events=False)
-        l._state = RUN
-        l.close_connection()
-
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                           send_events=False)
-        eventer = l.event_dispatcher = Mock()
-        eventer.enabled = True
-        heart = l.heart = MockHeart()
-        l._state = RUN
-        l.stop_consumers()
-        self.assertTrue(eventer.close.call_count)
-        self.assertTrue(heart.closed)
-
-    @patch("celery.worker.consumer.warn")
-    def test_receive_message_unknown(self, warn):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                           send_events=False)
-        backend = Mock()
-        m = create_message(backend, unknown={"baz": "!!!"})
-        l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
-
-        l.receive_message(m.decode(), m)
-        self.assertTrue(warn.call_count)
-
-    @patch("celery.utils.timer2.to_timestamp")
-    def test_receive_message_eta_OverflowError(self, to_timestamp):
-        to_timestamp.side_effect = OverflowError()
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                             send_events=False)
-        m = create_message(Mock(), task=foo_task.name,
-                                   args=("2, 2"),
-                                   kwargs={},
-                                   eta=datetime.now().isoformat())
-        l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
-        l.update_strategies()
-
-        l.receive_message(m.decode(), m)
-        self.assertTrue(m.acknowledged)
-        self.assertTrue(to_timestamp.call_count)
-
-    @patch("celery.worker.consumer.error")
-    def test_receive_message_InvalidTaskError(self, error):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                           send_events=False)
-        m = create_message(Mock(), task=foo_task.name,
-                           args=(1, 2), kwargs="foobarbaz", id=1)
-        l.update_strategies()
-        l.event_dispatcher = Mock()
-        l.pidbox_node = MockNode()
-
-        l.receive_message(m.decode(), m)
-        self.assertIn("Received invalid task message", error.call_args[0][0])
-
-    @patch("celery.worker.consumer.crit")
-    def test_on_decode_error(self, crit):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                           send_events=False)
-
-        class MockMessage(Mock):
-            content_type = "application/x-msgpack"
-            content_encoding = "binary"
-            body = "foobarbaz"
-
-        message = MockMessage()
-        l.on_decode_error(message, KeyError("foo"))
-        self.assertTrue(message.ack.call_count)
-        self.assertIn("Can't decode message body", crit.call_args[0][0])
-
-    def test_receieve_message(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                           send_events=False)
-        m = create_message(Mock(), task=foo_task.name,
-                           args=[2, 4, 8], kwargs={})
-        l.update_strategies()
-
-        l.event_dispatcher = Mock()
-        l.receive_message(m.decode(), m)
-
-        in_bucket = self.ready_queue.get_nowait()
-        self.assertIsInstance(in_bucket, Request)
-        self.assertEqual(in_bucket.name, foo_task.name)
-        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
-        self.assertTrue(self.eta_schedule.empty())
-
-    def test_start_connection_error(self):
-
-        class MockConsumer(MainConsumer):
-            iterations = 0
-
-            def consume_messages(self):
-                if not self.iterations:
-                    self.iterations = 1
-                    raise KeyError("foo")
-                raise SyntaxError("bar")
-
-        l = MockConsumer(self.ready_queue, self.eta_schedule,
-                             send_events=False, pool=BasePool())
-        l.connection_errors = (KeyError, )
-        with self.assertRaises(SyntaxError):
-            l.start()
-        l.heart.stop()
-        l.priority_timer.stop()
-
-    def test_start_channel_error(self):
-        # Regression test for AMQPChannelExceptions that can occur within the
-        # consumer. (i.e. 404 errors)
-
-        class MockConsumer(MainConsumer):
-            iterations = 0
-
-            def consume_messages(self):
-                if not self.iterations:
-                    self.iterations = 1
-                    raise KeyError("foo")
-                raise SyntaxError("bar")
-
-        l = MockConsumer(self.ready_queue, self.eta_schedule,
-                             send_events=False, pool=BasePool())
-
-        l.channel_errors = (KeyError, )
-        self.assertRaises(SyntaxError, l.start)
-        l.heart.stop()
-        l.priority_timer.stop()
-
-    def test_consume_messages_ignores_socket_timeout(self):
-
-        class Connection(current_app.broker_connection().__class__):
-            obj = None
-
-            def drain_events(self, **kwargs):
-                self.obj.connection = None
-                raise socket.timeout(10)
-
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                            send_events=False)
-        l.connection = Connection()
-        l.task_consumer = Mock()
-        l.connection.obj = l
-        l.qos = QoS(l.task_consumer, 10)
-        l.consume_messages()
-
-    def test_consume_messages_when_socket_error(self):
-
-        class Connection(current_app.broker_connection().__class__):
-            obj = None
-
-            def drain_events(self, **kwargs):
-                self.obj.connection = None
-                raise socket.error("foo")
-
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                            send_events=False)
-        l._state = RUN
-        c = l.connection = Connection()
-        l.connection.obj = l
-        l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, 10)
-        with self.assertRaises(socket.error):
-            l.consume_messages()
-
-        l._state = CLOSE
-        l.connection = c
-        l.consume_messages()
-
-    def test_consume_messages(self):
-
-        class Connection(current_app.broker_connection().__class__):
-            obj = None
-
-            def drain_events(self, **kwargs):
-                self.obj.connection = None
-
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                             send_events=False)
-        l.connection = Connection()
-        l.connection.obj = l
-        l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, 10)
-
-        l.consume_messages()
-        l.consume_messages()
-        self.assertTrue(l.task_consumer.consume.call_count)
-        l.task_consumer.qos.assert_called_with(prefetch_count=10)
-        l.qos.decrement()
-        l.consume_messages()
-        l.task_consumer.qos.assert_called_with(prefetch_count=9)
-
-    def test_maybe_conn_error(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                             send_events=False)
-        l.connection_errors = (KeyError, )
-        l.channel_errors = (SyntaxError, )
-        l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
-        l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
-        l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
-        with self.assertRaises(IndexError):
-            l.maybe_conn_error(Mock(side_effect=IndexError("foo")))
-
-    def test_apply_eta_task(self):
-        from celery.worker import state
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                             send_events=False)
-        l.qos = QoS(None, 10)
-
-        task = object()
-        qos = l.qos.value
-        l.apply_eta_task(task)
-        self.assertIn(task, state.reserved_requests)
-        self.assertEqual(l.qos.value, qos - 1)
-        self.assertIs(self.ready_queue.get_nowait(), task)
-
-    def test_receieve_message_eta_isoformat(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                             send_events=False)
-        m = create_message(Mock(), task=foo_task.name,
-                           eta=datetime.now().isoformat(),
-                           args=[2, 4, 8], kwargs={})
-
-        l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, l.initial_prefetch_count)
-        l.event_dispatcher = Mock()
-        l.enabled = False
-        l.update_strategies()
-        l.receive_message(m.decode(), m)
-        l.eta_schedule.stop()
-
-        items = [entry[2] for entry in self.eta_schedule.queue]
-        found = 0
-        for item in items:
-            if item.args[0].name == foo_task.name:
-                found = True
-        self.assertTrue(found)
-        self.assertTrue(l.task_consumer.qos.call_count)
-        l.eta_schedule.stop()
-
-    def test_on_control(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                             send_events=False)
-        l.pidbox_node = Mock()
-        l.reset_pidbox_node = Mock()
-
-        l.on_control("foo", "bar")
-        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
-
-        l.pidbox_node = Mock()
-        l.pidbox_node.handle_message.side_effect = KeyError("foo")
-        l.on_control("foo", "bar")
-        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
-
-        l.pidbox_node = Mock()
-        l.pidbox_node.handle_message.side_effect = ValueError("foo")
-        l.on_control("foo", "bar")
-        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
-        l.reset_pidbox_node.assert_called_with()
-
-    def test_revoke(self):
-        ready_queue = FastQueue()
-        l = MyKombuConsumer(ready_queue, self.eta_schedule,
-                           send_events=False)
-        backend = Mock()
-        id = uuid()
-        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
-                           kwargs={}, id=id)
-        from celery.worker.state import revoked
-        revoked.add(id)
-
-        l.receive_message(t.decode(), t)
-        self.assertTrue(ready_queue.empty())
-
-    def test_receieve_message_not_registered(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                          send_events=False)
-        backend = Mock()
-        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
-
-        l.event_dispatcher = Mock()
-        self.assertFalse(l.receive_message(m.decode(), m))
-        with self.assertRaises(Empty):
-            self.ready_queue.get_nowait()
-        self.assertTrue(self.eta_schedule.empty())
-
-    @patch("celery.worker.consumer.warn")
-    @patch("celery.worker.consumer.logger")
-    def test_receieve_message_ack_raises(self, logger, warn):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                          send_events=False)
-        backend = Mock()
-        m = create_message(backend, args=[2, 4, 8], kwargs={})
-
-        l.event_dispatcher = Mock()
-        l.connection_errors = (socket.error, )
-        m.reject = Mock()
-        m.reject.side_effect = socket.error("foo")
-        self.assertFalse(l.receive_message(m.decode(), m))
-        self.assertTrue(warn.call_count)
-        with self.assertRaises(Empty):
-            self.ready_queue.get_nowait()
-        self.assertTrue(self.eta_schedule.empty())
-        m.reject.assert_called_with()
-        self.assertTrue(logger.critical.call_count)
-
-    def test_receieve_message_eta(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                            send_events=False)
-        l.event_dispatcher = Mock()
-        l.event_dispatcher._outbound_buffer = deque()
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name,
-                           args=[2, 4, 8], kwargs={},
-                           eta=(datetime.now() +
-                               timedelta(days=1)).isoformat())
-
-        l.reset_connection()
-        p = l.app.conf.BROKER_CONNECTION_RETRY
-        l.app.conf.BROKER_CONNECTION_RETRY = False
-        try:
-            l.reset_connection()
-        finally:
-            l.app.conf.BROKER_CONNECTION_RETRY = p
-        l.stop_consumers()
-        l.event_dispatcher = Mock()
-        l.receive_message(m.decode(), m)
-        l.eta_schedule.stop()
-        in_hold = self.eta_schedule.queue[0]
-        self.assertEqual(len(in_hold), 3)
-        eta, priority, entry = in_hold
-        task = entry.args[0]
-        self.assertIsInstance(task, Request)
-        self.assertEqual(task.name, foo_task.name)
-        self.assertEqual(task.execute(), 2 * 4 * 8)
-        with self.assertRaises(Empty):
-            self.ready_queue.get_nowait()
-
-    def test_reset_pidbox_node(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                          send_events=False)
-        l.pidbox_node = Mock()
-        chan = l.pidbox_node.channel = Mock()
-        l.connection = Mock()
-        chan.close.side_effect = socket.error("foo")
-        l.connection_errors = (socket.error, )
-        l.reset_pidbox_node()
-        chan.close.assert_called_with()
-
-    def test_reset_pidbox_node_green(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                          send_events=False)
-        l.pool = Mock()
-        l.pool.is_green = True
-        l.reset_pidbox_node()
-        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)
-
-    def test__green_pidbox_node(self):
-        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
-                          send_events=False)
-        l.pidbox_node = Mock()
-
-        class BConsumer(Mock):
-
-            def __enter__(self):
-                self.consume()
-                return self
-
-            def __exit__(self, *exc_info):
-                self.cancel()
-
-        l.pidbox_node.listen = BConsumer()
-        connections = []
-
-        class Connection(object):
-
-            def __init__(self, obj):
-                connections.append(self)
-                self.obj = obj
-                self.default_channel = self.channel()
-                self.closed = False
-
-            def __enter__(self):
-                return self
-
-            def __exit__(self, *exc_info):
-                self.close()
-
-            def channel(self):
-                return Mock()
-
-            def drain_events(self, **kwargs):
-                self.obj.connection = None
-                self.obj._pidbox_node_shutdown.set()
-
-            def close(self):
-                self.closed = True
-
-        l.connection = Mock()
-        l._open_connection = lambda: Connection(obj=l)
-        l._green_pidbox_node()
-
-        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
-        self.assertTrue(l.broadcast_consumer)
-        l.broadcast_consumer.consume.assert_called_with()
-
-        self.assertIsNone(l.connection)
-        self.assertTrue(connections[0].closed)
-
-    def test_start__consume_messages(self):
-
-        class _QoS(object):
-            prev = 3
-            value = 4
-
-            def update(self):
-                self.prev = self.value
-
-        class _Consumer(MyKombuConsumer):
-            iterations = 0
-
-            def reset_connection(self):
-                if self.iterations >= 1:
-                    raise KeyError("foo")
-
-        init_callback = Mock()
-        l = _Consumer(self.ready_queue, self.eta_schedule,
-                      send_events=False, init_callback=init_callback)
-        l.task_consumer = Mock()
-        l.broadcast_consumer = Mock()
-        l.qos = _QoS()
-        l.connection = BrokerConnection()
-        l.iterations = 0
-
-        def raises_KeyError(limit=None):
-            l.iterations += 1
-            if l.qos.prev != l.qos.value:
-                l.qos.update()
-            if l.iterations >= 2:
-                raise KeyError("foo")
-
-        l.consume_messages = raises_KeyError
-        with self.assertRaises(KeyError):
-            l.start()
-        self.assertTrue(init_callback.call_count)
-        self.assertEqual(l.iterations, 1)
-        self.assertEqual(l.qos.prev, l.qos.value)
-
-        init_callback.reset_mock()
-        l = _Consumer(self.ready_queue, self.eta_schedule,
-                      send_events=False, init_callback=init_callback)
-        l.qos = _QoS()
-        l.task_consumer = Mock()
-        l.broadcast_consumer = Mock()
-        l.connection = BrokerConnection()
-        l.consume_messages = Mock(side_effect=socket.error("foo"))
-        with self.assertRaises(socket.error):
-            l.start()
-        self.assertTrue(init_callback.call_count)
-        self.assertTrue(l.consume_messages.call_count)
-
-    def test_reset_connection_with_no_node(self):
-        l = MainConsumer(self.ready_queue, self.eta_schedule,
-                         send_events=False)
-        self.assertEqual(None, l.pool)
-        l.reset_connection()
-
-
-class test_WorkController(AppCase):
-
-    def setup(self):
-        self.worker = self.create_worker()
-        from celery import worker
-        self._logger = worker.logger
-        self.logger = worker.logger = Mock()
-
-    def teardown(self):
-        from celery import worker
-        worker.logger = self._logger
-
-    def create_worker(self, **kw):
-        worker = WorkController(concurrency=1, loglevel=0, **kw)
-        worker._shutdown_complete.set()
-        return worker
-
-    @patch("celery.platforms.signals")
-    @patch("celery.platforms.set_mp_process_title")
-    def test_process_initializer(self, set_mp_process_title, _signals):
-        from celery import Celery
-        from celery import signals
-        from celery.app.state import _tls
-        from celery.concurrency.processes import process_initializer
-        from celery.concurrency.processes import (WORKER_SIGRESET,
-                                                  WORKER_SIGIGNORE)
-
-        def on_worker_process_init(**kwargs):
-            on_worker_process_init.called = True
-        on_worker_process_init.called = False
-        signals.worker_process_init.connect(on_worker_process_init)
-
-        loader = Mock()
-        app = Celery(loader=loader, set_as_current=False)
-        app.conf = AttributeDict(DEFAULTS)
-        process_initializer(app, "awesome.worker.com")
-        self.assertIn((tuple(WORKER_SIGIGNORE), {}),
-                      _signals.ignore.call_args_list)
-        self.assertIn((tuple(WORKER_SIGRESET), {}),
-                      _signals.reset.call_args_list)
-        self.assertTrue(app.loader.init_worker.call_count)
-        self.assertTrue(on_worker_process_init.called)
-        self.assertIs(_tls.current_app, app)
-        set_mp_process_title.assert_called_with("celeryd",
-                        hostname="awesome.worker.com")
-
-    def test_with_rate_limits_disabled(self):
-        worker = WorkController(concurrency=1, loglevel=0,
-                                disable_rate_limits=True)
-        self.assertTrue(hasattr(worker.ready_queue, "put"))
-
-    def test_attrs(self):
-        worker = self.worker
-        self.assertIsInstance(worker.scheduler, Timer)
-        self.assertTrue(worker.scheduler)
-        self.assertTrue(worker.pool)
-        self.assertTrue(worker.consumer)
-        self.assertTrue(worker.mediator)
-        self.assertTrue(worker.components)
-
-    def test_with_embedded_celerybeat(self):
-        worker = WorkController(concurrency=1, loglevel=0,
-                                embed_clockservice=True)
-        self.assertTrue(worker.beat)
-        self.assertIn(worker.beat, worker.components)
-
-    def test_with_autoscaler(self):
-        worker = self.create_worker(autoscale=[10, 3], send_events=False,
-                                eta_scheduler_cls="celery.utils.timer2.Timer")
-        self.assertTrue(worker.autoscaler)
-
-    def test_dont_stop_or_terminate(self):
-        worker = WorkController(concurrency=1, loglevel=0)
-        worker.stop()
-        self.assertNotEqual(worker._state, worker.CLOSE)
-        worker.terminate()
-        self.assertNotEqual(worker._state, worker.CLOSE)
-
-        sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
-        try:
-            worker._state = worker.RUN
-            worker.stop(in_sighandler=True)
-            self.assertNotEqual(worker._state, worker.CLOSE)
-            worker.terminate(in_sighandler=True)
-            self.assertNotEqual(worker._state, worker.CLOSE)
-        finally:
-            worker.pool.signal_safe = sigsafe
-
-    def test_on_timer_error(self):
-        worker = WorkController(concurrency=1, loglevel=0)
-
-        try:
-            raise KeyError("foo")
-        except KeyError:
-            exc_info = sys.exc_info()
-
-        worker.on_timer_error(exc_info)
-        msg, args = self.logger.error.call_args[0]
-        self.assertIn("KeyError", msg % args)
-
-    def test_on_timer_tick(self):
-        worker = WorkController(concurrency=1, loglevel=10)
-
-        worker.on_timer_tick(30.0)
-        xargs = self.logger.debug.call_args[0]
-        fmt, arg = xargs[0], xargs[1]
-        self.assertEqual(30.0, arg)
-        self.assertIn("Next eta %s secs", fmt)
-
-    def test_process_task(self):
-        worker = self.worker
-        worker.pool = Mock()
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
-                           kwargs={})
-        task = Request.from_message(m, m.decode())
-        worker.process_task(task)
-        self.assertEqual(worker.pool.apply_async.call_count, 1)
-        worker.pool.stop()
-
-    def test_process_task_raise_base(self):
-        worker = self.worker
-        worker.pool = Mock()
-        worker.pool.apply_async.side_effect = KeyboardInterrupt("Ctrl+C")
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
-                           kwargs={})
-        task = Request.from_message(m, m.decode())
-        worker.components = []
-        worker._state = worker.RUN
-        with self.assertRaises(KeyboardInterrupt):
-            worker.process_task(task)
-        self.assertEqual(worker._state, worker.TERMINATE)
-
-    def test_process_task_raise_SystemTerminate(self):
-        worker = self.worker
-        worker.pool = Mock()
-        worker.pool.apply_async.side_effect = SystemTerminate()
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
-                           kwargs={})
-        task = Request.from_message(m, m.decode())
-        worker.components = []
-        worker._state = worker.RUN
-        with self.assertRaises(SystemExit):
-            worker.process_task(task)
-        self.assertEqual(worker._state, worker.TERMINATE)
-
-    def test_process_task_raise_regular(self):
-        worker = self.worker
-        worker.pool = Mock()
-        worker.pool.apply_async.side_effect = KeyError("some exception")
-        backend = Mock()
-        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
-                           kwargs={})
-        task = Request.from_message(m, m.decode())
-        worker.process_task(task)
-        worker.pool.stop()
-
-    def test_start_catches_base_exceptions(self):
-        worker1 = self.create_worker()
-        stc = Mock()
-        stc.start.side_effect = SystemTerminate()
-        worker1.components = [stc]
-        worker1.start()
-        self.assertTrue(stc.terminate.call_count)
-
-        worker2 = self.create_worker()
-        sec = Mock()
-        sec.start.side_effect = SystemExit()
-        sec.terminate = None
-        worker2.components = [sec]
-        worker2.start()
-        self.assertTrue(sec.stop.call_count)
-
-    def test_state_db(self):
-        from celery.worker import state
-        Persistent = state.Persistent
-
-        state.Persistent = Mock()
-        try:
-            worker = self.create_worker(state_db="statefilename")
-            self.assertTrue(worker._persistence)
-        finally:
-            state.Persistent = Persistent
-
-    def test_disable_rate_limits_solo(self):
-        worker = self.create_worker(disable_rate_limits=True,
-                                    pool_cls="solo")
-        self.assertIsInstance(worker.ready_queue, FastQueue)
-        self.assertIsNone(worker.mediator)
-        self.assertEqual(worker.ready_queue.put, worker.process_task)
-
-    def test_disable_rate_limits_processes(self):
-        try:
-            worker = self.create_worker(disable_rate_limits=True,
-                                        pool_cls="processes")
-        except ImportError:
-            raise SkipTest("multiprocessing not supported")
-        self.assertIsInstance(worker.ready_queue, FastQueue)
-        self.assertTrue(worker.mediator)
-        self.assertNotEqual(worker.ready_queue.put, worker.process_task)
-
-    def test_start__stop(self):
-        worker = self.worker
-        worker._shutdown_complete.set()
-        worker.components = [Mock(), Mock(), Mock(), Mock()]
-
-        worker.start()
-        for w in worker.components:
-            self.assertTrue(w.start.call_count)
-        worker.stop()
-        for component in worker.components:
-            self.assertTrue(w.stop.call_count)
-
-    def test_start__terminate(self):
-        worker = self.worker
-        worker._shutdown_complete.set()
-        worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
-        for component in worker.components[:3]:
-            component.terminate = None
-
-        worker.start()
-        for w in worker.components[:3]:
-            self.assertTrue(w.start.call_count)
-        self.assertTrue(worker._running, len(worker.components))
-        self.assertEqual(worker._state, RUN)
-        worker.terminate()
-        for component in worker.components[:3]:
-            self.assertTrue(component.stop.call_count)
-        self.assertTrue(worker.components[4].terminate.call_count)

+ 10 - 0
celery/tests/worker/test_autoscale.py

@@ -125,6 +125,16 @@ class test_Autoscaler(Case):
         self.assertEqual(x.processes, 2)
         x.update(3, None)
         self.assertEqual(x.processes, 3)
+        x.force_scale_down(1000)
+        self.assertEqual(x.min_concurrency, 0)
+        self.assertEqual(x.processes, 0)
+        x.force_scale_up(1000)
+        x.min_concurrency = 1
+        x.force_scale_down(1)
+
+        x.update(max=300, min=10)
+        x.update(max=300, min=2)
+        x.update(max=None, min=None)
 
     def test_info(self):
         x = autoscale.Autoscaler(self.pool, 10, 3)

+ 31 - 0
celery/tests/worker/test_control.py

@@ -147,9 +147,16 @@ class test_ControlPanel(Case):
             self.assertDictContainsSubset({"total": 100,
                                            "consumer": {"xyz": "XYZ"}},
                                           self.panel.handle("stats"))
+            self.panel.state.consumer = Mock()
+            self.panel.handle("stats")
+            self.assertTrue(
+                self.panel.state.consumer.controller.autoscaler.info.called)
         finally:
             state.total_count = prev_count
 
+    def test_report(self):
+        self.panel.handle("report")
+
     def test_active(self):
         from celery.worker.job import TaskRequest
 
@@ -182,6 +189,14 @@ class test_ControlPanel(Case):
         panel.handle("pool_shrink")
         self.assertEqual(consumer.pool.size, 1)
 
+        panel.state.consumer = Mock()
+        panel.state.consumer.controller = Mock()
+        sc = panel.state.consumer.controller.autoscaler = Mock()
+        panel.handle("pool_grow")
+        self.assertTrue(sc.force_scale_up.called)
+        panel.handle("pool_shrink")
+        self.assertTrue(sc.force_scale_down.called)
+
     def test_add__cancel_consumer(self):
 
         class MockConsumer(object):
@@ -208,6 +223,7 @@ class test_ControlPanel(Case):
         panel.handle("add_consumer", {"queue": "MyQueue"})
         self.assertIn("MyQueue", consumer.task_consumer.queues)
         self.assertTrue(consumer.task_consumer.consuming)
+        panel.handle("add_consumer", {"queue": "MyQueue"})
         panel.handle("cancel_consumer", {"queue": "MyQueue"})
         self.assertIn("MyQueue", consumer.task_consumer.cancelled)
 
@@ -346,6 +362,21 @@ class test_ControlPanel(Case):
         finally:
             state.active_requests.discard(request)
 
+    def test_autoscale(self):
+        self.panel.state.consumer = Mock()
+        self.panel.state.consumer.controller = Mock()
+        sc = self.panel.state.consumer.controller.autoscaler = Mock()
+        sc.update.return_value = 10, 2
+        m = {"method": "autoscale",
+             "destination": hostname,
+             "arguments": {"max": "10", "min": "2"}}
+        r = self.panel.dispatch_from_message(m)
+        self.assertIn("ok", r)
+
+        self.panel.state.consumer.controller.autoscaler = None
+        r = self.panel.dispatch_from_message(m)
+        self.assertIn("error", r)
+
     def test_ping(self):
         m = {"method": "ping",
              "destination": hostname}

+ 5 - 0
celery/tests/worker/test_mediator.py

@@ -56,6 +56,11 @@ class test_Mediator(Case):
 
         self.assertEqual(got["value"], "George Costanza")
 
+        ready_queue.put(MockTask("Jerry Seinfeld"))
+        m._does_debug = False
+        m.body()
+        self.assertEqual(got["value"], "Jerry Seinfeld")
+
     @patch("os._exit")
     def test_mediator_crash(self, _exit):
         ms = [None]

+ 67 - 5
celery/tests/worker/test_request.py

@@ -26,6 +26,7 @@ from celery.result import AsyncResult
 from celery.task import task as task_dec
 from celery.task.base import Task
 from celery.utils import uuid
+from celery.worker import job as module
 from celery.worker.job import Request, TaskRequest, execute_and_trace
 from celery.worker.state import revoked
 
@@ -74,7 +75,7 @@ def on_ack(*args, **kwargs):
     scratch["ACK"] = True
 
 
-@task_dec(accept_magic_kwargs=True)
+@task_dec(accept_magic_kwargs=False)
 def mytask(i, **kwargs):
     return i ** i
 
@@ -97,8 +98,8 @@ def mytask_some_kwargs(i, logfile):
     return i ** i
 
 
-@task_dec(accept_magic_kwargs=True)
-def mytask_raising(i, **kwargs):
+@task_dec(accept_magic_kwargs=False)
+def mytask_raising(i):
     raise KeyError(i)
 
 
@@ -233,6 +234,16 @@ class test_TaskRequest(Case):
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
         self.assertTrue(repr(tw))
 
+    @patch("celery.worker.job.kwdict")
+    def test_kwdict(self, kwdict):
+
+        prev, module.NEEDS_KWDICT = module.NEEDS_KWDICT, True
+        try:
+            TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
+            self.assertTrue(kwdict.called)
+        finally:
+            module.NEEDS_KWDICT = prev
+
     def test_sets_store_errors(self):
         mytask.ignore_result = True
         try:
@@ -260,6 +271,19 @@ class test_TaskRequest(Case):
             einfo = ExceptionInfo(sys.exc_info())
             tw.on_failure(einfo)
             self.assertIn("task-retried", tw.eventer.sent)
+            tw._does_info = False
+            tw.on_failure(einfo)
+            einfo.internal = True
+            tw.on_failure(einfo)
+
+    def test_compat_properties(self):
+        tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
+        self.assertEqual(tw.task_id, tw.id)
+        self.assertEqual(tw.task_name, tw.name)
+        tw.task_id = "ID"
+        self.assertEqual(tw.id, "ID")
+        tw.task_name = "NAME"
+        self.assertEqual(tw.name, "NAME")
 
     def test_terminate__task_started(self):
         pool = Mock()
@@ -375,10 +399,12 @@ class test_TaskRequest(Case):
 
     def test_execute_acks_late(self):
         mytask_raising.acks_late = True
-        tw = TaskRequest(mytask_raising.name, uuid(), [1], {"f": "x"})
+        tw = TaskRequest(mytask_raising.name, uuid(), [1])
         try:
             tw.execute()
             self.assertTrue(tw.acknowledged)
+            tw.task.accept_magic_kwargs = False
+            tw.execute()
         finally:
             mytask_raising.acks_late = False
 
@@ -391,6 +417,8 @@ class test_TaskRequest(Case):
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
         tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
         self.assertTrue(tw.acknowledged)
+        tw._does_debug = False
+        tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
 
     def test_on_accepted_acks_late(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
@@ -413,8 +441,39 @@ class test_TaskRequest(Case):
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
         tw.time_start = 1
         tw.on_success(42)
+        tw._does_info = False
+        tw.on_success(42)
         self.assertFalse(tw.acknowledged)
 
+    def test_on_success_BaseException(self):
+        tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
+        tw.time_start = 1
+        with self.assertRaises(SystemExit):
+            try:
+                raise SystemExit()
+            except SystemExit:
+                tw.on_success(ExceptionInfo(sys.exc_info()))
+            else:
+                assert False
+
+    def test_on_success_eventer(self):
+        tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
+        tw.time_start = 1
+        tw.eventer = Mock()
+        tw.send_event = Mock()
+        tw.on_success(42)
+        self.assertTrue(tw.send_event.called)
+
+    def test_on_success_when_failure(self):
+        tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
+        tw.time_start = 1
+        tw.on_failure = Mock()
+        try:
+            raise KeyError("foo")
+        except Exception:
+            tw.on_success(ExceptionInfo(sys.exc_info()))
+            self.assertTrue(tw.on_failure.called)
+
     def test_on_success_acks_late(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
         tw.time_start = 1
@@ -638,7 +697,7 @@ class test_TaskRequest(Case):
 
     def test_execute_fail(self):
         tid = uuid()
-        tw = TaskRequest(mytask_raising.name, tid, [4], {"f": "x"})
+        tw = TaskRequest(mytask_raising.name, tid, [4])
         self.assertIsInstance(tw.execute(), ExceptionInfo)
         meta = mytask_raising.backend.get_task_meta(tid)
         self.assertEqual(meta["status"], states.FAILURE)
@@ -671,6 +730,9 @@ class test_TaskRequest(Case):
         self.assertIn("f", p.args[3])
         self.assertIn([4], p.args)
 
+        tw.task.accept_magic_kwargs = False
+        tw.execute_using_pool(p)
+
     def test_default_kwargs(self):
         tid = uuid()
         tw = TaskRequest(mytask.name, tid, [4], {"f": "x"})

+ 4 - 0
celery/tests/worker/test_state.py

@@ -52,6 +52,10 @@ class test_Persistent(StateResetCase):
     def on_setup(self):
         self.p = MyPersistent(filename="celery-state")
 
+    def test_close_twice(self):
+        self.p._is_open = False
+        self.p.close()
+
     def test_constructor(self):
         self.assertDictEqual(self.p.db, {})
         self.assertEqual(self.p.db.filename, self.p.filename)

+ 1026 - 0
celery/tests/worker/test_worker.py

@@ -0,0 +1,1026 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+import socket
+import sys
+
+from collections import deque
+from datetime import datetime, timedelta
+from Queue import Empty
+
+from kombu.exceptions import StdChannelError
+from kombu.transport.base import Message
+from kombu.connection import BrokerConnection
+from mock import Mock, patch
+from nose import SkipTest
+
+from celery import current_app
+from celery.app.defaults import DEFAULTS
+from celery.concurrency.base import BasePool
+from celery.datastructures import AttributeDict
+from celery.exceptions import SystemTerminate
+from celery.task import task as task_dec
+from celery.task import periodic_task as periodic_task_dec
+from celery.utils import uuid
+from celery.worker import WorkController, Queues
+from celery.worker.buckets import FastQueue
+from celery.worker.job import Request
+from celery.worker.consumer import Consumer as MainConsumer
+from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX, CLOSE
+from celery.utils.serialization import pickle
+from celery.utils.timer2 import Timer
+from celery.utils.threads import Event
+
+from celery.tests.utils import AppCase, Case
+
+
+class PlaceHolder(object):
+        pass
+
+
+class MyKombuConsumer(MainConsumer):
+    broadcast_consumer = Mock()
+    task_consumer = Mock()
+
+    def __init__(self, *args, **kwargs):
+        kwargs.setdefault("pool", BasePool(2))
+        super(MyKombuConsumer, self).__init__(*args, **kwargs)
+
+    def restart_heartbeat(self):
+        self.heart = None
+
+
+class MockNode(object):
+    commands = []
+
+    def handle_message(self, body, message):
+        self.commands.append(body.pop("command", None))
+
+
+class MockEventDispatcher(object):
+    sent = []
+    closed = False
+    flushed = False
+    _outbound_buffer = []
+
+    def send(self, event, *args, **kwargs):
+        self.sent.append(event)
+
+    def close(self):
+        self.closed = True
+
+    def flush(self):
+        self.flushed = True
+
+
+class MockHeart(object):
+    closed = False
+
+    def stop(self):
+        self.closed = True
+
+
+@task_dec()
+def foo_task(x, y, z, **kwargs):
+    return x * y * z
+
+
+@periodic_task_dec(run_every=60)
+def foo_periodic_task():
+    return "foo"
+
+
+def create_message(channel, **data):
+    data.setdefault("id", uuid())
+    channel.no_ack_consumers = set()
+    return Message(channel, body=pickle.dumps(dict(**data)),
+                   content_type="application/x-python-serialize",
+                   content_encoding="binary",
+                   delivery_info={"consumer_tag": "mock"})
+
+
+class test_QoS(Case):
+
+    class _QoS(QoS):
+        def __init__(self, value):
+            self.value = value
+            QoS.__init__(self, None, value)
+
+        def set(self, value):
+            return value
+
+    def test_qos_increment_decrement(self):
+        qos = self._QoS(10)
+        self.assertEqual(qos.increment(), 11)
+        self.assertEqual(qos.increment(3), 14)
+        self.assertEqual(qos.increment(-30), 14)
+        self.assertEqual(qos.decrement(7), 7)
+        self.assertEqual(qos.decrement(), 6)
+        with self.assertRaises(AssertionError):
+            qos.decrement(10)
+
+    def test_qos_disabled_increment_decrement(self):
+        qos = self._QoS(0)
+        self.assertEqual(qos.increment(), 0)
+        self.assertEqual(qos.increment(3), 0)
+        self.assertEqual(qos.increment(-30), 0)
+        self.assertEqual(qos.decrement(7), 0)
+        self.assertEqual(qos.decrement(), 0)
+        self.assertEqual(qos.decrement(10), 0)
+
+    def test_qos_thread_safe(self):
+        qos = self._QoS(10)
+
+        def add():
+            for i in xrange(1000):
+                qos.increment()
+
+        def sub():
+            for i in xrange(1000):
+                qos.decrement_eventually()
+
+        def threaded(funs):
+            from threading import Thread
+            threads = [Thread(target=fun) for fun in funs]
+            for thread in threads:
+                thread.start()
+            for thread in threads:
+                thread.join()
+
+        threaded([add, add])
+        self.assertEqual(qos.value, 2010)
+
+        qos.value = 1000
+        threaded([add, sub])  # n = 2
+        self.assertEqual(qos.value, 1000)
+
+    def test_exceeds_short(self):
+        qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
+        qos.update()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
+        qos.increment()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
+        qos.increment()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
+        qos.decrement()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
+        qos.decrement()
+        self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
+
+    def test_consumer_increment_decrement(self):
+        consumer = Mock()
+        qos = QoS(consumer, 10)
+        qos.update()
+        self.assertEqual(qos.value, 10)
+        self.assertIn({"prefetch_count": 10}, consumer.qos.call_args)
+        qos.decrement()
+        self.assertEqual(qos.value, 9)
+        self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
+        qos.decrement_eventually()
+        self.assertEqual(qos.value, 8)
+        self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
+
+        # Does not decrement 0 value
+        qos.value = 0
+        qos.decrement()
+        self.assertEqual(qos.value, 0)
+        qos.increment()
+        self.assertEqual(qos.value, 0)
+
+    def test_consumer_decrement_eventually(self):
+        consumer = Mock()
+        qos = QoS(consumer, 10)
+        qos.decrement_eventually()
+        self.assertEqual(qos.value, 9)
+        qos.value = 0
+        qos.decrement_eventually()
+        self.assertEqual(qos.value, 0)
+
+    def test_set(self):
+        consumer = Mock()
+        qos = QoS(consumer, 10)
+        qos.set(12)
+        self.assertEqual(qos.prev, 12)
+        qos.set(qos.prev)
+
+
+class test_Consumer(Case):
+
+    def setUp(self):
+        self.ready_queue = FastQueue()
+        self.eta_schedule = Timer()
+
+    def tearDown(self):
+        self.eta_schedule.stop()
+
+    def test_info(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                           send_events=False)
+        l.qos = QoS(l.task_consumer, 10)
+        info = l.info
+        self.assertEqual(info["prefetch_count"], 10)
+        self.assertFalse(info["broker"])
+
+        l.connection = current_app.broker_connection()
+        info = l.info
+        self.assertTrue(info["broker"])
+
+    def test_start_when_closed(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                            send_events=False)
+        l._state = CLOSE
+        l.start()
+
+    def test_connection(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                           send_events=False)
+
+        l.reset_connection()
+        self.assertIsInstance(l.connection, BrokerConnection)
+
+        l._state = RUN
+        l.event_dispatcher = None
+        l.stop_consumers(close_connection=False)
+        self.assertTrue(l.connection)
+
+        l._state = RUN
+        l.stop_consumers()
+        self.assertIsNone(l.connection)
+        self.assertIsNone(l.task_consumer)
+
+        l.reset_connection()
+        self.assertIsInstance(l.connection, BrokerConnection)
+        l.stop_consumers()
+
+        l.stop()
+        l.close_connection()
+        self.assertIsNone(l.connection)
+        self.assertIsNone(l.task_consumer)
+
+    def test_close_connection(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                           send_events=False)
+        l._state = RUN
+        l.close_connection()
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                           send_events=False)
+        eventer = l.event_dispatcher = Mock()
+        eventer.enabled = True
+        heart = l.heart = MockHeart()
+        l._state = RUN
+        l.stop_consumers()
+        self.assertTrue(eventer.close.call_count)
+        self.assertTrue(heart.closed)
+
+    @patch("celery.worker.consumer.warn")
+    def test_receive_message_unknown(self, warn):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                           send_events=False)
+        backend = Mock()
+        m = create_message(backend, unknown={"baz": "!!!"})
+        l.event_dispatcher = Mock()
+        l.pidbox_node = MockNode()
+
+        l.receive_message(m.decode(), m)
+        self.assertTrue(warn.call_count)
+
+    @patch("celery.utils.timer2.to_timestamp")
+    def test_receive_message_eta_OverflowError(self, to_timestamp):
+        to_timestamp.side_effect = OverflowError()
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                             send_events=False)
+        m = create_message(Mock(), task=foo_task.name,
+                                   args=("2, 2"),
+                                   kwargs={},
+                                   eta=datetime.now().isoformat())
+        l.event_dispatcher = Mock()
+        l.pidbox_node = MockNode()
+        l.update_strategies()
+
+        l.receive_message(m.decode(), m)
+        self.assertTrue(m.acknowledged)
+        self.assertTrue(to_timestamp.call_count)
+
+    @patch("celery.worker.consumer.error")
+    def test_receive_message_InvalidTaskError(self, error):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                           send_events=False)
+        m = create_message(Mock(), task=foo_task.name,
+                           args=(1, 2), kwargs="foobarbaz", id=1)
+        l.update_strategies()
+        l.event_dispatcher = Mock()
+        l.pidbox_node = MockNode()
+
+        l.receive_message(m.decode(), m)
+        self.assertIn("Received invalid task message", error.call_args[0][0])
+
+    @patch("celery.worker.consumer.crit")
+    def test_on_decode_error(self, crit):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                           send_events=False)
+
+        class MockMessage(Mock):
+            content_type = "application/x-msgpack"
+            content_encoding = "binary"
+            body = "foobarbaz"
+
+        message = MockMessage()
+        l.on_decode_error(message, KeyError("foo"))
+        self.assertTrue(message.ack.call_count)
+        self.assertIn("Can't decode message body", crit.call_args[0][0])
+
+    def test_receieve_message(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                           send_events=False)
+        m = create_message(Mock(), task=foo_task.name,
+                           args=[2, 4, 8], kwargs={})
+        l.update_strategies()
+
+        l.event_dispatcher = Mock()
+        l.receive_message(m.decode(), m)
+
+        in_bucket = self.ready_queue.get_nowait()
+        self.assertIsInstance(in_bucket, Request)
+        self.assertEqual(in_bucket.name, foo_task.name)
+        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
+        self.assertTrue(self.eta_schedule.empty())
+
+    def test_start_connection_error(self):
+
+        class MockConsumer(MainConsumer):
+            iterations = 0
+
+            def consume_messages(self):
+                if not self.iterations:
+                    self.iterations = 1
+                    raise KeyError("foo")
+                raise SyntaxError("bar")
+
+        l = MockConsumer(self.ready_queue, self.eta_schedule,
+                             send_events=False, pool=BasePool())
+        l.connection_errors = (KeyError, )
+        with self.assertRaises(SyntaxError):
+            l.start()
+        l.heart.stop()
+        l.priority_timer.stop()
+
+    def test_start_channel_error(self):
+        # Regression test for AMQPChannelExceptions that can occur within the
+        # consumer. (i.e. 404 errors)
+
+        class MockConsumer(MainConsumer):
+            iterations = 0
+
+            def consume_messages(self):
+                if not self.iterations:
+                    self.iterations = 1
+                    raise KeyError("foo")
+                raise SyntaxError("bar")
+
+        l = MockConsumer(self.ready_queue, self.eta_schedule,
+                             send_events=False, pool=BasePool())
+
+        l.channel_errors = (KeyError, )
+        self.assertRaises(SyntaxError, l.start)
+        l.heart.stop()
+        l.priority_timer.stop()
+
+    def test_consume_messages_ignores_socket_timeout(self):
+
+        class Connection(current_app.broker_connection().__class__):
+            obj = None
+
+            def drain_events(self, **kwargs):
+                self.obj.connection = None
+                raise socket.timeout(10)
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                            send_events=False)
+        l.connection = Connection()
+        l.task_consumer = Mock()
+        l.connection.obj = l
+        l.qos = QoS(l.task_consumer, 10)
+        l.consume_messages()
+
+    def test_consume_messages_when_socket_error(self):
+
+        class Connection(current_app.broker_connection().__class__):
+            obj = None
+
+            def drain_events(self, **kwargs):
+                self.obj.connection = None
+                raise socket.error("foo")
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                            send_events=False)
+        l._state = RUN
+        c = l.connection = Connection()
+        l.connection.obj = l
+        l.task_consumer = Mock()
+        l.qos = QoS(l.task_consumer, 10)
+        with self.assertRaises(socket.error):
+            l.consume_messages()
+
+        l._state = CLOSE
+        l.connection = c
+        l.consume_messages()
+
+    def test_consume_messages(self):
+
+        class Connection(current_app.broker_connection().__class__):
+            obj = None
+
+            def drain_events(self, **kwargs):
+                self.obj.connection = None
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                             send_events=False)
+        l.connection = Connection()
+        l.connection.obj = l
+        l.task_consumer = Mock()
+        l.qos = QoS(l.task_consumer, 10)
+
+        l.consume_messages()
+        l.consume_messages()
+        self.assertTrue(l.task_consumer.consume.call_count)
+        l.task_consumer.qos.assert_called_with(prefetch_count=10)
+        l.qos.decrement()
+        l.consume_messages()
+        l.task_consumer.qos.assert_called_with(prefetch_count=9)
+
+    def test_maybe_conn_error(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                             send_events=False)
+        l.connection_errors = (KeyError, )
+        l.channel_errors = (SyntaxError, )
+        l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
+        l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
+        l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
+        with self.assertRaises(IndexError):
+            l.maybe_conn_error(Mock(side_effect=IndexError("foo")))
+
+    def test_apply_eta_task(self):
+        from celery.worker import state
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                             send_events=False)
+        l.qos = QoS(None, 10)
+
+        task = object()
+        qos = l.qos.value
+        l.apply_eta_task(task)
+        self.assertIn(task, state.reserved_requests)
+        self.assertEqual(l.qos.value, qos - 1)
+        self.assertIs(self.ready_queue.get_nowait(), task)
+
+    def test_receieve_message_eta_isoformat(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                             send_events=False)
+        m = create_message(Mock(), task=foo_task.name,
+                           eta=datetime.now().isoformat(),
+                           args=[2, 4, 8], kwargs={})
+
+        l.task_consumer = Mock()
+        l.qos = QoS(l.task_consumer, l.initial_prefetch_count)
+        l.event_dispatcher = Mock()
+        l.enabled = False
+        l.update_strategies()
+        l.receive_message(m.decode(), m)
+        l.eta_schedule.stop()
+
+        items = [entry[2] for entry in self.eta_schedule.queue]
+        found = 0
+        for item in items:
+            if item.args[0].name == foo_task.name:
+                found = True
+        self.assertTrue(found)
+        self.assertTrue(l.task_consumer.qos.call_count)
+        l.eta_schedule.stop()
+
+    def test_on_control(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                             send_events=False)
+        l.pidbox_node = Mock()
+        l.reset_pidbox_node = Mock()
+
+        l.on_control("foo", "bar")
+        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
+
+        l.pidbox_node = Mock()
+        l.pidbox_node.handle_message.side_effect = KeyError("foo")
+        l.on_control("foo", "bar")
+        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
+
+        l.pidbox_node = Mock()
+        l.pidbox_node.handle_message.side_effect = ValueError("foo")
+        l.on_control("foo", "bar")
+        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
+        l.reset_pidbox_node.assert_called_with()
+
+    def test_revoke(self):
+        ready_queue = FastQueue()
+        l = MyKombuConsumer(ready_queue, self.eta_schedule,
+                           send_events=False)
+        backend = Mock()
+        id = uuid()
+        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
+                           kwargs={}, id=id)
+        from celery.worker.state import revoked
+        revoked.add(id)
+
+        l.receive_message(t.decode(), t)
+        self.assertTrue(ready_queue.empty())
+
+    def test_receieve_message_not_registered(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                          send_events=False)
+        backend = Mock()
+        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
+
+        l.event_dispatcher = Mock()
+        self.assertFalse(l.receive_message(m.decode(), m))
+        with self.assertRaises(Empty):
+            self.ready_queue.get_nowait()
+        self.assertTrue(self.eta_schedule.empty())
+
+    @patch("celery.worker.consumer.warn")
+    @patch("celery.worker.consumer.logger")
+    def test_receieve_message_ack_raises(self, logger, warn):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                          send_events=False)
+        backend = Mock()
+        m = create_message(backend, args=[2, 4, 8], kwargs={})
+
+        l.event_dispatcher = Mock()
+        l.connection_errors = (socket.error, )
+        m.reject = Mock()
+        m.reject.side_effect = socket.error("foo")
+        self.assertFalse(l.receive_message(m.decode(), m))
+        self.assertTrue(warn.call_count)
+        with self.assertRaises(Empty):
+            self.ready_queue.get_nowait()
+        self.assertTrue(self.eta_schedule.empty())
+        m.reject.assert_called_with()
+        self.assertTrue(logger.critical.call_count)
+
+    def test_receieve_message_eta(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                            send_events=False)
+        l.event_dispatcher = Mock()
+        l.event_dispatcher._outbound_buffer = deque()
+        backend = Mock()
+        m = create_message(backend, task=foo_task.name,
+                           args=[2, 4, 8], kwargs={},
+                           eta=(datetime.now() +
+                               timedelta(days=1)).isoformat())
+
+        l.reset_connection()
+        p = l.app.conf.BROKER_CONNECTION_RETRY
+        l.app.conf.BROKER_CONNECTION_RETRY = False
+        try:
+            l.reset_connection()
+        finally:
+            l.app.conf.BROKER_CONNECTION_RETRY = p
+        l.stop_consumers()
+        l.event_dispatcher = Mock()
+        l.receive_message(m.decode(), m)
+        l.eta_schedule.stop()
+        in_hold = self.eta_schedule.queue[0]
+        self.assertEqual(len(in_hold), 3)
+        eta, priority, entry = in_hold
+        task = entry.args[0]
+        self.assertIsInstance(task, Request)
+        self.assertEqual(task.name, foo_task.name)
+        self.assertEqual(task.execute(), 2 * 4 * 8)
+        with self.assertRaises(Empty):
+            self.ready_queue.get_nowait()
+
+    def test_reset_pidbox_node(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                          send_events=False)
+        l.pidbox_node = Mock()
+        chan = l.pidbox_node.channel = Mock()
+        l.connection = Mock()
+        chan.close.side_effect = socket.error("foo")
+        l.connection_errors = (socket.error, )
+        l.reset_pidbox_node()
+        chan.close.assert_called_with()
+
+    def test_reset_pidbox_node_green(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                          send_events=False)
+        l.pool = Mock()
+        l.pool.is_green = True
+        l.reset_pidbox_node()
+        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)
+
+    def test__green_pidbox_node(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                          send_events=False)
+        l.pidbox_node = Mock()
+
+        class BConsumer(Mock):
+
+            def __enter__(self):
+                self.consume()
+                return self
+
+            def __exit__(self, *exc_info):
+                self.cancel()
+
+        l.pidbox_node.listen = BConsumer()
+        connections = []
+
+        class Connection(object):
+            calls = 0
+
+            def __init__(self, obj):
+                connections.append(self)
+                self.obj = obj
+                self.default_channel = self.channel()
+                self.closed = False
+
+            def __enter__(self):
+                return self
+
+            def __exit__(self, *exc_info):
+                self.close()
+
+            def channel(self):
+                return Mock()
+
+            def drain_events(self, **kwargs):
+                if not self.calls:
+                    self.calls += 1
+                    raise socket.timeout()
+                self.obj.connection = None
+                self.obj._pidbox_node_shutdown.set()
+
+            def close(self):
+                self.closed = True
+
+        l.connection = Mock()
+        l._open_connection = lambda: Connection(obj=l)
+        l._green_pidbox_node()
+
+        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
+        self.assertTrue(l.broadcast_consumer)
+        l.broadcast_consumer.consume.assert_called_with()
+
+        self.assertIsNone(l.connection)
+        self.assertTrue(connections[0].closed)
+
+    @patch("kombu.connection.BrokerConnection._establish_connection")
+    @patch("kombu.utils.sleep")
+    def test_open_connection_errback(self, sleep, connect):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                      send_events=False)
+        calls = [0]
+        from kombu.transport.memory import Transport
+        Transport.connection_errors = (StdChannelError, )
+
+        def effect():
+            try:
+                if calls[0] > 1:
+                    return
+                raise StdChannelError()
+            finally:
+                calls[0] += 1
+        connect.side_effect = effect
+        l._open_connection()
+        connect.assert_called_with()
+
+    def test_stop_pidbox_node(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
+                      send_events=False)
+        l._pidbox_node_stopped = Event()
+        l._pidbox_node_shutdown = Event()
+        l._pidbox_node_stopped.set()
+        l.stop_pidbox_node()
+
+    def test_start__consume_messages(self):
+
+        class _QoS(object):
+            prev = 3
+            value = 4
+
+            def update(self):
+                self.prev = self.value
+
+        class _Consumer(MyKombuConsumer):
+            iterations = 0
+
+            def reset_connection(self):
+                if self.iterations >= 1:
+                    raise KeyError("foo")
+
+        init_callback = Mock()
+        l = _Consumer(self.ready_queue, self.eta_schedule,
+                      send_events=False, init_callback=init_callback)
+        l.task_consumer = Mock()
+        l.broadcast_consumer = Mock()
+        l.qos = _QoS()
+        l.connection = BrokerConnection()
+        l.iterations = 0
+
+        def raises_KeyError(limit=None):
+            l.iterations += 1
+            if l.qos.prev != l.qos.value:
+                l.qos.update()
+            if l.iterations >= 2:
+                raise KeyError("foo")
+
+        l.consume_messages = raises_KeyError
+        with self.assertRaises(KeyError):
+            l.start()
+        self.assertTrue(init_callback.call_count)
+        self.assertEqual(l.iterations, 1)
+        self.assertEqual(l.qos.prev, l.qos.value)
+
+        init_callback.reset_mock()
+        l = _Consumer(self.ready_queue, self.eta_schedule,
+                      send_events=False, init_callback=init_callback)
+        l.qos = _QoS()
+        l.task_consumer = Mock()
+        l.broadcast_consumer = Mock()
+        l.connection = BrokerConnection()
+        l.consume_messages = Mock(side_effect=socket.error("foo"))
+        with self.assertRaises(socket.error):
+            l.start()
+        self.assertTrue(init_callback.call_count)
+        self.assertTrue(l.consume_messages.call_count)
+
+    def test_reset_connection_with_no_node(self):
+        l = MainConsumer(self.ready_queue, self.eta_schedule,
+                         send_events=False)
+        self.assertEqual(None, l.pool)
+        l.reset_connection()
+
+    def test_on_task_revoked(self):
+        l = MainConsumer(self.ready_queue, self.eta_schedule,
+                         send_events=False)
+        task = Mock()
+        task.revoked.return_value = True
+        l.on_task(task)
+
+    def test_on_task_no_events(self):
+        l = MainConsumer(self.ready_queue, self.eta_schedule,
+                         send_events=False)
+        task = Mock()
+        task.revoked.return_value = False
+        l.event_dispatcher = Mock()
+        l.event_dispatcher.enabled = False
+        task.eta = None
+        l._does_info = False
+        l.on_task(task)
+
+
+class test_WorkController(AppCase):
+
+    def setup(self):
+        self.worker = self.create_worker()
+        from celery import worker
+        self._logger = worker.logger
+        self.logger = worker.logger = Mock()
+
+    def teardown(self):
+        from celery import worker
+        worker.logger = self._logger
+
+    def create_worker(self, **kw):
+        worker = WorkController(concurrency=1, loglevel=0, **kw)
+        worker._shutdown_complete.set()
+        return worker
+
+    @patch("celery.platforms.signals")
+    @patch("celery.platforms.set_mp_process_title")
+    def test_process_initializer(self, set_mp_process_title, _signals):
+        from celery import Celery
+        from celery import signals
+        from celery.app.state import _tls
+        from celery.concurrency.processes import process_initializer
+        from celery.concurrency.processes import (WORKER_SIGRESET,
+                                                  WORKER_SIGIGNORE)
+
+        def on_worker_process_init(**kwargs):
+            on_worker_process_init.called = True
+        on_worker_process_init.called = False
+        signals.worker_process_init.connect(on_worker_process_init)
+
+        loader = Mock()
+        app = Celery(loader=loader, set_as_current=False)
+        app.conf = AttributeDict(DEFAULTS)
+        process_initializer(app, "awesome.worker.com")
+        self.assertIn((tuple(WORKER_SIGIGNORE), {}),
+                      _signals.ignore.call_args_list)
+        self.assertIn((tuple(WORKER_SIGRESET), {}),
+                      _signals.reset.call_args_list)
+        self.assertTrue(app.loader.init_worker.call_count)
+        self.assertTrue(on_worker_process_init.called)
+        self.assertIs(_tls.current_app, app)
+        set_mp_process_title.assert_called_with("celeryd",
+                        hostname="awesome.worker.com")
+
+    def test_with_rate_limits_disabled(self):
+        worker = WorkController(concurrency=1, loglevel=0,
+                                disable_rate_limits=True)
+        self.assertTrue(hasattr(worker.ready_queue, "put"))
+
+    def test_attrs(self):
+        worker = self.worker
+        self.assertIsInstance(worker.scheduler, Timer)
+        self.assertTrue(worker.scheduler)
+        self.assertTrue(worker.pool)
+        self.assertTrue(worker.consumer)
+        self.assertTrue(worker.mediator)
+        self.assertTrue(worker.components)
+
+    def test_with_embedded_celerybeat(self):
+        worker = WorkController(concurrency=1, loglevel=0,
+                                embed_clockservice=True)
+        self.assertTrue(worker.beat)
+        self.assertIn(worker.beat, worker.components)
+
+    def test_with_autoscaler(self):
+        worker = self.create_worker(autoscale=[10, 3], send_events=False,
+                                eta_scheduler_cls="celery.utils.timer2.Timer")
+        self.assertTrue(worker.autoscaler)
+
+    def test_dont_stop_or_terminate(self):
+        worker = WorkController(concurrency=1, loglevel=0)
+        worker.stop()
+        self.assertNotEqual(worker._state, worker.CLOSE)
+        worker.terminate()
+        self.assertNotEqual(worker._state, worker.CLOSE)
+
+        sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
+        try:
+            worker._state = worker.RUN
+            worker.stop(in_sighandler=True)
+            self.assertNotEqual(worker._state, worker.CLOSE)
+            worker.terminate(in_sighandler=True)
+            self.assertNotEqual(worker._state, worker.CLOSE)
+        finally:
+            worker.pool.signal_safe = sigsafe
+
+    def test_on_timer_error(self):
+        worker = WorkController(concurrency=1, loglevel=0)
+
+        try:
+            raise KeyError("foo")
+        except KeyError:
+            exc_info = sys.exc_info()
+
+        worker.on_timer_error(exc_info)
+        msg, args = self.logger.error.call_args[0]
+        self.assertIn("KeyError", msg % args)
+
+    def test_on_timer_tick(self):
+        worker = WorkController(concurrency=1, loglevel=10)
+
+        worker.on_timer_tick(30.0)
+        xargs = self.logger.debug.call_args[0]
+        fmt, arg = xargs[0], xargs[1]
+        self.assertEqual(30.0, arg)
+        self.assertIn("Next eta %s secs", fmt)
+
+    def test_process_task(self):
+        worker = self.worker
+        worker.pool = Mock()
+        backend = Mock()
+        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+                           kwargs={})
+        task = Request.from_message(m, m.decode())
+        worker.process_task(task)
+        self.assertEqual(worker.pool.apply_async.call_count, 1)
+        worker.pool.stop()
+
+    def test_process_task_raise_base(self):
+        worker = self.worker
+        worker.pool = Mock()
+        worker.pool.apply_async.side_effect = KeyboardInterrupt("Ctrl+C")
+        backend = Mock()
+        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+                           kwargs={})
+        task = Request.from_message(m, m.decode())
+        worker.components = []
+        worker._state = worker.RUN
+        with self.assertRaises(KeyboardInterrupt):
+            worker.process_task(task)
+        self.assertEqual(worker._state, worker.TERMINATE)
+
+    def test_process_task_raise_SystemTerminate(self):
+        worker = self.worker
+        worker.pool = Mock()
+        worker.pool.apply_async.side_effect = SystemTerminate()
+        backend = Mock()
+        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+                           kwargs={})
+        task = Request.from_message(m, m.decode())
+        worker.components = []
+        worker._state = worker.RUN
+        with self.assertRaises(SystemExit):
+            worker.process_task(task)
+        self.assertEqual(worker._state, worker.TERMINATE)
+
+    def test_process_task_raise_regular(self):
+        worker = self.worker
+        worker.pool = Mock()
+        worker.pool.apply_async.side_effect = KeyError("some exception")
+        backend = Mock()
+        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+                           kwargs={})
+        task = Request.from_message(m, m.decode())
+        worker.process_task(task)
+        worker.pool.stop()
+
+    def test_start_catches_base_exceptions(self):
+        worker1 = self.create_worker()
+        stc = Mock()
+        stc.start.side_effect = SystemTerminate()
+        worker1.components = [stc]
+        worker1.start()
+        self.assertTrue(stc.terminate.call_count)
+
+        worker2 = self.create_worker()
+        sec = Mock()
+        sec.start.side_effect = SystemExit()
+        sec.terminate = None
+        worker2.components = [sec]
+        worker2.start()
+        self.assertTrue(sec.stop.call_count)
+
+    def test_state_db(self):
+        from celery.worker import state
+        Persistent = state.Persistent
+
+        state.Persistent = Mock()
+        try:
+            worker = self.create_worker(state_db="statefilename")
+            self.assertTrue(worker._persistence)
+        finally:
+            state.Persistent = Persistent
+
+    def test_disable_rate_limits_solo(self):
+        worker = self.create_worker(disable_rate_limits=True,
+                                    pool_cls="solo")
+        self.assertIsInstance(worker.ready_queue, FastQueue)
+        self.assertIsNone(worker.mediator)
+        self.assertEqual(worker.ready_queue.put, worker.process_task)
+
+    def test_disable_rate_limits_processes(self):
+        try:
+            worker = self.create_worker(disable_rate_limits=True,
+                                        pool_cls="processes")
+        except ImportError:
+            raise SkipTest("multiprocessing not supported")
+        self.assertIsInstance(worker.ready_queue, FastQueue)
+        self.assertTrue(worker.mediator)
+        self.assertNotEqual(worker.ready_queue.put, worker.process_task)
+
+    def test_start__stop(self):
+        worker = self.worker
+        worker._shutdown_complete.set()
+        worker.components = [Mock(), Mock(), Mock(), Mock()]
+
+        worker.start()
+        for w in worker.components:
+            self.assertTrue(w.start.call_count)
+        worker.stop()
+        for component in worker.components:
+            self.assertTrue(w.stop.call_count)
+
+    def test_component_raises(self):
+        worker = self.worker
+        comp = Mock()
+        worker.components = [comp]
+        comp.start.side_effect = TypeError()
+        worker.stop = Mock()
+        worker.start()
+        worker.stop.assert_called_with()
+
+    def test_state(self):
+        self.assertTrue(self.worker.state)
+
+    def test_start__terminate(self):
+        worker = self.worker
+        worker._shutdown_complete.set()
+        worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
+        for component in worker.components[:3]:
+            component.terminate = None
+
+        worker.start()
+        for w in worker.components[:3]:
+            self.assertTrue(w.start.call_count)
+        self.assertTrue(worker._running, len(worker.components))
+        self.assertEqual(worker._state, RUN)
+        worker.terminate()
+        for component in worker.components[:3]:
+            self.assertTrue(component.stop.call_count)
+        self.assertTrue(worker.components[4].terminate.call_count)
+
+    def test_Queues_pool_not_rlimit_safe(self):
+        w = Mock()
+        w.pool_cls.rlimit_safe = False
+        Queues(w).create(w)
+        self.assertTrue(w.disable_rate_limits)

+ 2 - 2
celery/worker/autoreload.py

@@ -24,7 +24,7 @@ from celery.utils.threads import bgThread, Event
 
 from .abstract import StartStopComponent
 
-try:
+try:                        # pragma: no cover
     import pyinotify
     _ProcessEvent = pyinotify.ProcessEvent
 except ImportError:         # pragma: no cover
@@ -137,7 +137,7 @@ class KQueueMonitor(BaseMonitor):
     def stop(self):
         self._kq.close()
         for fd in filter(None, self.filemap.values()):
-            with ignore_EBADF():
+            with ignore_EBADF():  # pragma: no cover
                 os.close(fd)
             self.filemap[fd] = None
         self.filemap.clear()

+ 2 - 2
celery/worker/autoscale.py

@@ -94,8 +94,8 @@ class Autoscaler(bgThread):
         with self.mutex:
             new = self.processes - n
             if new < self.min_concurrency:
-                self.min_concurrency = new
-            self._shrink(n)
+                self.min_concurrency = max(new, 0)
+            self._shrink(min(n, self.processes))
 
     def scale_up(self, n):
         self._last_action = time()

+ 2 - 2
celery/worker/consumer.py

@@ -358,14 +358,14 @@ class Consumer(object):
         debug("Ready to accept tasks!")
 
         while self._state != CLOSE and self.connection:
-            if self.qos.prev != self.qos.value:
+            if self.qos.prev != self.qos.value:     # pragma: no cover
                 self.qos.update()
             try:
                 self.connection.drain_events(timeout=1)
             except socket.timeout:
                 pass
             except socket.error:
-                if self._state != CLOSE:
+                if self._state != CLOSE:            # pragma: no cover
                     raise
 
     def on_task(self, task):