Преглед изворни кода

99% overall coverage :happy:

Ask Solem пре 13 година
родитељ
комит
53b61c638b
58 измењених фајлова са 1071 додато и 195 уклоњено
  1. 2 1
      celery/app/task.py
  2. 9 7
      celery/apps/worker.py
  3. 1 1
      celery/backends/base.py
  4. 1 0
      celery/backends/redis.py
  5. 30 35
      celery/beat.py
  6. 1 1
      celery/bin/base.py
  7. 1 1
      celery/bin/celeryd.py
  8. 1 1
      celery/events/state.py
  9. 1 1
      celery/result.py
  10. 1 1
      celery/task/http.py
  11. 6 1
      celery/tests/__init__.py
  12. 31 0
      celery/tests/app/test_amqp.py
  13. 5 1
      celery/tests/app/test_app.py
  14. 129 18
      celery/tests/app/test_beat.py
  15. 15 0
      celery/tests/app/test_control.py
  16. 12 0
      celery/tests/app/test_defaults.py
  17. 15 3
      celery/tests/app/test_loaders.py
  18. 39 6
      celery/tests/app/test_log.py
  19. 4 0
      celery/tests/backends/test_amqp.py
  20. 24 0
      celery/tests/backends/test_base.py
  21. 7 1
      celery/tests/backends/test_database.py
  22. 2 0
      celery/tests/backends/test_pyredis_compat.py
  23. 33 0
      celery/tests/backends/test_redis.py
  24. 31 1
      celery/tests/bin/test_base.py
  25. 6 0
      celery/tests/bin/test_celerybeat.py
  26. 19 1
      celery/tests/bin/test_celeryd.py
  27. 2 4
      celery/tests/bin/test_celeryd_multi.py
  28. 21 0
      celery/tests/bin/test_celeryev.py
  29. 16 0
      celery/tests/concurrency/test_concurrency.py
  30. 19 8
      celery/tests/concurrency/test_processes.py
  31. 36 15
      celery/tests/events/test_events.py
  32. 18 3
      celery/tests/events/test_snapshot.py
  33. 5 0
      celery/tests/events/test_state.py
  34. 6 10
      celery/tests/slow/test_buckets.py
  35. 8 0
      celery/tests/tasks/test_registry.py
  36. 136 1
      celery/tests/tasks/test_result.py
  37. 34 1
      celery/tests/tasks/test_tasks.py
  38. 139 0
      celery/tests/utilities/test_dispatcher.py
  39. 11 0
      celery/tests/utilities/test_imports.py
  40. 79 0
      celery/tests/utilities/test_saferef.py
  41. 29 1
      celery/tests/utilities/test_timeutils.py
  42. 16 1
      celery/tests/utilities/test_utils.py
  43. 19 0
      celery/tests/utils.py
  44. 2 4
      celery/tests/worker/test_autoreload.py
  45. 8 13
      celery/tests/worker/test_worker.py
  46. 19 19
      celery/utils/compat.py
  47. 4 4
      celery/utils/dispatch/saferef.py
  48. 2 7
      celery/utils/dispatch/signal.py
  49. 4 4
      celery/utils/functional.py
  50. 2 2
      celery/utils/log.py
  51. 2 4
      celery/utils/mail.py
  52. 1 1
      celery/utils/serialization.py
  53. 0 1
      celery/utils/timeutils.py
  54. 2 1
      contrib/release/doc4allmods
  55. 3 6
      contrib/release/py3k-run-tests
  56. 1 1
      requirements/test.txt
  57. 0 2
      setup.cfg
  58. 1 1
      setup.py

+ 2 - 1
celery/app/task.py

@@ -817,7 +817,8 @@ class BaseTask(object):
 
     def annotate(self):
         for d in resolve_all_annotations(self.app.annotations, self):
-            self.__dict__.update(d)
+            for key, value in d.iteritems():
+                setattr(self, key, value)
 
     def __repr__(self):
         """`repr(task)`"""

+ 9 - 7
celery/apps/worker.py

@@ -25,7 +25,7 @@ from celery.worker import WorkController
 try:
     from greenlet import GreenletExit
     IGNORE_ERRORS = (GreenletExit, )
-except ImportError:
+except ImportError:  # pragma: no cover
     IGNORE_ERRORS = ()
 
 logger = get_logger(__name__)
@@ -302,15 +302,17 @@ def install_cry_handler():
     # Jython/PyPy does not have sys._current_frames
     is_jython = sys.platform.startswith("java")
     is_pypy = hasattr(sys, "pypy_version_info")
-    if not (is_jython or is_pypy):
+    if is_jython or is_pypy:  # pragma: no cover
+        return
 
-        def cry_handler(signum, frame):
-            """Signal handler logging the stacktrace of all active threads."""
-            logger.error("\n" + cry())
-        platforms.signals["SIGUSR1"] = cry_handler
+    def cry_handler(signum, frame):
+        """Signal handler logging the stacktrace of all active threads."""
+        logger.error("\n" + cry())
+    platforms.signals["SIGUSR1"] = cry_handler
 
 
-def install_rdb_handler(envvar="CELERY_RDBSIG", sig="SIGUSR2"):
+def install_rdb_handler(envvar="CELERY_RDBSIG",
+                        sig="SIGUSR2"):  # pragma: no cover
 
     def rdb_handler(signum, frame):
         """Signal handler setting a rdb breakpoint at the current frame."""

+ 1 - 1
celery/backends/base.py

@@ -120,7 +120,7 @@ class BaseBackend(object):
         if self.serializer in EXCEPTION_ABLE_CODECS:
             return get_pickled_exception(exc)
         return create_exception_cls(from_utf8(exc["exc_type"]),
-                                    sys.modules[__name__])
+                                    sys.modules[__name__])(exc["exc_message"])
 
     def prepare_value(self, result):
         """Prepare value for storage."""

+ 1 - 0
celery/backends/redis.py

@@ -62,6 +62,7 @@ class RedisBackend(KeyValueStoreBackend):
         uhost = uport = upass = udb = None
         if url:
             _, uhost, uport, _, upass, udb, _ = _parse_url(url)
+            udb = udb.strip("/")
         self.host = uhost or host or _get("HOST") or self.host
         self.port = int(uport or port or _get("PORT") or self.port)
         self.db = udb or db or _get("DB") or self.db

+ 30 - 35
celery/beat.py

@@ -16,10 +16,9 @@ import os
 import time
 import shelve
 import sys
-import threading
 import traceback
 
-from billiard import Process
+from billiard import Process, ensure_multiprocessing
 from kombu.utils import reprcall
 from kombu.utils.functional import maybe_promise
 
@@ -31,6 +30,7 @@ from .app import app_or_default
 from .schedules import maybe_schedule, crontab
 from .utils import cached_property
 from .utils.imports import instantiate
+from .utils.threads import Event, Thread
 from .utils.timeutils import humanize_seconds
 from .utils.log import get_logger
 
@@ -229,12 +229,12 @@ class Scheduler(object):
             raise SchedulingError, SchedulingError(
                 "Couldn't apply scheduled task %s: %s" % (
                     entry.name, exc)), sys.exc_info()[2]
-
-        if self.should_sync():
-            self._do_sync()
+        finally:
+            if self.should_sync():
+                self._do_sync()
         return result
 
-    def send_task(self, *args, **kwargs):               # pragma: no cover
+    def send_task(self, *args, **kwargs):
         return self.app.send_task(*args, **kwargs)
 
     def setup_schedule(self):
@@ -283,12 +283,6 @@ class Scheduler(object):
             else:
                 schedule[key] = entry
 
-    def get_schedule(self):
-        return self.data
-
-    def set_schedule(self, schedule):
-        self.data = schedule
-
     def _ensure_connected(self):
         # callback called for each retry while the connection
         # can't be established.
@@ -299,6 +293,13 @@ class Scheduler(object):
         return self.connection.ensure_connection(_error_handler,
                     self.app.conf.BROKER_CONNECTION_MAX_RETRIES)
 
+    def get_schedule(self):
+        return self.data
+
+    def set_schedule(self, schedule):
+        self.data = schedule
+    schedule = property(get_schedule, set_schedule)
+
     @cached_property
     def connection(self):
         return self.app.broker_connection()
@@ -307,10 +308,6 @@ class Scheduler(object):
     def publisher(self):
         return self.Publisher(connection=self._ensure_connected())
 
-    @property
-    def schedule(self):
-        return self.get_schedule()
-
     @property
     def info(self):
         return ""
@@ -318,6 +315,7 @@ class Scheduler(object):
 
 class PersistentScheduler(Scheduler):
     persistence = shelve
+    known_suffixes = ("", ".db", ".dat", ".bak", ".dir")
 
     _store = None
 
@@ -326,7 +324,7 @@ class PersistentScheduler(Scheduler):
         Scheduler.__init__(self, *args, **kwargs)
 
     def _remove_db(self):
-        for suffix in "", ".db", ".dat", ".bak", ".dir":
+        for suffix in self.known_suffixes:
             try:
                 os.remove(self.schedule_filename + suffix)
             except OSError, exc:
@@ -358,6 +356,10 @@ class PersistentScheduler(Scheduler):
     def get_schedule(self):
         return self._store["entries"]
 
+    def set_schedule(self, schedule):
+        self._store["entries"] = schedule
+    schedule = property(get_schedule, set_schedule)
+
     def sync(self):
         if self._store is not None:
             self._store.sync()
@@ -383,8 +385,8 @@ class Service(object):
         self.schedule_filename = schedule_filename or \
                                     app.conf.CELERYBEAT_SCHEDULE_FILENAME
 
-        self._is_shutdown = threading.Event()
-        self._is_stopped = threading.Event()
+        self._is_shutdown = Event()
+        self._is_stopped = Event()
 
     def start(self, embedded_process=False):
         info("Celerybeat: Starting...")
@@ -397,7 +399,7 @@ class Service(object):
             platforms.set_process_title("celerybeat")
 
         try:
-            while not self._is_shutdown.isSet():
+            while not self._is_shutdown.is_set():
                 interval = self.scheduler.tick()
                 debug("Celerybeat: Waking up %s.",
                       humanize_seconds(interval, prefix="in "))
@@ -430,14 +432,14 @@ class Service(object):
         return self.get_scheduler()
 
 
-class _Threaded(threading.Thread):
+class _Threaded(Thread):
     """Embedded task scheduler using threading."""
 
     def __init__(self, *args, **kwargs):
         super(_Threaded, self).__init__()
         self.service = Service(*args, **kwargs)
-        self.setDaemon(True)
-        self.setName("Beat")
+        self.daemon = True
+        self.name = "Beat"
 
     def run(self):
         self.service.start()
@@ -446,16 +448,12 @@ class _Threaded(threading.Thread):
         self.service.stop(wait=True)
 
 
-supports_fork = True
 try:
-    from billiard._ext import _billiard
-    supports_fork = True if _billiard else False
-except ImportError:
-    supports_fork = False
-
-if supports_fork:
-    class _Process(Process):
-        """Embedded task scheduler using multiprocessing."""
+    ensure_multiprocessing()
+except NotImplementedError:     # pragma: no cover
+    _Process = None
+else:
+    class _Process(Process):    # noqa
 
         def __init__(self, *args, **kwargs):
             super(_Process, self).__init__()
@@ -469,8 +467,6 @@ if supports_fork:
         def stop(self):
             self.service.stop()
             self.terminate()
-else:
-    _Process = None
 
 
 def EmbeddedService(*args, **kwargs):
@@ -485,5 +481,4 @@ def EmbeddedService(*args, **kwargs):
         # in reasonable time.
         kwargs.setdefault("max_interval", 1)
         return _Threaded(*args, **kwargs)
-
     return _Process(*args, **kwargs)

+ 1 - 1
celery/bin/base.py

@@ -139,7 +139,7 @@ class Command(object):
         # Don't want to load configuration to just print the version,
         # so we handle --version manually here.
         if "--version" in arguments:
-            print(self.version)
+            sys.stdout.write("%s\n" % self.version)
             sys.exit(0)
         parser = self.create_parser(prog_name)
         return parser.parse_args(arguments)

+ 1 - 1
celery/bin/celeryd.py

@@ -188,7 +188,7 @@ def main():
     # Fix for setuptools generated scripts, so that it will
     # work with multiprocessing fork emulation.
     # (see multiprocessing.forking.get_preparation_data())
-    if __name__ != "__main__":
+    if __name__ != "__main__":  # pragma: no cover
         sys.modules["__main__"] = sys.modules[__name__]
     freeze_support()
     worker = WorkerCommand()

+ 1 - 1
celery/events/state.py

@@ -299,7 +299,7 @@ class State(object):
     def itertasks(self, limit=None):
         for index, row in enumerate(self.tasks.iteritems()):
             yield row
-            if limit and index >= limit:
+            if limit and index + 1 >= limit:
                 break
 
     def tasks_by_timestamp(self, limit=None):

+ 1 - 1
celery/result.py

@@ -632,7 +632,7 @@ class TaskSetResult(ResultSet):
         return self.id
 
     def _set_taskset_id(self, id):
-        self.taskset_id = id
+        self.id = id
     taskset_id = property(_get_taskset_id, _set_taskset_id)
 
 

+ 1 - 1
celery/task/http.py

@@ -47,7 +47,7 @@ def maybe_utf8(value):
     return value
 
 
-if sys.version_info >= (3, 0):
+if sys.version_info[0] == 3:  # pragma: no cover
 
     def utf8dict(tup):
         if not isinstance(tup, dict):

+ 6 - 1
celery/tests/__init__.py

@@ -1,8 +1,10 @@
 from __future__ import absolute_import
+from __future__ import with_statement
 
 import logging
 import os
 import sys
+import warnings
 
 from importlib import import_module
 
@@ -77,4 +79,7 @@ def import_all_modules(name=__name__, file=__file__,
 
 
 if os.environ.get("COVER_ALL_MODULES") or "--with-coverage3" in sys.argv:
-    import_all_modules()
+    from celery.tests.utils import catch_warnings
+    with catch_warnings(record=True):
+        import_all_modules()
+    warnings.resetwarnings()

+ 31 - 0
celery/tests/app/test_amqp.py

@@ -29,6 +29,21 @@ class test_TaskPublisher(AppCase):
             pass
         publisher.release.assert_called_with()
 
+    def test_declare(self):
+        publisher = self.app.amqp.TaskPublisher(self.app.broker_connection())
+        publisher.exchange.name = "foo"
+        publisher.declare()
+        publisher.exchange.name = None
+        publisher.declare()
+
+    def test_exit_AttributeError(self):
+        publisher = self.app.amqp.TaskPublisher(self.app.broker_connection())
+        publisher.close = Mock()
+        publisher.release = Mock()
+        publisher.release.side_effect = AttributeError()
+        publisher.__exit__()
+        publisher.close.assert_called_with()
+
     def test_ensure_declare_queue(self, q="x1242112"):
         publisher = self.app.amqp.TaskPublisher(Mock())
         self.app.amqp.queues.add(q, q, q)
@@ -103,3 +118,19 @@ class test_PublisherPool(AppCase):
             r2.release()
         finally:
             self.app.conf.BROKER_POOL_LIMIT = L
+
+
+class test_Queues(AppCase):
+
+    def test_queues_format(self):
+        prev, self.app.amqp.queues._consume_from = \
+                self.app.amqp.queues._consume_from, {}
+        try:
+            self.assertEqual(self.app.amqp.queues.format(), "")
+        finally:
+            self.app.amqp.queues._consume_from = prev
+
+    def test_with_defaults(self):
+        self.assertEqual(
+            self.app.amqp.queues.with_defaults(None, "celery", "direct"),
+            {})

+ 5 - 1
celery/tests/app/test_app.py

@@ -164,7 +164,11 @@ class test_App(Case):
         self.assertEqual(self.app.conf.BROKER_TRANSPORT, "set_by_us")
 
     def test_WorkController(self):
-        x = self.app.Worker()
+        x = self.app.WorkController
+        self.assertIs(x.app, self.app)
+
+    def test_Worker(self):
+        x = self.app.Worker
         self.assertIs(x.app, self.app)
 
     def test_AsyncResult(self):

+ 129 - 18
celery/tests/app/test_beat.py

@@ -1,15 +1,19 @@
 from __future__ import absolute_import
+from __future__ import with_statement
+
+import errno
 
 from datetime import datetime, timedelta
-from mock import patch
+from mock import Mock, call, patch
 from nose import SkipTest
 
 from celery import beat
+from celery import task
 from celery.result import AsyncResult
 from celery.schedules import schedule
 from celery.task.base import Task
 from celery.utils import uuid
-from celery.tests.utils import Case
+from celery.tests.utils import Case, patch_settings
 
 
 class Object(object):
@@ -159,10 +163,69 @@ class test_Scheduler(Case):
         scheduler.apply_async(scheduler.Entry(task=MockTask.name))
         self.assertTrue(through_task[0])
 
+    def test_apply_async_should_not_sync(self):
+
+        @task
+        def not_sync():
+            pass
+        not_sync.apply_async = Mock()
+
+        s = mScheduler()
+        s._do_sync = Mock()
+        s.should_sync = Mock()
+        s.should_sync.return_value = True
+        s.apply_async(s.Entry(task=not_sync.name))
+        s._do_sync.assert_called_with()
+
+        s._do_sync = Mock()
+        s.should_sync.return_value = False
+        s.apply_async(s.Entry(task=not_sync.name))
+        self.assertFalse(s._do_sync.called)
+
+    @patch("celery.app.base.Celery.send_task")
+    def test_send_task(self, send_task):
+        b = beat.Scheduler()
+        b.send_task("tasks.add", countdown=10)
+        send_task.assert_called_with("tasks.add", countdown=10)
+
     def test_info(self):
         scheduler = mScheduler()
         self.assertIsInstance(scheduler.info, basestring)
 
+    def test_maybe_entry(self):
+        s = mScheduler()
+        entry = s.Entry(name="add every", task="tasks.add")
+        self.assertIs(s._maybe_entry(entry.name, entry), entry)
+        self.assertTrue(s._maybe_entry("add every", {
+            "task": "tasks.add",
+        }))
+
+    def test_set_schedule(self):
+        s = mScheduler()
+        s.schedule = {"foo": "bar"}
+        self.assertEqual(s.data, {"foo": "bar"})
+
+    @patch("kombu.connection.Connection.ensure_connection")
+    def test_ensure_connection_error_handler(self, ensure):
+        s = mScheduler()
+        self.assertTrue(s._ensure_connected())
+        self.assertTrue(ensure.called)
+        callback = ensure.call_args[0][0]
+
+        callback(KeyError(), 5)
+
+    def test_install_default_entries(self):
+        with patch_settings(CELERY_TASK_RESULT_EXPIRES=None,
+                            CELERYBEAT_SCHEDULE={}):
+            s = mScheduler()
+            s.install_default_entries({})
+            self.assertNotIn("celery.backend_cleanup", s.data)
+        with patch_settings(CELERY_TASK_RESULT_EXPIRES=30,
+                            CELERYBEAT_SCHEDULE={}):
+            s = mScheduler()
+            s.install_default_entries({})
+            self.assertIn("celery.backend_cleanup", s.data)
+
     def test_due_tick(self):
         scheduler = mScheduler()
         scheduler.add(name="test_due_tick",
@@ -233,25 +296,73 @@ class test_Scheduler(Case):
         self.assertEqual(a.schedule["bar"].schedule._next_run_at, 40)
 
 
+def create_persistent_scheduler(shelv=None):
+    if shelv is None:
+        shelv = MockShelve()
+
+    class MockPersistentScheduler(beat.PersistentScheduler):
+        sh = shelv
+        persistence = Object()
+        persistence.open = lambda *a, **kw: shelv
+        tick_raises_exit = False
+        shutdown_service = None
+
+        def tick(self):
+            if self.tick_raises_exit:
+                raise SystemExit()
+            if self.shutdown_service:
+                self.shutdown_service._is_shutdown.set()
+            return 0.0
+
+    return MockPersistentScheduler, shelv
+
+
+class test_PersistentScheduler(Case):
+
+    @patch("os.remove")
+    def test_remove_db(self, remove):
+        s = create_persistent_scheduler()[0](schedule_filename="schedule")
+        s._remove_db()
+        remove.assert_has_calls(
+            [call("schedule" + suffix) for suffix in s.known_suffixes]
+        )
+        err = OSError()
+        err.errno = errno.ENOENT
+        remove.side_effect = err
+        s._remove_db()
+        err.errno = errno.EPERM
+        with self.assertRaises(OSError):
+            s._remove_db()
+
+    def test_setup_schedule(self):
+        s = create_persistent_scheduler()[0](schedule_filename="schedule")
+        opens = s.persistence.open = Mock()
+        s._remove_db = Mock()
+
+        def effect(*args, **kwargs):
+            if opens.call_count > 1:
+                return s.sh
+            raise OSError()
+        opens.side_effect = effect
+        s.setup_schedule()
+        s._remove_db.assert_called_with()
+
+        s._store = {"__version__": 1}
+        s.setup_schedule()
+
+    def test_get_schedule(self):
+        s = create_persistent_scheduler()[0](schedule_filename="schedule")
+        s._store = {"entries": {}}
+        s.schedule = {"foo": "bar"}
+        self.assertDictEqual(s.schedule, {"foo": "bar"})
+        self.assertDictEqual(s._store["entries"], s.schedule)
+
+
 class test_Service(Case):
 
     def get_service(self):
-        sh = MockShelve()
-
-        class PersistentScheduler(beat.PersistentScheduler):
-            persistence = Object()
-            persistence.open = lambda *a, **kw: sh
-            tick_raises_exit = False
-            shutdown_service = None
-
-            def tick(self):
-                if self.tick_raises_exit:
-                    raise SystemExit()
-                if self.shutdown_service:
-                    self.shutdown_service._is_shutdown.set()
-                return 0.0
-
-        return beat.Service(scheduler_cls=PersistentScheduler), sh
+        Scheduler, mock_shelve = create_persistent_scheduler()
+        return beat.Service(scheduler_cls=Scheduler), mock_shelve
 
     def test_start(self):
         s, sh = self.get_service()

+ 15 - 0
celery/tests/app/test_control.py

@@ -117,6 +117,16 @@ class test_inspect(Case):
         self.i.cancel_consumer("foo")
         self.assertIn("cancel_consumer", MockMailbox.sent)
 
+    @with_mock_broadcast
+    def test_active_queues(self):
+        self.i.active_queues()
+        self.assertIn("active_queues", MockMailbox.sent)
+
+    @with_mock_broadcast
+    def test_report(self):
+        self.i.report()
+        self.assertIn("report", MockMailbox.sent)
+
 
 class test_Broadcast(Case):
 
@@ -153,6 +163,11 @@ class test_Broadcast(Case):
         self.control.rate_limit(mytask.name, "100/m")
         self.assertIn("rate_limit", MockMailbox.sent)
 
+    @with_mock_broadcast
+    def test_time_limit(self):
+        self.control.time_limit(mytask.name, soft=10, hard=20)
+        self.assertIn("time_limit", MockMailbox.sent)
+
     @with_mock_broadcast
     def test_revoke(self):
         self.control.revoke("foozbaaz")

+ 12 - 0
celery/tests/app/test_defaults.py

@@ -4,6 +4,7 @@ from __future__ import with_statement
 import sys
 
 from importlib import import_module
+from mock import Mock, patch
 
 from celery.tests.utils import Case, pypy_version, sys_platform
 
@@ -17,6 +18,10 @@ class test_defaults(Case):
         if self._prev:
             sys.modules["celery.app.defaults"] = self._prev
 
+    def test_any(self):
+        val = object()
+        self.assertIs(self.defaults.Option.typemap["any"](val), val)
+
     def test_default_pool_pypy_14(self):
         with sys_platform("darwin"):
             with pypy_version((1, 4, 0)):
@@ -27,6 +32,13 @@ class test_defaults(Case):
             with pypy_version((1, 5, 0)):
                 self.assertEqual(self.defaults.DEFAULT_POOL, "processes")
 
+    def test_deprecated(self):
+        source = Mock()
+        source.BROKER_INSIST = True
+        with patch("celery.utils.warn_deprecated") as warn:
+            self.defaults.find_deprecated_settings(source)
+            self.assertTrue(warn.called)
+
     def test_default_pool_jython(self):
         with sys_platform("java 1.6.51"):
             self.assertEqual(self.defaults.DEFAULT_POOL, "threads")

+ 15 - 3
celery/tests/app/test_loaders.py

@@ -4,7 +4,7 @@ from __future__ import with_statement
 import os
 import sys
 
-from mock import patch
+from mock import Mock, patch
 
 from celery import loaders
 from celery.app import app_or_default
@@ -83,6 +83,17 @@ class test_LoaderBase(Case):
     def test_import_task_module(self):
         self.assertEqual(sys, self.loader.import_task_module("sys"))
 
+    def test_init_worker_process(self):
+        self.loader.on_worker_process_init()
+        m = self.loader.on_worker_process_init = Mock()
+        self.loader.init_worker_process()
+        m.assert_called_with()
+
+    def test_config_from_object_module(self):
+        self.loader.import_from_cwd = Mock()
+        self.loader.config_from_object("module_name")
+        self.loader.import_from_cwd.assert_called_with("module_name")
+
     def test_conf_property(self):
         self.assertEqual(self.loader.conf["foo"], "bar")
         self.assertEqual(self.loader._conf["foo"], "bar")
@@ -181,7 +192,7 @@ class test_DefaultLoader(Case):
         celeryconfig.CELERY_IMPORTS = ("os", "sys")
         configname = os.environ.get("CELERY_CONFIG_MODULE") or "celeryconfig"
 
-        prevconfig = sys.modules[configname]
+        prevconfig = sys.modules.get(configname)
         sys.modules[configname] = celeryconfig
         try:
             l = default.Loader()
@@ -191,7 +202,8 @@ class test_DefaultLoader(Case):
             self.assertTupleEqual(settings.CELERY_IMPORTS, ("os", "sys"))
             l.on_worker_init()
         finally:
-            sys.modules[configname] = prevconfig
+            if prevconfig:
+                sys.modules[configname] = prevconfig
 
     def test_import_from_cwd(self):
         l = default.Loader()

+ 39 - 6
celery/tests/app/test_log.py

@@ -8,17 +8,39 @@ from tempfile import mktemp
 from mock import patch, Mock
 
 from celery import current_app
-from celery.app.log import Logging
+from celery import signals
+from celery.app.log import Logging, TaskFormatter
 from celery.utils.log import LoggingProxy
 from celery.utils import uuid
-from celery.utils.log import get_logger, ColorFormatter, logger as base_logger
+from celery.utils.log import (
+    get_logger,
+    ColorFormatter,
+    logger as base_logger,
+)
 from celery.tests.utils import (
-    Case, override_stdouts, wrap_logger, get_handlers,
+    AppCase, Case, override_stdouts, wrap_logger, get_handlers,
 )
 
 log = current_app.log
 
 
+class test_TaskFormatter(Case):
+
+    def test_no_task(self):
+        class Record(object):
+            msg = "hello world"
+            levelname = "info"
+            exc_text = exc_info = None
+
+            def getMessage(self):
+                return self.msg
+        record = Record()
+        x = TaskFormatter()
+        x.format(record)
+        self.assertEqual(record.task_name, "???")
+        self.assertEqual(record.task_id, "???")
+
+
 class test_ColorFormatter(Case):
 
     @patch("celery.utils.log.safe_str")
@@ -71,11 +93,12 @@ class test_ColorFormatter(Case):
         self.assertEqual(safe_str.call_count, 1)
 
 
-class test_default_logger(Case):
+class test_default_logger(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.setup_logger = log.setup_logger
         self.get_logger = lambda n=None: get_logger(n) if n else logging.root
+        signals.setup_logging.receivers[:] = []
         Logging._setup = False
 
     def test_get_logger_sets_parent(self):
@@ -86,6 +109,14 @@ class test_default_logger(Case):
         logger = get_logger(base_logger.name)
         self.assertIs(logger.parent, logging.root)
 
+    def test_setup_logging_subsystem_misc(self):
+        log.setup_logging_subsystem(loglevel=None)
+        self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = True
+        try:
+            log.setup_logging_subsystem()
+        finally:
+            self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = False
+
     def test_setup_logging_subsystem_colorize(self):
         log.setup_logging_subsystem(colorize=None)
         log.setup_logging_subsystem(colorize=True)
@@ -149,6 +180,8 @@ class test_default_logger(Case):
                 log.redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
                 logger.error("foo")
                 self.assertIn("foo", sio.getvalue())
+                log.redirect_stdouts_to_logger(logger, stdout=False,
+                        stderr=False)
         finally:
             sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
 
@@ -186,7 +219,7 @@ class test_default_logger(Case):
 
 class test_task_logger(test_default_logger):
 
-    def setUp(self):
+    def setup(self):
         logger = self.logger = get_logger("celery.task")
         logger.handlers = []
         logging.root.manager.loggerDict.pop(logger.name, None)

+ 4 - 0
celery/tests/backends/test_amqp.py

@@ -42,6 +42,10 @@ class test_AMQPBackend(Case):
         self.assertTrue(tb2._cache.get(tid))
         self.assertTrue(tb2.get_result(tid), 42)
 
+    def test_revive(self):
+        tb = self.create_backend()
+        tb.revive(None)
+
     def test_is_pickled(self):
         tb1 = self.create_backend()
         tb2 = self.create_backend()

+ 24 - 0
celery/tests/backends/test_base.py

@@ -58,6 +58,10 @@ class test_BaseBackend_interface(Case):
         with self.assertRaises(NotImplementedError):
             b.forget("SOMExx-N0Nex1stant-IDxx-")
 
+    def test_get_children(self):
+        with self.assertRaises(NotImplementedError):
+            b.get_children("SOMExx-N0Nex1stant-IDxx-")
+
     def test_store_result(self):
         with self.assertRaises(NotImplementedError):
             b.store_result("SOMExx-N0nex1stant-IDxx-", 42, states.SUCCESS)
@@ -98,6 +102,9 @@ class test_BaseBackend_interface(Case):
         with self.assertRaises(NotImplementedError):
             b.forget("SOMExx-N0nex1stant-IDxx-")
 
+    def test_on_chord_part_return(self):
+        b.on_chord_part_return(None)
+
     def test_on_chord_apply(self, unlock="celery.chord_unlock"):
         p, current_app.tasks[unlock] = current_app.tasks.get(unlock), Mock()
         try:
@@ -138,6 +145,7 @@ class test_prepare_exception(Case):
     def test_impossible(self):
         x = b.prepare_exception(Impossible())
         self.assertIsInstance(x, UnpickleableExceptionWrapper)
+        self.assertTrue(str(x))
         y = b.exception_to_python(x)
         self.assertEqual(y.__class__.__name__, "Impossible")
         if sys.version_info < (2, 5):
@@ -202,6 +210,14 @@ class test_BaseDictBackend(Case):
         self.b.delete_taskset("can-delete")
         self.assertNotIn("can-delete", self.b._data)
 
+    def test_prepare_exception_json(self):
+        x = DictBackend(serializer="json")
+        e = x.prepare_exception(KeyError("foo"))
+        self.assertIn("exc_type", e)
+        e = x.exception_to_python(e)
+        self.assertEqual(e.__class__.__name__, "KeyError")
+        self.assertEqual(str(e), "'foo'")
+
     def test_save_taskset(self):
         b = BaseDictBackend()
         b._save_taskset = Mock()
@@ -237,6 +253,10 @@ class test_KeyValueStoreBackend(Case):
     def setUp(self):
         self.b = KVBackend()
 
+    def test_on_chord_part_return(self):
+        assert not self.b.implements_incr
+        self.b.on_chord_part_return(None)
+
     def test_get_store_delete_result(self):
         tid = uuid()
         self.b.mark_as_done(tid, "Hello world")
@@ -290,6 +310,10 @@ class test_KeyValueStoreBackend_interface(Case):
         with self.assertRaises(NotImplementedError):
             KeyValueStoreBackend().set("a", 1)
 
+    def test_incr(self):
+        with self.assertRaises(NotImplementedError):
+            KeyValueStoreBackend().incr("a")
+
     def test_cleanup(self):
         self.assertFalse(KeyValueStoreBackend().cleanup())
 

+ 7 - 1
celery/tests/backends/test_database.py

@@ -6,6 +6,7 @@ import sys
 from datetime import datetime
 
 from nose import SkipTest
+from pickle import loads, dumps
 
 from celery import states
 from celery.app import app_or_default
@@ -151,7 +152,8 @@ class test_DatabaseBackend(Case):
         tb = DatabaseBackend(backend="memory://")
         tid = uuid()
         tb.mark_as_done(tid, {"foo": "bar"})
-        x = AsyncResult(tid)
+        tb.mark_as_done(tid, {"foo": "bar"})
+        x = AsyncResult(tid, backend=tb)
         x.forget()
         self.assertIsNone(x.result)
 
@@ -159,6 +161,10 @@ class test_DatabaseBackend(Case):
         tb = DatabaseBackend()
         tb.process_cleanup()
 
+    def test_reduce(self):
+        tb = DatabaseBackend()
+        self.assertTrue(loads(dumps(tb)))
+
     def test_save__restore__delete_taskset(self):
         tb = DatabaseBackend()
 

+ 2 - 0
celery/tests/backends/test_pyredis_compat.py

@@ -1,6 +1,7 @@
 from __future__ import absolute_import
 
 from nose import SkipTest
+from pickle import loads, dumps
 
 from celery.exceptions import ImproperlyConfigured
 from celery.tests.utils import Case
@@ -19,3 +20,4 @@ class test_RedisBackend(Case):
         self.assertEqual(x.redis_port, 312)
         self.assertEqual(x.redis_db, 1)
         self.assertEqual(x.redis_password, "foo")
+        self.assertTrue(loads(dumps(x)))

+ 33 - 0
celery/tests/backends/test_redis.py

@@ -1,11 +1,16 @@
 from __future__ import absolute_import
+from __future__ import with_statement
 
 from datetime import timedelta
 
 from mock import Mock, patch
+from nose import SkipTest
+from pickle import loads, dumps
 
 from celery import current_app
 from celery import states
+from celery.datastructures import AttributeDict
+from celery.exceptions import ImproperlyConfigured
 from celery.result import AsyncResult
 from celery.task import subtask
 from celery.utils import cached_property, uuid
@@ -81,6 +86,34 @@ class test_RedisBackend(Case):
 
         self.MockBackend = MockBackend
 
+    def test_reduce(self):
+        try:
+            from celery.backends.redis import RedisBackend
+            x = RedisBackend()
+            self.assertTrue(loads(dumps(x)))
+        except ImportError:
+            raise SkipTest("redis not installed")
+
+    def test_no_redis(self):
+        self.MockBackend.redis = None
+        with self.assertRaises(ImproperlyConfigured):
+            self.MockBackend()
+
+    def test_url(self):
+        x = self.MockBackend("redis://foobar//1")
+        self.assertEqual(x.host, "foobar")
+        self.assertEqual(x.db, "1")
+
+    def test_conf_raises_KeyError(self):
+        conf = AttributeDict({"CELERY_RESULT_SERIALIZER": "json",
+                              "CELERY_MAX_CACHED_RESULTS": 1,
+                              "CELERY_TASK_RESULT_EXPIRES": None})
+        prev, current_app.conf = current_app.conf, conf
+        try:
+            self.MockBackend()
+        finally:
+            current_app.conf = prev
+
     def test_expires_defaults_to_config(self):
         conf = current_app.conf
         prev = conf.CELERY_TASK_RESULT_EXPIRES

+ 31 - 1
celery/tests/bin/test_base.py

@@ -3,7 +3,9 @@ from __future__ import with_statement
 
 import os
 
-from celery.bin.base import Command
+from mock import patch
+
+from celery.bin.base import Command, Option
 from celery.tests.utils import AppCase, override_stdouts
 
 
@@ -41,6 +43,13 @@ class test_Command(AppCase):
         with self.assertRaises(NotImplementedError):
             Command().run()
 
+    @patch("sys.stdout")
+    def test_parse_options_version_only(self, stdout):
+        cmd = Command()
+        with self.assertRaises(SystemExit):
+            cmd.parse_options("prog", ["--version"])
+        stdout.write.assert_called_with(cmd.version + "\n")
+
     def test_execute_from_commandline(self):
         cmd = MockCommand()
         args1, kwargs1 = cmd.execute_from_commandline()     # sys.argv
@@ -71,6 +80,21 @@ class test_Command(AppCase):
         finally:
             if prev:
                 os.environ["CELERY_CONFIG_MODULE"] = prev
+            else:
+                os.environ.pop("CELERY_CONFIG_MODULE", None)
+
+    def test_with_custom_broker(self):
+        prev = os.environ.pop("CELERY_BROKER_URL", None)
+        try:
+            cmd = MockCommand()
+            cmd.setup_app_from_commandline(["--broker=xyzza://"])
+            self.assertEqual(os.environ.get("CELERY_BROKER_URL"),
+                    "xyzza://")
+        finally:
+            if prev:
+                os.environ["CELERY_BROKER_URL"] = prev
+            else:
+                os.environ.pop("CELERY_BROKER_URL", None)
 
     def test_with_custom_app(self):
         cmd = MockCommand()
@@ -89,3 +113,9 @@ class test_Command(AppCase):
         self.assertEqual(cmd.app.conf.BROKER_HOST, "broker.example.com")
         self.assertEqual(cmd.app.conf.CELERYD_PREFETCH_MULTIPLIER, 100)
         self.assertListEqual(rest, ["--loglevel=INFO"])
+
+    def test_parse_preload_options_shortopt(self):
+        cmd = Command()
+        cmd.preload_options = (Option("-s", action="store", dest="silent"), )
+        acc = cmd.parse_preload_options(["-s", "yes"])
+        self.assertEqual(acc.get("silent"), "yes")

+ 6 - 0
celery/tests/bin/test_celerybeat.py

@@ -182,6 +182,12 @@ class test_div(AppCase):
         self.assertTrue(MockDaemonContext.opened)
         self.assertTrue(MockDaemonContext.closed)
 
+    @patch("os.chdir")
+    def test_prepare_preload_options(self, chdir):
+        cmd = celerybeat_bin.BeatCommand()
+        cmd.prepare_preload_options({"working_directory": "/opt/Project"})
+        chdir.assert_called_with("/opt/Project")
+
     def test_parse_options(self):
         cmd = celerybeat_bin.BeatCommand()
         cmd.app = app_or_default()

+ 19 - 1
celery/tests/bin/test_celeryd.py

@@ -7,7 +7,7 @@ import sys
 
 from functools import wraps
 
-from mock import patch
+from mock import Mock, patch
 from nose import SkipTest
 
 from billiard import current_process
@@ -64,6 +64,16 @@ class test_Worker(AppCase):
         self.assertEqual(worker.use_queues, ["foo", "bar", "baz"])
         self.assertTrue("foo" in celery.amqp.queues)
 
+    @disable_stdouts
+    def test_cpu_count(self):
+        celery = Celery(set_as_current=False)
+        with patch("celery.apps.worker.cpu_count") as cpu_count:
+            cpu_count.side_effect = NotImplementedError()
+            worker = celery.Worker(concurrency=None)
+            self.assertEqual(worker.concurrency, 2)
+        worker = celery.Worker(concurrency=5)
+        self.assertEqual(worker.concurrency, 5)
+
     @disable_stdouts
     def test_windows_B_option(self):
         celery = Celery(set_as_current=False)
@@ -139,6 +149,14 @@ class test_Worker(AppCase):
         worker.init_loader()
         worker.run()
 
+        prev, cd.IGNORE_ERRORS = cd.IGNORE_ERRORS, (KeyError, )
+        try:
+            worker.run_worker = Mock()
+            worker.run_worker.side_effect = KeyError()
+            worker.run()
+        finally:
+            cd.IGNORE_ERRORS = prev
+
     @disable_stdouts
     def test_purge_messages(self):
         self.Worker().purge_messages()

+ 2 - 4
celery/tests/bin/test_celeryd_multi.py

@@ -312,7 +312,7 @@ class test_MultiTool(Case):
         self.prepare_pidfile_for_getpids(PIDFile)
         self.assertIsNone(self.t.shutdown_nodes([]))
         self.t.signal_node = Mock()
-        self.t.node_alive = Mock()
+        node_alive = self.t.node_alive = Mock()
         self.t.node_alive.return_value = False
 
         callback = Mock()
@@ -324,11 +324,9 @@ class test_MultiTool(Case):
         self.t.signal_node.return_value = False
         self.assertTrue(callback.called)
         self.t.stop(["foo", "bar", "baz"], "celeryd", callback=None)
-        calls = [0]
 
         def on_node_alive(pid):
-            calls[0] += 1
-            if calls[0] > 3:
+            if node_alive.call_count > 4:
                 return True
             return False
         self.t.signal_node.return_value = True

+ 21 - 0
celery/tests/bin/test_celeryev.py

@@ -1,6 +1,8 @@
 from __future__ import absolute_import
+from __future__ import with_statement
 
 from nose import SkipTest
+from mock import patch as mpatch
 
 from celery.app import app_or_default
 from celery.bin import celeryev
@@ -32,6 +34,14 @@ class test_EvCommand(Case):
         self.assertEqual(self.ev.run(dump=True), "me dumper, you?")
         self.assertIn("celeryev:dump", proctitle.last[0])
 
+    @mpatch("os.chdir")
+    def test_prepare_preload_options(self, chdir):
+        self.ev.prepare_preload_options({"working_directory": "/opt/Project"})
+        chdir.assert_called_with("/opt/Project")
+        chdir.called = False
+        self.ev.prepare_preload_options({})
+        self.assertFalse(chdir.called)
+
     def test_run_top(self):
         try:
             import curses  # noqa
@@ -56,6 +66,17 @@ class test_EvCommand(Case):
         self.assertEqual(kw["logfile"], "logfile")
         self.assertIn("celeryev:cam", proctitle.last[0])
 
+    @mpatch("celery.events.snapshot.evcam")
+    @mpatch("celery.bin.celeryev.detached")
+    def test_run_cam_detached(self, detached, evcam):
+        self.ev.prog_name = "celeryev"
+        self.ev.run_evcam("myapp.Camera", detach=True)
+        self.assertTrue(detached.called)
+        self.assertTrue(evcam.called)
+
+    def test_get_options(self):
+        self.assertTrue(self.ev.get_options())
+
     @patch("celery.bin.celeryev", "EvCommand", MockCommand)
     def test_main(self):
         MockCommand.executed = []

+ 16 - 0
celery/tests/concurrency/test_concurrency.py

@@ -47,6 +47,14 @@ class test_BasePool(Case):
                               {"target": (3, (8, 16)),
                                "callback": (4, (42, ))})
 
+    def test_does_not_debug(self):
+        x = BasePool(10)
+        x._does_debug = False
+        x.apply_async(object)
+
+    def test_num_processes(self):
+        self.assertEqual(BasePool(7).num_processes, 7)
+
     def test_interface_on_start(self):
         BasePool(10).on_start()
 
@@ -69,3 +77,11 @@ class test_BasePool(Case):
         p = BasePool(10)
         with self.assertRaises(NotImplementedError):
             p.restart()
+
+    def test_interface_on_terminate(self):
+        p = BasePool(10)
+        p.on_terminate()
+
+    def test_interface_terminate_job(self):
+        with self.assertRaises(NotImplementedError):
+            BasePool(10).terminate_job(101)

+ 19 - 8
celery/tests/concurrency/test_processes.py

@@ -76,9 +76,9 @@ class MockPool(object):
     def __init__(self, *args, **kwargs):
         self.started = True
         self._state = mp.RUN
-        self.processes = kwargs.get("processes")
-        self._pool = [Object(pid=i) for i in range(self.processes)]
-        self._current_proc = cycle(xrange(self.processes)).next
+        self._processes = kwargs.get("processes")
+        self._pool = [Object(pid=i) for i in range(self._processes)]
+        self._current_proc = cycle(xrange(self._processes)).next
 
     def close(self):
         self.closed = True
@@ -91,10 +91,10 @@ class MockPool(object):
         self.terminated = True
 
     def grow(self, n=1):
-        self.processes += n
+        self._processes += n
 
     def shrink(self, n=1):
-        self.processes -= n
+        self._processes -= n
 
     def apply_async(self, *args, **kwargs):
         pass
@@ -179,11 +179,11 @@ class test_TaskPool(Case):
     def test_grow_shrink(self):
         pool = TaskPool(10)
         pool.start()
-        self.assertEqual(pool._pool.processes, 10)
+        self.assertEqual(pool._pool._processes, 10)
         pool.grow()
-        self.assertEqual(pool._pool.processes, 11)
+        self.assertEqual(pool._pool._processes, 11)
         pool.shrink(2)
-        self.assertEqual(pool._pool.processes, 9)
+        self.assertEqual(pool._pool._processes, 9)
 
     def test_info(self):
         pool = TaskPool(10)
@@ -197,6 +197,17 @@ class test_TaskPool(Case):
         self.assertIsNone(info["max-tasks-per-child"])
         self.assertEqual(info["timeouts"], (5, 10))
 
+    def test_num_processes(self):
+        pool = TaskPool(7)
+        pool.start()
+        self.assertEqual(pool.num_processes, 7)
+
+    def test_restart_pool(self):
+        pool = TaskPool()
+        pool._pool = Mock()
+        pool.restart()
+        pool._pool.restart.assert_called_with()
+
     def test_restart(self):
         raise SkipTest("functional test")
 

+ 36 - 15
celery/tests/events/test_events.py

@@ -3,9 +3,10 @@ from __future__ import with_statement
 
 import socket
 
+from mock import Mock
+
 from celery import events
-from celery.app import app_or_default
-from celery.tests.utils import Case
+from celery.tests.utils import AppCase
 
 
 class MockProducer(object):
@@ -29,7 +30,7 @@ class MockProducer(object):
         return False
 
 
-class test_Event(Case):
+class test_Event(AppCase):
 
     def test_constructor(self):
         event = events.Event("world war II")
@@ -37,10 +38,7 @@ class test_Event(Case):
         self.assertTrue(event["timestamp"])
 
 
-class test_EventDispatcher(Case):
-
-    def setUp(self):
-        self.app = app_or_default()
+class test_EventDispatcher(AppCase):
 
     def test_send(self):
         producer = MockProducer()
@@ -67,6 +65,30 @@ class test_EventDispatcher(Case):
         for ev in evs:
             self.assertTrue(producer.has_event(ev))
 
+        buf = eventer._outbound_buffer = Mock()
+        buf.popleft.side_effect = IndexError()
+        eventer.flush()
+
+    def test_enter_exit(self):
+        with self.app.broker_connection() as conn:
+            d = self.app.events.Dispatcher(conn)
+            d.close = Mock()
+            with d as _d:
+                self.assertTrue(_d)
+            d.close.assert_called_with()
+
+    def test_enable_disable_callbacks(self):
+        on_enable = Mock()
+        on_disable = Mock()
+        with self.app.broker_connection() as conn:
+            with self.app.events.Dispatcher(conn, enabled=False) as d:
+                d.on_enabled.add(on_enable)
+                d.on_disabled.add(on_disable)
+                d.enable()
+                on_enable.assert_called_with()
+                d.disable()
+                on_disable.assert_called_with()
+
     def test_enabled_disable(self):
         connection = self.app.broker_connection()
         channel = connection.channel()
@@ -99,10 +121,7 @@ class test_EventDispatcher(Case):
             connection.close()
 
 
-class test_EventReceiver(Case):
-
-    def setUp(self):
-        self.app = app_or_default()
+class test_EventReceiver(AppCase):
 
     def test_process(self):
 
@@ -181,11 +200,13 @@ class test_EventReceiver(Case):
             connection.close()
 
 
-class test_misc(Case):
-
-    def setUp(self):
-        self.app = app_or_default()
+class test_misc(AppCase):
 
     def test_State(self):
         state = self.app.events.State()
         self.assertDictEqual(dict(state.workers), {})
+
+    def test_default_dispatcher(self):
+        with self.app.events.default_dispatcher() as d:
+            self.assertTrue(d)
+            self.assertTrue(d.connection)

+ 18 - 3
celery/tests/events/test_snapshot.py

@@ -1,6 +1,8 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
+from mock import patch
+
 from celery.app import app_or_default
 from celery.events import Events
 from celery.events.snapshot import Polaroid, evcam
@@ -114,11 +116,24 @@ class test_evcam(Case):
 
     def setUp(self):
         self.app = app_or_default()
-        self.app.events = self.MockEvents()
+        self.prev, self.app.events = self.app.events, self.MockEvents()
+
+    def tearDown(self):
+        self.app.events = self.prev
 
     def test_evcam(self):
         evcam(Polaroid, timer=timer)
         evcam(Polaroid, timer=timer, loglevel="CRITICAL")
         self.MockReceiver.raise_keyboard_interrupt = True
-        with self.assertRaises(SystemExit):
-            evcam(Polaroid, timer=timer)
+        try:
+            with self.assertRaises(SystemExit):
+                evcam(Polaroid, timer=timer)
+        finally:
+            self.MockReceiver.raise_keyboard_interrupt = False
+
+    @patch("atexit.register")
+    @patch("celery.platforms.create_pidlock")
+    def test_evcam_pidfile(self, create_pidlock, atexit):
+        evcam(Polaroid, timer=timer, pidfile="/var/pid")
+        self.assertTrue(atexit.called)
+        create_pidlock.assert_called_with("/var/pid")

+ 5 - 0
celery/tests/events/test_state.py

@@ -172,6 +172,11 @@ class test_State(Case):
         self.assertFalse(r.state.alive_workers())
         self.assertFalse(r.state.workers["utest1"].alive)
 
+    def test_itertasks(self):
+        s = State()
+        s.tasks = {"a": "a", "b": "b", "c": "c", "d": "d"}
+        self.assertEqual(len(list(s.itertasks(limit=2))), 2)
+
     def test_worker_heartbeat_expire(self):
         r = ev_worker_heartbeats(State())
         r.next()

+ 6 - 10
celery/tests/slow/test_buckets.py

@@ -148,18 +148,14 @@ class test_TaskBucket(Case):
         x = buckets.TaskBucket(task_registry=self.registry)
         x.not_empty = Mock()
         get = x._get = Mock()
-        calls = [0]
         remaining = [0]
 
         def effect():
-            try:
-                if not calls[0]:
-                    raise Empty()
-                rem = remaining[0]
-                remaining[0] = 0
-                return rem, Mock()
-            finally:
-                calls[0] += 1
+            if get.call_count == 1:
+                raise Empty()
+            rem = remaining[0]
+            remaining[0] = 0
+            return rem, Mock()
         get.side_effect = effect
 
         with mock_context(Mock()) as context:
@@ -167,7 +163,7 @@ class test_TaskBucket(Case):
             x.wait = Mock()
             x.get(block=True)
 
-            calls[0] = 0
+            get.reset()
             remaining[0] = 1
             x.get(block=True)
 

+ 8 - 0
celery/tests/tasks/test_registry.py

@@ -23,6 +23,9 @@ class MockPeriodicTask(PeriodicTask):
 
 class test_TaskRegistry(Case):
 
+    def test_NotRegistered_str(self):
+        self.assertTrue(repr(TaskRegistry.NotRegistered("tasks.add")))
+
     def assertRegisterUnregisterCls(self, r, task):
         with self.assertRaises(r.NotRegistered):
             r.unregister(task)
@@ -64,3 +67,8 @@ class test_TaskRegistry(Case):
 
         self.assertTrue(MockTask().run())
         self.assertTrue(MockPeriodicTask().run())
+
+    def test_compat(self):
+        r = TaskRegistry()
+        r.regular()
+        r.periodic()

+ 136 - 1
celery/tests/tasks/test_result.py

@@ -1,11 +1,21 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
+from pickle import loads, dumps
+from mock import Mock
+
 from celery import states
 from celery.app import app_or_default
+from celery.exceptions import IncompleteStream
 from celery.utils import uuid
 from celery.utils.serialization import pickle
-from celery.result import AsyncResult, EagerResult, TaskSetResult, ResultSet
+from celery.result import (
+    AsyncResult,
+    EagerResult,
+    TaskSetResult,
+    ResultSet,
+    from_serializable,
+)
 from celery.exceptions import TimeoutError
 from celery.task import task
 from celery.task.base import Task
@@ -53,6 +63,71 @@ class test_AsyncResult(AppCase):
         for task in (self.task1, self.task2, self.task3, self.task4):
             save_result(task)
 
+    def test_compat_properties(self):
+        x = AsyncResult("1")
+        self.assertEqual(x.task_id, x.id)
+        x.task_id = "2"
+        self.assertEqual(x.id, "2")
+
+    def test_children(self):
+        x = AsyncResult("1")
+        children = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
+        x.backend = Mock()
+        x.backend.get_children.return_value = children
+        x.backend.READY_STATES = states.READY_STATES
+        self.assertTrue(x.children)
+        self.assertEqual(len(x.children), 3)
+
+    def test_build_graph_get_leaf_collect(self):
+        x = AsyncResult("1")
+        x.backend._cache["1"] = {"status": states.SUCCESS, "result": None}
+        c = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
+        x.iterdeps = Mock()
+        x.iterdeps.return_value = (
+            (None, x),
+            (x, c[0]),
+            (c[0], c[1]),
+            (c[1], c[2])
+        )
+        x.backend.READY_STATES = states.READY_STATES
+        self.assertTrue(x.graph)
+
+        self.assertIs(x.get_leaf(), 2)
+
+        it = x.collect()
+        self.assertListEqual(list(it), [
+            (x, None),
+            (c[0], 0),
+            (c[1], 1),
+            (c[2], 2),
+        ])
+
+    def test_iterdeps(self):
+        x = AsyncResult("1")
+        x.backend._cache["1"] = {"status": states.SUCCESS, "result": None}
+        c = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
+        for child in c:
+            child.backend = Mock()
+            child.backend.get_children.return_value = []
+        x.backend.get_children = Mock()
+        x.backend.get_children.return_value = c
+        it = x.iterdeps()
+        self.assertListEqual(list(it), [
+            (None, x),
+            (x, c[0]),
+            (x, c[1]),
+            (x, c[2]),
+        ])
+        x.backend._cache.pop("1")
+        x.ready = Mock()
+        x.ready.return_value = False
+        with self.assertRaises(IncompleteStream):
+            list(x.iterdeps())
+        list(x.iterdeps(intermediate=True))
+
+    def test_eq_not_implemented(self):
+        self.assertFalse(AsyncResult("1") == object())
+
     def test_reduce(self):
         a1 = AsyncResult("uuid", task_name=mytask.name)
         restored = pickle.loads(pickle.dumps(a1))
@@ -129,6 +204,7 @@ class test_AsyncResult(AppCase):
         self.assertEqual(ok2_res.get(), "quick")
         with self.assertRaises(KeyError):
             nok_res.get()
+        self.assertTrue(nok_res.get(propagate=False))
         self.assertIsInstance(nok2_res.result, KeyError)
         self.assertEqual(ok_res.info, "the")
 
@@ -159,6 +235,32 @@ class test_AsyncResult(AppCase):
 
 class test_ResultSet(AppCase):
 
+    def test_resultset_repr(self):
+        self.assertTrue(repr(ResultSet(map(AsyncResult, [1, 2, 3]))))
+
+    def test_eq_other(self):
+        self.assertFalse(ResultSet([1, 3, 3]) == 1)
+        self.assertTrue(ResultSet([1]) == ResultSet([1]))
+
+    def test_get(self):
+        x = ResultSet(map(AsyncResult, [1, 2, 3]))
+        b = x.results[0].backend = Mock()
+        b.supports_native_join = False
+        x.join_native = Mock()
+        x.join = Mock()
+        x.get()
+        self.assertTrue(x.join.called)
+        b.supports_native_join = True
+        x.get()
+        self.assertTrue(x.join_native.called)
+
+    def test_add(self):
+        x = ResultSet([1])
+        x.add(2)
+        self.assertEqual(len(x), 2)
+        x.add(2)
+        self.assertEqual(len(x), 2)
+
     def test_add_discard(self):
         x = ResultSet([])
         x.add(AsyncResult("1"))
@@ -231,6 +333,21 @@ class test_TaskSetResult(AppCase):
         self.assertEqual(len(self.ts), self.size)
         self.assertEqual(self.ts.total, self.size)
 
+    def test_compat_properties(self):
+        self.assertEqual(self.ts.taskset_id, self.ts.id)
+        self.ts.taskset_id = "foo"
+        self.assertEqual(self.ts.taskset_id, "foo")
+
+    def test_eq_other(self):
+        self.assertFalse(self.ts == 1)
+
+    def test_reduce(self):
+        self.assertTrue(loads(dumps(self.ts)))
+
+    def test_compat_subtasks_kwarg(self):
+        x = TaskSetResult(uuid(), subtasks=[1, 2, 3])
+        self.assertEqual(x.results, [1, 2, 3])
+
     def test_iterate_raises(self):
         ar = MockAsyncResultFailure(uuid())
         ts = TaskSetResult(uuid(), [ar])
@@ -432,6 +549,7 @@ class test_EagerResult(AppCase):
         res = RaisingTask.apply(args=[3, 3])
         with self.assertRaises(KeyError):
             res.wait()
+        self.assertTrue(res.wait(propagate=False))
 
     def test_wait(self):
         res = EagerResult("x", "x", states.RETRY)
@@ -439,6 +557,23 @@ class test_EagerResult(AppCase):
         self.assertEqual(res.state, states.RETRY)
         self.assertEqual(res.status, states.RETRY)
 
+    def test_forget(self):
+        res = EagerResult("x", "x", states.RETRY)
+        res.forget()
+
     def test_revoke(self):
         res = RaisingTask.apply(args=[3, 3])
         self.assertFalse(res.revoke())
+
+
+class test_serializable(AppCase):
+
+    def test_AsyncResult(self):
+        x = AsyncResult(uuid())
+        self.assertEqual(x, from_serializable(x.serializable()))
+        self.assertEqual(x, from_serializable(x))
+
+    def test_TaskSetResult(self):
+        x = TaskSetResult(uuid(), [AsyncResult(uuid()) for _ in range(10)])
+        self.assertEqual(x, from_serializable(x.serializable()))
+        self.assertEqual(x, from_serializable(x))

+ 34 - 1
celery/tests/tasks/test_tasks.py

@@ -3,9 +3,11 @@ from __future__ import with_statement
 
 from datetime import datetime, timedelta
 from functools import wraps
+from mock import patch
+from pickle import loads, dumps
 
 from celery import task
-from celery.task import current
+from celery.task import current, Task
 from celery.app import app_or_default
 from celery.task import task as task_dec
 from celery.exceptions import RetryTaskError
@@ -57,6 +59,7 @@ def retry_task(arg1, arg2, kwarg=1, max_retries=None, care=True):
     current.iterations += 1
     rmax = current.max_retries if max_retries is None else max_retries
 
+    assert repr(current.request)
     retries = current.request.retries
     if care and retries >= rmax:
         return arg1
@@ -301,6 +304,22 @@ class test_tasks(Case):
     def test_task_class_repr(self):
         task = self.createTask("c.unittest.t.repr")
         self.assertIn("class Task of", repr(task.app.Task))
+        prev, task.app.Task._app = task.app.Task._app, None
+        try:
+            self.assertIn("unbound", repr(task.app.Task, ))
+        finally:
+            task.app.Task._app = prev
+
+    def test_bind_no_magic_kwargs(self):
+        task = self.createTask("c.unittest.t.magic_kwargs")
+        task.__class__.accept_magic_kwargs = None
+        task.bind(task.app)
+
+    def test_annotate(self):
+        with patch("celery.app.task.resolve_all_annotations") as anno:
+            anno.return_value = [{"FOO": "BAR"}]
+            Task.annotate()
+            self.assertEqual(Task.FOO, "BAR")
 
     def test_after_return(self):
         task = self.createTask("c.unittest.t.after_return")
@@ -436,6 +455,13 @@ class test_apply_task(Case):
         with self.assertRaises(KeyError):
             raising.apply(throw=True)
 
+    def test_apply_no_magic_kwargs(self):
+        increment_counter.accept_magic_kwargs = False
+        try:
+            increment_counter.apply()
+        finally:
+            increment_counter.accept_magic_kwargs = True
+
     def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
         raising.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
         try:
@@ -551,6 +577,13 @@ def patch_crontab_nowfun(cls, retval):
 
 class test_crontab_parser(Case):
 
+    def test_crontab_reduce(self):
+        self.assertTrue(loads(dumps(crontab("*"))))
+
+    def test_range_steps_not_enough(self):
+        with self.assertRaises(crontab_parser.ParseException):
+            crontab_parser(24)._range_steps([1])
+
     def test_parse_star(self):
         self.assertEqual(crontab_parser(24).parse('*'), set(range(24)))
         self.assertEqual(crontab_parser(60).parse('*'), set(range(60)))

+ 139 - 0
celery/tests/utilities/test_dispatcher.py

@@ -0,0 +1,139 @@
+from __future__ import absolute_import
+
+
+import gc
+import sys
+import time
+
+from celery.utils.dispatch import Signal
+from celery.tests.utils import Case
+
+
+if sys.platform.startswith('java'):
+
+    def garbage_collect():
+        # Some JVM GCs will execute finalizers in a different thread, meaning
+        # we need to wait for that to complete before we go on looking for the
+        # effects of that.
+        gc.collect()
+        time.sleep(0.1)
+
+elif hasattr(sys, "pypy_version_info"):
+
+    def garbage_collect():  # noqa
+        # Collecting weakreferences can take two collections on PyPy.
+        gc.collect()
+        gc.collect()
+else:
+
+    def garbage_collect():  # noqa
+        gc.collect()
+
+
+def receiver_1_arg(val, **kwargs):
+    return val
+
+
+class Callable(object):
+
+    def __call__(self, val, **kwargs):
+        return val
+
+    def a(self, val, **kwargs):
+        return val
+
+a_signal = Signal(providing_args=["val"])
+
+
+class DispatcherTests(Case):
+    """Test suite for dispatcher (barely started)"""
+
+    def _testIsClean(self, signal):
+        """Assert that everything has been cleaned up automatically"""
+        self.assertEqual(signal.receivers, [])
+
+        # force cleanup just in case
+        signal.receivers = []
+
+    def testExact(self):
+        a_signal.connect(receiver_1_arg, sender=self)
+        expected = [(receiver_1_arg, "test")]
+        result = a_signal.send(sender=self, val="test")
+        self.assertEqual(result, expected)
+        a_signal.disconnect(receiver_1_arg, sender=self)
+        self._testIsClean(a_signal)
+
+    def testIgnoredSender(self):
+        a_signal.connect(receiver_1_arg)
+        expected = [(receiver_1_arg, "test")]
+        result = a_signal.send(sender=self, val="test")
+        self.assertEqual(result, expected)
+        a_signal.disconnect(receiver_1_arg)
+        self._testIsClean(a_signal)
+
+    def testGarbageCollected(self):
+        a = Callable()
+        a_signal.connect(a.a, sender=self)
+        expected = []
+        del a
+        garbage_collect()
+        result = a_signal.send(sender=self, val="test")
+        self.assertEqual(result, expected)
+        self._testIsClean(a_signal)
+
+    def testMultipleRegistration(self):
+        a = Callable()
+        a_signal.connect(a)
+        a_signal.connect(a)
+        a_signal.connect(a)
+        a_signal.connect(a)
+        a_signal.connect(a)
+        a_signal.connect(a)
+        result = a_signal.send(sender=self, val="test")
+        self.assertEqual(len(result), 1)
+        self.assertEqual(len(a_signal.receivers), 1)
+        del a
+        del result
+        garbage_collect()
+        self._testIsClean(a_signal)
+
+    def testUidRegistration(self):
+
+        def uid_based_receiver_1(**kwargs):
+            pass
+
+        def uid_based_receiver_2(**kwargs):
+            pass
+
+        a_signal.connect(uid_based_receiver_1, dispatch_uid="uid")
+        a_signal.connect(uid_based_receiver_2, dispatch_uid="uid")
+        self.assertEqual(len(a_signal.receivers), 1)
+        a_signal.disconnect(dispatch_uid="uid")
+        self._testIsClean(a_signal)
+
+    def testRobust(self):
+        """Test the sendRobust function"""
+
+        def fails(val, **kwargs):
+            raise ValueError('this')
+
+        a_signal.connect(fails)
+        result = a_signal.send_robust(sender=self, val="test")
+        err = result[0][1]
+        self.assertTrue(isinstance(err, ValueError))
+        self.assertEqual(err.args, ('this',))
+        a_signal.disconnect(fails)
+        self._testIsClean(a_signal)
+
+    def testDisconnection(self):
+        receiver_1 = Callable()
+        receiver_2 = Callable()
+        receiver_3 = Callable()
+        a_signal.connect(receiver_1)
+        a_signal.connect(receiver_2)
+        a_signal.connect(receiver_3)
+        a_signal.disconnect(receiver_1)
+        del receiver_2
+        garbage_collect()
+        a_signal.disconnect(receiver_3)
+        self._testIsClean(a_signal)

+ 11 - 0
celery/tests/utilities/test_imports.py

@@ -1,4 +1,5 @@
 from __future__ import absolute_import
+from __future__ import with_statement
 
 from mock import Mock, patch
 
@@ -7,6 +8,8 @@ from celery.utils.imports import (
     symbol_by_name,
     reload_from_cwd,
     module_file,
+    find_module,
+    NotAPackage,
 )
 
 from celery.tests.utils import Case
@@ -14,6 +17,13 @@ from celery.tests.utils import Case
 
 class test_import_utils(Case):
 
+    def test_find_module(self):
+        self.assertTrue(find_module("celery"))
+        imp = Mock()
+        imp.return_value = None
+        with self.assertRaises(NotAPackage):
+            find_module("foo.bar.baz", imp=imp)
+
     def test_qualname(self):
         Class = type("Fox", (object, ), {"__module__": "quick.brown"})
         self.assertEqual(qualname(Class), "quick.brown.Fox")
@@ -32,6 +42,7 @@ class test_import_utils(Case):
         from celery.worker import WorkController
         self.assertIs(symbol_by_name(".worker:WorkController",
                     package="celery"), WorkController)
+        self.assertTrue(symbol_by_name(":group", package="celery"))
 
     @patch("celery.utils.imports.reload")
     def test_reload_from_cwd(self, reload):

+ 79 - 0
celery/tests/utilities/test_saferef.py

@@ -0,0 +1,79 @@
+from __future__ import absolute_import
+
+from celery.utils.dispatch.saferef import safe_ref
+from celery.tests.utils import Case
+
+
+class Class1(object):
+
+    def x(self):
+        pass
+
+
+def fun(obj):
+    pass
+
+
+class Class2(object):
+
+    def __call__(self, obj):
+        pass
+
+
+class SaferefTests(Case):
+
+    def setUp(self):
+        ts = []
+        ss = []
+        for x in xrange(5000):
+            t = Class1()
+            ts.append(t)
+            s = safe_ref(t.x, self._closure)
+            ss.append(s)
+        ts.append(fun)
+        ss.append(safe_ref(fun, self._closure))
+        for x in xrange(30):
+            t = Class2()
+            ts.append(t)
+            s = safe_ref(t, self._closure)
+            ss.append(s)
+        self.ts = ts
+        self.ss = ss
+        self.closureCount = 0
+
+    def tearDown(self):
+        del self.ts
+        del self.ss
+
+    def testIn(self):
+        """Test the "in" operator for safe references (cmp)"""
+        for t in self.ts[:50]:
+            self.assertTrue(safe_ref(t.x) in self.ss)
+
+    def testValid(self):
+        """Test that the references are valid (return instance methods)"""
+        for s in self.ss:
+            self.assertTrue(s())
+
+    def testShortCircuit(self):
+        """Test that creation short-circuits to reuse existing references"""
+        sd = {}
+        for s in self.ss:
+            sd[s] = 1
+        for t in self.ts:
+            if hasattr(t, 'x'):
+                self.assertIn(safe_ref(t.x), sd)
+            else:
+                self.assertIn(safe_ref(t), sd)
+
+    def testRepresentation(self):
+        """Test that the reference object's representation works
+
+        XXX Doesn't currently check the results, just that no error
+            is raised
+        """
+        repr(self.ss[-1])
+
+    def _closure(self, ref):
+        """Dumb utility mechanism to increment deletion counter"""
+        self.closureCount += 1

+ 29 - 1
celery/tests/utilities/test_timeutils.py

@@ -1,8 +1,13 @@
 from __future__ import absolute_import
+from __future__ import with_statement
 
 from datetime import datetime, timedelta
 
+from mock import Mock
+
+from celery.exceptions import ImproperlyConfigured
 from celery.utils import timeutils
+from celery.utils.timeutils import timezone
 from celery.tests.utils import Case
 
 
@@ -54,10 +59,33 @@ class test_timeutils(Case):
         now = datetime.now()
         self.assertIs(timeutils.maybe_iso8601(now), now)
 
-    def test_maybe_timdelta(self):
+    def test_maybe_timedelta(self):
         D = timeutils.maybe_timedelta
 
         for i in (30, 30.6):
             self.assertEqual(D(i), timedelta(seconds=i))
 
         self.assertEqual(D(timedelta(days=2)), timedelta(days=2))
+
+    def test_remaining_relative(self):
+        timeutils.remaining(datetime.utcnow(), timedelta(hours=1),
+                relative=True)
+
+
+class test_timezone(Case):
+
+    def test_get_timezone_with_pytz(self):
+        prev, timeutils.pytz = timeutils.pytz, Mock()
+        try:
+            self.assertTrue(timezone.get_timezone("UTC"))
+        finally:
+            timeutils.pytz = prev
+
+    def test_get_timezone_without_pytz(self):
+        prev, timeutils.pytz = timeutils.pytz, None
+        try:
+            self.assertTrue(timezone.get_timezone("UTC"))
+            with self.assertRaises(ImproperlyConfigured):
+                timezone.get_timezone("Europe/Oslo")
+        finally:
+            timeutils.pytz = prev

+ 16 - 1
celery/tests/utilities/test_utils.py

@@ -3,10 +3,12 @@ from __future__ import with_statement
 
 from kombu.utils.functional import promise
 
+from mock import patch
+
 from celery import utils
 from celery.utils import text
 from celery.utils import functional
-from celery.utils.functional import mpromise
+from celery.utils.functional import mpromise, maybe_list
 from celery.utils.threads import bgThread
 from celery.tests.utils import Case
 
@@ -112,6 +114,9 @@ class test_utils(Case):
         self.assertEqual(text.abbrtask("feeds.tasks.refresh", 30),
                                         "feeds.tasks.refresh")
 
+    def test_pretty(self):
+        self.assertTrue(text.pretty(("a", "b", "c")))
+
     def test_cached_property(self):
 
         def fun(obj):
@@ -122,6 +127,16 @@ class test_utils(Case):
         self.assertIs(x.__set__(None, None), x)
         self.assertIs(x.__delete__(None), x)
 
+    def test_maybe_list(self):
+        self.assertEqual(maybe_list(1), [1])
+        self.assertEqual(maybe_list([1]), [1])
+        self.assertIsNone(maybe_list(None))
+
+    @patch("warnings.warn")
+    def test_warn_deprecated(self, warn):
+        utils.warn_deprecated("Foo")
+        self.assertTrue(warn.called)
+
 
 class test_mpromise(Case):
 

+ 19 - 0
celery/tests/utils.py

@@ -529,3 +529,22 @@ def mock_open(typ=WhateverIO, side_effect=None):
 
 def patch_many(*targets):
     return nested(*[mock.patch(target) for target in targets])
+
+
+@contextmanager
+def patch_settings(app=None, **config):
+    if app is None:
+        from celery import current_app
+        app = current_app
+    prev = {}
+    for key, value in config.iteritems():
+        try:
+            prev[key] = getattr(app.conf, key)
+        except AttributeError:
+            pass
+        setattr(app.conf, key, value)
+
+    yield app.conf
+
+    for key, value in prev.iteritems():
+        setattr(app.conf, key, value)

+ 2 - 4
celery/tests/worker/test_autoreload.py

@@ -72,18 +72,16 @@ class test_StatMonitor(Case):
             st_mtime = time()
         stat.return_value = st()
         x = StatMonitor(["a", "b"])
-        calls = [0]
 
         def on_is_set():
-            calls[0] += 1
-            if calls[0] > 2:
+            if x.shutdown_event.is_set.call_count > 3:
                 return True
             return False
         x.shutdown_event = Mock()
         x.shutdown_event.is_set.side_effect = on_is_set
 
         x.start()
-        calls[0] = 0
+        x.shutdown_event = Mock()
         stat.side_effect = OSError()
         x.start()
 

+ 8 - 13
celery/tests/worker/test_worker.py

@@ -172,12 +172,13 @@ class test_QoS(Case):
         qos = QoS(consumer, 10)
         qos.update()
         self.assertEqual(qos.value, 10)
-        self.assertIn({"prefetch_count": 10}, consumer.qos.call_args)
+        consumer.qos.assert_called_with(prefetch_count=10)
         qos.decrement()
         self.assertEqual(qos.value, 9)
-        self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
+        consumer.qos.assert_called_with(prefetch_count=9)
         qos.decrement_eventually()
         self.assertEqual(qos.value, 8)
+        consumer.qos.assert_called_with(prefetch_count=9)
         self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
 
         # Does not decrement 0 value
@@ -675,17 +676,13 @@ class test_Consumer(Case):
     def test_open_connection_errback(self, sleep, connect):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
                       send_events=False)
-        calls = [0]
         from kombu.transport.memory import Transport
         Transport.connection_errors = (StdChannelError, )
 
         def effect():
-            try:
-                if calls[0] > 1:
-                    return
-                raise StdChannelError()
-            finally:
-                calls[0] += 1
+            if connect.call_count > 1:
+                return
+            raise StdChannelError()
         connect.side_effect = effect
         l._open_connection()
         connect.assert_called_with()
@@ -811,10 +808,8 @@ class test_WorkController(AppCase):
         app = Celery(loader=loader, set_as_current=False)
         app.conf = AttributeDict(DEFAULTS)
         process_initializer(app, "awesome.worker.com")
-        self.assertIn((tuple(WORKER_SIGIGNORE), {}),
-                      _signals.ignore.call_args_list)
-        self.assertIn((tuple(WORKER_SIGRESET), {}),
-                      _signals.reset.call_args_list)
+        _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
+        _signals.reset.assert_any_call(*WORKER_SIGRESET)
         self.assertTrue(app.loader.init_worker.call_count)
         self.assertTrue(on_worker_process_init.called)
         self.assertIs(_tls.current_app, app)

+ 19 - 19
celery/utils/compat.py

@@ -17,21 +17,21 @@ import sys
 is_py3k = sys.version_info[0] == 3
 
 try:
-    reload = reload                     # noqa
-except NameError:
-    from imp import reload              # noqa
+    reload = reload                         # noqa
+except NameError:                           # pragma: no cover
+    from imp import reload                  # noqa
 
 try:
-    from UserList import UserList       # noqa
-except ImportError:
-    from collections import UserList    # noqa
+    from UserList import UserList           # noqa
+except ImportError:                         # pragma: no cover
+    from collections import UserList        # noqa
 
 try:
-    from UserDict import UserDict       # noqa
-except ImportError:
-    from collections import UserDict    # noqa
+    from UserDict import UserDict           # noqa
+except ImportError:                         # pragma: no cover
+    from collections import UserDict        # noqa
 
-if is_py3k:
+if is_py3k:                                 # pragma: no cover
     from io import StringIO, BytesIO
     from .encoding import bytes_to_str
 
@@ -40,21 +40,21 @@ if is_py3k:
         def write(self, data):
             StringIO.write(self, bytes_to_str(data))
 else:
-    from StringIO import StringIO       # noqa
-    BytesIO = WhateverIO = StringIO     # noqa
+    from StringIO import StringIO           # noqa
+    BytesIO = WhateverIO = StringIO         # noqa
 
 
 ############## collections.OrderedDict ######################################
 try:
     from collections import OrderedDict
-except ImportError:
-    from ordereddict import OrderedDict  # noqa
+except ImportError:                         # pragma: no cover
+    from ordereddict import OrderedDict     # noqa
 
 ############## itertools.zip_longest #######################################
 
 try:
     from itertools import izip_longest as zip_longest
-except ImportError:
+except ImportError:                         # pragma: no cover
     import itertools
 
     def zip_longest(*args, **kwds):  # noqa
@@ -77,14 +77,14 @@ except ImportError:
 from itertools import chain
 
 
-def _compat_chain_from_iterable(iterables):
+def _compat_chain_from_iterable(iterables):  # pragma: no cover
     for it in iterables:
         for element in it:
             yield element
 
 try:
     chain_from_iterable = getattr(chain, "from_iterable")
-except AttributeError:
+except AttributeError:   # pragma: no cover
     chain_from_iterable = _compat_chain_from_iterable
 
 
@@ -94,13 +94,13 @@ import os
 from stat import ST_DEV, ST_INO
 import platform as _platform
 
-if _platform.system() == "Windows":
+if _platform.system() == "Windows":  # pragma: no cover
     #since windows doesn't go with WatchedFileHandler use FileHandler instead
     WatchedFileHandler = logging.FileHandler
 else:
     try:
         from logging.handlers import WatchedFileHandler
-    except ImportError:
+    except ImportError:  # pragma: no cover
         class WatchedFileHandler(logging.FileHandler):  # noqa
             """
             A handler for logging to a file, which watches the file

+ 4 - 4
celery/utils/dispatch/saferef.py

@@ -11,7 +11,7 @@ import weakref
 import traceback
 
 
-def safe_ref(target, on_delete=None):
+def safe_ref(target, on_delete=None):  # pragma: no cover
     """Return a *safe* weak reference to a callable target
 
     :param target: the object to be weakly referenced, if it's a
@@ -37,7 +37,7 @@ def safe_ref(target, on_delete=None):
         return weakref.ref(target)
 
 
-class BoundMethodWeakref(object):
+class BoundMethodWeakref(object):  # pragma: no cover
     """'Safe' and reusable weak references to instance methods.
 
     BoundMethodWeakref objects provide a mechanism for
@@ -199,7 +199,7 @@ class BoundMethodWeakref(object):
                 return function.__get__(target)
 
 
-class BoundNonDescriptorMethodWeakref(BoundMethodWeakref):
+class BoundNonDescriptorMethodWeakref(BoundMethodWeakref):  # pragma: no cover
     """A specialized :class:`BoundMethodWeakref`, for platforms where
     instance methods are not descriptors.
 
@@ -269,7 +269,7 @@ class BoundNonDescriptorMethodWeakref(BoundMethodWeakref):
                 return getattr(target, function.__name__)
 
 
-def get_bound_method_weakref(target, on_delete):
+def get_bound_method_weakref(target, on_delete):  # pragma: no cover
     """Instantiates the appropiate :class:`BoundMethodWeakRef`, depending
     on the details of the underlying class method implementation."""
     if hasattr(target, '__get__'):

+ 2 - 7
celery/utils/dispatch/signal.py

@@ -3,23 +3,18 @@
 from __future__ import absolute_import
 
 import weakref
-try:
-    set
-except NameError:
-    from sets import Set as set                 # Python 2.3 fallback
-
 from . import saferef
 
 WEAKREF_TYPES = (weakref.ReferenceType, saferef.BoundMethodWeakref)
 
 
-def _make_id(target):
+def _make_id(target):  # pragma: no cover
     if hasattr(target, 'im_func'):
         return (id(target.im_self), id(target.im_func))
     return id(target)
 
 
-class Signal(object):
+class Signal(object):  # pragma: no cover
     """Base class for all signals
 
 

+ 4 - 4
celery/utils/functional.py

@@ -20,9 +20,9 @@ from threading import Lock, RLock
 
 try:
     from collections import Sequence
-except ImportError:
+except ImportError:             # pragma: no cover
     # <= Py2.5
-    Sequence = (list, tuple)  # noqa
+    Sequence = (list, tuple)    # noqa
 
 from kombu.utils.functional import promise, maybe_promise
 
@@ -76,7 +76,7 @@ class LRUCache(UserDict):
         for k in self:
             try:
                 yield (k, self.data[k])
-            except KeyError:
+            except KeyError:  # pragma: no cover
                 pass
     iteritems = _iterate_items
 
@@ -100,7 +100,7 @@ class LRUCache(UserDict):
 def maybe_list(l):
     if l is None:
         return l
-    elif isinstance(l, Sequence):
+    elif not isinstance(l, basestring) and isinstance(l, Sequence):
         return l
     return [l]
 

+ 2 - 2
celery/utils/log.py

@@ -45,8 +45,8 @@ class ColorFormatter(logging.Formatter):
     colors = {"DEBUG": COLORS["blue"], "WARNING": COLORS["yellow"],
               "ERROR": COLORS["red"], "CRITICAL": COLORS["magenta"]}
 
-    def __init__(self, msg, use_color=True):
-        logging.Formatter.__init__(self, msg)
+    def __init__(self, fmt=None, use_color=True):
+        logging.Formatter.__init__(self, fmt)
         self.use_color = use_color
 
     def formatException(self, ei):

+ 2 - 4
celery/utils/mail.py

@@ -18,6 +18,7 @@ import warnings
 
 from email.mime.text import MIMEText
 
+from .functional import maybe_list
 from .imports import symbol_by_name
 
 supports_timeout = sys.version_info >= (2, 6)
@@ -31,15 +32,12 @@ class Message(object):
 
     def __init__(self, to=None, sender=None, subject=None, body=None,
             charset="us-ascii"):
-        self.to = to
+        self.to = maybe_list(to)
         self.sender = sender
         self.subject = subject
         self.body = body
         self.charset = charset
 
-        if not isinstance(self.to, (list, tuple)):
-            self.to = [self.to]
-
     def __repr__(self):
         return "<Email: To:%r Subject:%r>" % (self.to, self.subject)
 

+ 1 - 1
celery/utils/serialization.py

@@ -74,7 +74,7 @@ def find_nearest_pickleable_exception(exc):
     getmro_ = getattr(cls, "mro", None)
 
     # old-style classes doesn't have mro()
-    if not getmro_:
+    if not getmro_:  # pragma: no cover
         # all Py2.4 exceptions has a baseclass.
         if not getattr(cls, "__bases__", ()):
             return

+ 0 - 1
celery/utils/timeutils.py

@@ -71,7 +71,6 @@ class _Zone(object):
     @cached_property
     def utc(self):
         return self.get_timezone("UTC")
-
 timezone = _Zone()
 
 

+ 2 - 1
contrib/release/doc4allmods

@@ -2,7 +2,8 @@
 
 PACKAGE="$1"
 SKIP_PACKAGES="$PACKAGE tests management urls"
-SKIP_FILES="celery.backends.pyredis.rst
+SKIP_FILES="celery.__compat__.rst
+            celery.backends.pyredis.rst
             celery.bin.rst
             celery.bin.celeryd_detach.rst
             celery.concurrency.processes._win.rst

+ 3 - 6
contrib/release/py3k-run-tests

@@ -8,9 +8,6 @@ nosetests -vd celery.tests                                      \
             --cover3-html                                       \
             --cover3-html-dir="$base/cover"                     \
             --cover3-package=celery                             \
-            --cover3-exclude="                                  \
-              celery.tests.*                                    \
-              celery.utils.compat                               \
-              celery.utils.dispatch*"                           \
-            --with-xunit                                        \
-              --xunit-file="$base/nosetests.xml"
+            --cover3-exclude="celery.tests.*"                   \
+          --with-xunit                                          \
+            --xunit-file="$base/nosetests.xml"

+ 1 - 1
requirements/test.txt

@@ -2,7 +2,7 @@ unittest2>=0.4.0
 nose
 nose-cover3
 coverage>=3.0
-mock>=0.7.0
+mock==dev
 redis
 pymongo
 SQLAlchemy

+ 0 - 2
setup.cfg

@@ -4,8 +4,6 @@ cover3-branch = 1
 cover3-html = 1
 cover3-package = celery
 cover3-exclude = celery.tests.*
-                 celery.utils.compat
-                 celery.utils.dispatch*
 
 [build_sphinx]
 source-dir = docs/

+ 1 - 1
setup.py

@@ -126,7 +126,7 @@ elif py_version[0:2] == (2, 5):
 
 # -*- Tests Requires -*-
 
-tests_require = ["nose", "nose-cover3", "sqlalchemy", "mock"]
+tests_require = ["nose", "nose-cover3", "sqlalchemy", "mock==dev"]
 if sys.version_info < (2, 7):
     tests_require.append("unittest2")
 elif sys.version_info <= (2, 5):