Przeglądaj źródła

99% overall coverage :happy:

Ask Solem 13 lat temu
rodzic
commit
53b61c638b
58 zmienionych plików z 1071 dodań i 195 usunięć
  1. 2 1
      celery/app/task.py
  2. 9 7
      celery/apps/worker.py
  3. 1 1
      celery/backends/base.py
  4. 1 0
      celery/backends/redis.py
  5. 30 35
      celery/beat.py
  6. 1 1
      celery/bin/base.py
  7. 1 1
      celery/bin/celeryd.py
  8. 1 1
      celery/events/state.py
  9. 1 1
      celery/result.py
  10. 1 1
      celery/task/http.py
  11. 6 1
      celery/tests/__init__.py
  12. 31 0
      celery/tests/app/test_amqp.py
  13. 5 1
      celery/tests/app/test_app.py
  14. 129 18
      celery/tests/app/test_beat.py
  15. 15 0
      celery/tests/app/test_control.py
  16. 12 0
      celery/tests/app/test_defaults.py
  17. 15 3
      celery/tests/app/test_loaders.py
  18. 39 6
      celery/tests/app/test_log.py
  19. 4 0
      celery/tests/backends/test_amqp.py
  20. 24 0
      celery/tests/backends/test_base.py
  21. 7 1
      celery/tests/backends/test_database.py
  22. 2 0
      celery/tests/backends/test_pyredis_compat.py
  23. 33 0
      celery/tests/backends/test_redis.py
  24. 31 1
      celery/tests/bin/test_base.py
  25. 6 0
      celery/tests/bin/test_celerybeat.py
  26. 19 1
      celery/tests/bin/test_celeryd.py
  27. 2 4
      celery/tests/bin/test_celeryd_multi.py
  28. 21 0
      celery/tests/bin/test_celeryev.py
  29. 16 0
      celery/tests/concurrency/test_concurrency.py
  30. 19 8
      celery/tests/concurrency/test_processes.py
  31. 36 15
      celery/tests/events/test_events.py
  32. 18 3
      celery/tests/events/test_snapshot.py
  33. 5 0
      celery/tests/events/test_state.py
  34. 6 10
      celery/tests/slow/test_buckets.py
  35. 8 0
      celery/tests/tasks/test_registry.py
  36. 136 1
      celery/tests/tasks/test_result.py
  37. 34 1
      celery/tests/tasks/test_tasks.py
  38. 139 0
      celery/tests/utilities/test_dispatcher.py
  39. 11 0
      celery/tests/utilities/test_imports.py
  40. 79 0
      celery/tests/utilities/test_saferef.py
  41. 29 1
      celery/tests/utilities/test_timeutils.py
  42. 16 1
      celery/tests/utilities/test_utils.py
  43. 19 0
      celery/tests/utils.py
  44. 2 4
      celery/tests/worker/test_autoreload.py
  45. 8 13
      celery/tests/worker/test_worker.py
  46. 19 19
      celery/utils/compat.py
  47. 4 4
      celery/utils/dispatch/saferef.py
  48. 2 7
      celery/utils/dispatch/signal.py
  49. 4 4
      celery/utils/functional.py
  50. 2 2
      celery/utils/log.py
  51. 2 4
      celery/utils/mail.py
  52. 1 1
      celery/utils/serialization.py
  53. 0 1
      celery/utils/timeutils.py
  54. 2 1
      contrib/release/doc4allmods
  55. 3 6
      contrib/release/py3k-run-tests
  56. 1 1
      requirements/test.txt
  57. 0 2
      setup.cfg
  58. 1 1
      setup.py

+ 2 - 1
celery/app/task.py

@@ -817,7 +817,8 @@ class BaseTask(object):
 
 
     def annotate(self):
     def annotate(self):
         for d in resolve_all_annotations(self.app.annotations, self):
         for d in resolve_all_annotations(self.app.annotations, self):
-            self.__dict__.update(d)
+            for key, value in d.iteritems():
+                setattr(self, key, value)
 
 
     def __repr__(self):
     def __repr__(self):
         """`repr(task)`"""
         """`repr(task)`"""

+ 9 - 7
celery/apps/worker.py

@@ -25,7 +25,7 @@ from celery.worker import WorkController
 try:
 try:
     from greenlet import GreenletExit
     from greenlet import GreenletExit
     IGNORE_ERRORS = (GreenletExit, )
     IGNORE_ERRORS = (GreenletExit, )
-except ImportError:
+except ImportError:  # pragma: no cover
     IGNORE_ERRORS = ()
     IGNORE_ERRORS = ()
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -302,15 +302,17 @@ def install_cry_handler():
     # Jython/PyPy does not have sys._current_frames
     # Jython/PyPy does not have sys._current_frames
     is_jython = sys.platform.startswith("java")
     is_jython = sys.platform.startswith("java")
     is_pypy = hasattr(sys, "pypy_version_info")
     is_pypy = hasattr(sys, "pypy_version_info")
-    if not (is_jython or is_pypy):
+    if is_jython or is_pypy:  # pragma: no cover
+        return
 
 
-        def cry_handler(signum, frame):
-            """Signal handler logging the stacktrace of all active threads."""
-            logger.error("\n" + cry())
-        platforms.signals["SIGUSR1"] = cry_handler
+    def cry_handler(signum, frame):
+        """Signal handler logging the stacktrace of all active threads."""
+        logger.error("\n" + cry())
+    platforms.signals["SIGUSR1"] = cry_handler
 
 
 
 
-def install_rdb_handler(envvar="CELERY_RDBSIG", sig="SIGUSR2"):
+def install_rdb_handler(envvar="CELERY_RDBSIG",
+                        sig="SIGUSR2"):  # pragma: no cover
 
 
     def rdb_handler(signum, frame):
     def rdb_handler(signum, frame):
         """Signal handler setting a rdb breakpoint at the current frame."""
         """Signal handler setting a rdb breakpoint at the current frame."""

+ 1 - 1
celery/backends/base.py

@@ -120,7 +120,7 @@ class BaseBackend(object):
         if self.serializer in EXCEPTION_ABLE_CODECS:
         if self.serializer in EXCEPTION_ABLE_CODECS:
             return get_pickled_exception(exc)
             return get_pickled_exception(exc)
         return create_exception_cls(from_utf8(exc["exc_type"]),
         return create_exception_cls(from_utf8(exc["exc_type"]),
-                                    sys.modules[__name__])
+                                    sys.modules[__name__])(exc["exc_message"])
 
 
     def prepare_value(self, result):
     def prepare_value(self, result):
         """Prepare value for storage."""
         """Prepare value for storage."""

+ 1 - 0
celery/backends/redis.py

@@ -62,6 +62,7 @@ class RedisBackend(KeyValueStoreBackend):
         uhost = uport = upass = udb = None
         uhost = uport = upass = udb = None
         if url:
         if url:
             _, uhost, uport, _, upass, udb, _ = _parse_url(url)
             _, uhost, uport, _, upass, udb, _ = _parse_url(url)
+            udb = udb.strip("/")
         self.host = uhost or host or _get("HOST") or self.host
         self.host = uhost or host or _get("HOST") or self.host
         self.port = int(uport or port or _get("PORT") or self.port)
         self.port = int(uport or port or _get("PORT") or self.port)
         self.db = udb or db or _get("DB") or self.db
         self.db = udb or db or _get("DB") or self.db

+ 30 - 35
celery/beat.py

@@ -16,10 +16,9 @@ import os
 import time
 import time
 import shelve
 import shelve
 import sys
 import sys
-import threading
 import traceback
 import traceback
 
 
-from billiard import Process
+from billiard import Process, ensure_multiprocessing
 from kombu.utils import reprcall
 from kombu.utils import reprcall
 from kombu.utils.functional import maybe_promise
 from kombu.utils.functional import maybe_promise
 
 
@@ -31,6 +30,7 @@ from .app import app_or_default
 from .schedules import maybe_schedule, crontab
 from .schedules import maybe_schedule, crontab
 from .utils import cached_property
 from .utils import cached_property
 from .utils.imports import instantiate
 from .utils.imports import instantiate
+from .utils.threads import Event, Thread
 from .utils.timeutils import humanize_seconds
 from .utils.timeutils import humanize_seconds
 from .utils.log import get_logger
 from .utils.log import get_logger
 
 
@@ -229,12 +229,12 @@ class Scheduler(object):
             raise SchedulingError, SchedulingError(
             raise SchedulingError, SchedulingError(
                 "Couldn't apply scheduled task %s: %s" % (
                 "Couldn't apply scheduled task %s: %s" % (
                     entry.name, exc)), sys.exc_info()[2]
                     entry.name, exc)), sys.exc_info()[2]
-
-        if self.should_sync():
-            self._do_sync()
+        finally:
+            if self.should_sync():
+                self._do_sync()
         return result
         return result
 
 
-    def send_task(self, *args, **kwargs):               # pragma: no cover
+    def send_task(self, *args, **kwargs):
         return self.app.send_task(*args, **kwargs)
         return self.app.send_task(*args, **kwargs)
 
 
     def setup_schedule(self):
     def setup_schedule(self):
@@ -283,12 +283,6 @@ class Scheduler(object):
             else:
             else:
                 schedule[key] = entry
                 schedule[key] = entry
 
 
-    def get_schedule(self):
-        return self.data
-
-    def set_schedule(self, schedule):
-        self.data = schedule
-
     def _ensure_connected(self):
     def _ensure_connected(self):
         # callback called for each retry while the connection
         # callback called for each retry while the connection
         # can't be established.
         # can't be established.
@@ -299,6 +293,13 @@ class Scheduler(object):
         return self.connection.ensure_connection(_error_handler,
         return self.connection.ensure_connection(_error_handler,
                     self.app.conf.BROKER_CONNECTION_MAX_RETRIES)
                     self.app.conf.BROKER_CONNECTION_MAX_RETRIES)
 
 
+    def get_schedule(self):
+        return self.data
+
+    def set_schedule(self, schedule):
+        self.data = schedule
+    schedule = property(get_schedule, set_schedule)
+
     @cached_property
     @cached_property
     def connection(self):
     def connection(self):
         return self.app.broker_connection()
         return self.app.broker_connection()
@@ -307,10 +308,6 @@ class Scheduler(object):
     def publisher(self):
     def publisher(self):
         return self.Publisher(connection=self._ensure_connected())
         return self.Publisher(connection=self._ensure_connected())
 
 
-    @property
-    def schedule(self):
-        return self.get_schedule()
-
     @property
     @property
     def info(self):
     def info(self):
         return ""
         return ""
@@ -318,6 +315,7 @@ class Scheduler(object):
 
 
 class PersistentScheduler(Scheduler):
 class PersistentScheduler(Scheduler):
     persistence = shelve
     persistence = shelve
+    known_suffixes = ("", ".db", ".dat", ".bak", ".dir")
 
 
     _store = None
     _store = None
 
 
@@ -326,7 +324,7 @@ class PersistentScheduler(Scheduler):
         Scheduler.__init__(self, *args, **kwargs)
         Scheduler.__init__(self, *args, **kwargs)
 
 
     def _remove_db(self):
     def _remove_db(self):
-        for suffix in "", ".db", ".dat", ".bak", ".dir":
+        for suffix in self.known_suffixes:
             try:
             try:
                 os.remove(self.schedule_filename + suffix)
                 os.remove(self.schedule_filename + suffix)
             except OSError, exc:
             except OSError, exc:
@@ -358,6 +356,10 @@ class PersistentScheduler(Scheduler):
     def get_schedule(self):
     def get_schedule(self):
         return self._store["entries"]
         return self._store["entries"]
 
 
+    def set_schedule(self, schedule):
+        self._store["entries"] = schedule
+    schedule = property(get_schedule, set_schedule)
+
     def sync(self):
     def sync(self):
         if self._store is not None:
         if self._store is not None:
             self._store.sync()
             self._store.sync()
@@ -383,8 +385,8 @@ class Service(object):
         self.schedule_filename = schedule_filename or \
         self.schedule_filename = schedule_filename or \
                                     app.conf.CELERYBEAT_SCHEDULE_FILENAME
                                     app.conf.CELERYBEAT_SCHEDULE_FILENAME
 
 
-        self._is_shutdown = threading.Event()
-        self._is_stopped = threading.Event()
+        self._is_shutdown = Event()
+        self._is_stopped = Event()
 
 
     def start(self, embedded_process=False):
     def start(self, embedded_process=False):
         info("Celerybeat: Starting...")
         info("Celerybeat: Starting...")
@@ -397,7 +399,7 @@ class Service(object):
             platforms.set_process_title("celerybeat")
             platforms.set_process_title("celerybeat")
 
 
         try:
         try:
-            while not self._is_shutdown.isSet():
+            while not self._is_shutdown.is_set():
                 interval = self.scheduler.tick()
                 interval = self.scheduler.tick()
                 debug("Celerybeat: Waking up %s.",
                 debug("Celerybeat: Waking up %s.",
                       humanize_seconds(interval, prefix="in "))
                       humanize_seconds(interval, prefix="in "))
@@ -430,14 +432,14 @@ class Service(object):
         return self.get_scheduler()
         return self.get_scheduler()
 
 
 
 
-class _Threaded(threading.Thread):
+class _Threaded(Thread):
     """Embedded task scheduler using threading."""
     """Embedded task scheduler using threading."""
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         super(_Threaded, self).__init__()
         super(_Threaded, self).__init__()
         self.service = Service(*args, **kwargs)
         self.service = Service(*args, **kwargs)
-        self.setDaemon(True)
-        self.setName("Beat")
+        self.daemon = True
+        self.name = "Beat"
 
 
     def run(self):
     def run(self):
         self.service.start()
         self.service.start()
@@ -446,16 +448,12 @@ class _Threaded(threading.Thread):
         self.service.stop(wait=True)
         self.service.stop(wait=True)
 
 
 
 
-supports_fork = True
 try:
 try:
-    from billiard._ext import _billiard
-    supports_fork = True if _billiard else False
-except ImportError:
-    supports_fork = False
-
-if supports_fork:
-    class _Process(Process):
-        """Embedded task scheduler using multiprocessing."""
+    ensure_multiprocessing()
+except NotImplementedError:     # pragma: no cover
+    _Process = None
+else:
+    class _Process(Process):    # noqa
 
 
         def __init__(self, *args, **kwargs):
         def __init__(self, *args, **kwargs):
             super(_Process, self).__init__()
             super(_Process, self).__init__()
@@ -469,8 +467,6 @@ if supports_fork:
         def stop(self):
         def stop(self):
             self.service.stop()
             self.service.stop()
             self.terminate()
             self.terminate()
-else:
-    _Process = None
 
 
 
 
 def EmbeddedService(*args, **kwargs):
 def EmbeddedService(*args, **kwargs):
@@ -485,5 +481,4 @@ def EmbeddedService(*args, **kwargs):
         # in reasonable time.
         # in reasonable time.
         kwargs.setdefault("max_interval", 1)
         kwargs.setdefault("max_interval", 1)
         return _Threaded(*args, **kwargs)
         return _Threaded(*args, **kwargs)
-
     return _Process(*args, **kwargs)
     return _Process(*args, **kwargs)

+ 1 - 1
celery/bin/base.py

@@ -139,7 +139,7 @@ class Command(object):
         # Don't want to load configuration to just print the version,
         # Don't want to load configuration to just print the version,
         # so we handle --version manually here.
         # so we handle --version manually here.
         if "--version" in arguments:
         if "--version" in arguments:
-            print(self.version)
+            sys.stdout.write("%s\n" % self.version)
             sys.exit(0)
             sys.exit(0)
         parser = self.create_parser(prog_name)
         parser = self.create_parser(prog_name)
         return parser.parse_args(arguments)
         return parser.parse_args(arguments)

+ 1 - 1
celery/bin/celeryd.py

@@ -188,7 +188,7 @@ def main():
     # Fix for setuptools generated scripts, so that it will
     # Fix for setuptools generated scripts, so that it will
     # work with multiprocessing fork emulation.
     # work with multiprocessing fork emulation.
     # (see multiprocessing.forking.get_preparation_data())
     # (see multiprocessing.forking.get_preparation_data())
-    if __name__ != "__main__":
+    if __name__ != "__main__":  # pragma: no cover
         sys.modules["__main__"] = sys.modules[__name__]
         sys.modules["__main__"] = sys.modules[__name__]
     freeze_support()
     freeze_support()
     worker = WorkerCommand()
     worker = WorkerCommand()

+ 1 - 1
celery/events/state.py

@@ -299,7 +299,7 @@ class State(object):
     def itertasks(self, limit=None):
     def itertasks(self, limit=None):
         for index, row in enumerate(self.tasks.iteritems()):
         for index, row in enumerate(self.tasks.iteritems()):
             yield row
             yield row
-            if limit and index >= limit:
+            if limit and index + 1 >= limit:
                 break
                 break
 
 
     def tasks_by_timestamp(self, limit=None):
     def tasks_by_timestamp(self, limit=None):

+ 1 - 1
celery/result.py

@@ -632,7 +632,7 @@ class TaskSetResult(ResultSet):
         return self.id
         return self.id
 
 
     def _set_taskset_id(self, id):
     def _set_taskset_id(self, id):
-        self.taskset_id = id
+        self.id = id
     taskset_id = property(_get_taskset_id, _set_taskset_id)
     taskset_id = property(_get_taskset_id, _set_taskset_id)
 
 
 
 

+ 1 - 1
celery/task/http.py

@@ -47,7 +47,7 @@ def maybe_utf8(value):
     return value
     return value
 
 
 
 
-if sys.version_info >= (3, 0):
+if sys.version_info[0] == 3:  # pragma: no cover
 
 
     def utf8dict(tup):
     def utf8dict(tup):
         if not isinstance(tup, dict):
         if not isinstance(tup, dict):

+ 6 - 1
celery/tests/__init__.py

@@ -1,8 +1,10 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
+from __future__ import with_statement
 
 
 import logging
 import logging
 import os
 import os
 import sys
 import sys
+import warnings
 
 
 from importlib import import_module
 from importlib import import_module
 
 
@@ -77,4 +79,7 @@ def import_all_modules(name=__name__, file=__file__,
 
 
 
 
 if os.environ.get("COVER_ALL_MODULES") or "--with-coverage3" in sys.argv:
 if os.environ.get("COVER_ALL_MODULES") or "--with-coverage3" in sys.argv:
-    import_all_modules()
+    from celery.tests.utils import catch_warnings
+    with catch_warnings(record=True):
+        import_all_modules()
+    warnings.resetwarnings()

+ 31 - 0
celery/tests/app/test_amqp.py

@@ -29,6 +29,21 @@ class test_TaskPublisher(AppCase):
             pass
             pass
         publisher.release.assert_called_with()
         publisher.release.assert_called_with()
 
 
+    def test_declare(self):
+        publisher = self.app.amqp.TaskPublisher(self.app.broker_connection())
+        publisher.exchange.name = "foo"
+        publisher.declare()
+        publisher.exchange.name = None
+        publisher.declare()
+
+    def test_exit_AttributeError(self):
+        publisher = self.app.amqp.TaskPublisher(self.app.broker_connection())
+        publisher.close = Mock()
+        publisher.release = Mock()
+        publisher.release.side_effect = AttributeError()
+        publisher.__exit__()
+        publisher.close.assert_called_with()
+
     def test_ensure_declare_queue(self, q="x1242112"):
     def test_ensure_declare_queue(self, q="x1242112"):
         publisher = self.app.amqp.TaskPublisher(Mock())
         publisher = self.app.amqp.TaskPublisher(Mock())
         self.app.amqp.queues.add(q, q, q)
         self.app.amqp.queues.add(q, q, q)
@@ -103,3 +118,19 @@ class test_PublisherPool(AppCase):
             r2.release()
             r2.release()
         finally:
         finally:
             self.app.conf.BROKER_POOL_LIMIT = L
             self.app.conf.BROKER_POOL_LIMIT = L
+
+
+class test_Queues(AppCase):
+
+    def test_queues_format(self):
+        prev, self.app.amqp.queues._consume_from = \
+                self.app.amqp.queues._consume_from, {}
+        try:
+            self.assertEqual(self.app.amqp.queues.format(), "")
+        finally:
+            self.app.amqp.queues._consume_from = prev
+
+    def test_with_defaults(self):
+        self.assertEqual(
+            self.app.amqp.queues.with_defaults(None, "celery", "direct"),
+            {})

+ 5 - 1
celery/tests/app/test_app.py

@@ -164,7 +164,11 @@ class test_App(Case):
         self.assertEqual(self.app.conf.BROKER_TRANSPORT, "set_by_us")
         self.assertEqual(self.app.conf.BROKER_TRANSPORT, "set_by_us")
 
 
     def test_WorkController(self):
     def test_WorkController(self):
-        x = self.app.Worker()
+        x = self.app.WorkController
+        self.assertIs(x.app, self.app)
+
+    def test_Worker(self):
+        x = self.app.Worker
         self.assertIs(x.app, self.app)
         self.assertIs(x.app, self.app)
 
 
     def test_AsyncResult(self):
     def test_AsyncResult(self):

+ 129 - 18
celery/tests/app/test_beat.py

@@ -1,15 +1,19 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
+from __future__ import with_statement
+
+import errno
 
 
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
-from mock import patch
+from mock import Mock, call, patch
 from nose import SkipTest
 from nose import SkipTest
 
 
 from celery import beat
 from celery import beat
+from celery import task
 from celery.result import AsyncResult
 from celery.result import AsyncResult
 from celery.schedules import schedule
 from celery.schedules import schedule
 from celery.task.base import Task
 from celery.task.base import Task
 from celery.utils import uuid
 from celery.utils import uuid
-from celery.tests.utils import Case
+from celery.tests.utils import Case, patch_settings
 
 
 
 
 class Object(object):
 class Object(object):
@@ -159,10 +163,69 @@ class test_Scheduler(Case):
         scheduler.apply_async(scheduler.Entry(task=MockTask.name))
         scheduler.apply_async(scheduler.Entry(task=MockTask.name))
         self.assertTrue(through_task[0])
         self.assertTrue(through_task[0])
 
 
+    def test_apply_async_should_not_sync(self):
+
+        @task
+        def not_sync():
+            pass
+        not_sync.apply_async = Mock()
+
+        s = mScheduler()
+        s._do_sync = Mock()
+        s.should_sync = Mock()
+        s.should_sync.return_value = True
+        s.apply_async(s.Entry(task=not_sync.name))
+        s._do_sync.assert_called_with()
+
+        s._do_sync = Mock()
+        s.should_sync.return_value = False
+        s.apply_async(s.Entry(task=not_sync.name))
+        self.assertFalse(s._do_sync.called)
+
+    @patch("celery.app.base.Celery.send_task")
+    def test_send_task(self, send_task):
+        b = beat.Scheduler()
+        b.send_task("tasks.add", countdown=10)
+        send_task.assert_called_with("tasks.add", countdown=10)
+
     def test_info(self):
     def test_info(self):
         scheduler = mScheduler()
         scheduler = mScheduler()
         self.assertIsInstance(scheduler.info, basestring)
         self.assertIsInstance(scheduler.info, basestring)
 
 
+    def test_maybe_entry(self):
+        s = mScheduler()
+        entry = s.Entry(name="add every", task="tasks.add")
+        self.assertIs(s._maybe_entry(entry.name, entry), entry)
+        self.assertTrue(s._maybe_entry("add every", {
+            "task": "tasks.add",
+        }))
+
+    def test_set_schedule(self):
+        s = mScheduler()
+        s.schedule = {"foo": "bar"}
+        self.assertEqual(s.data, {"foo": "bar"})
+
+    @patch("kombu.connection.Connection.ensure_connection")
+    def test_ensure_connection_error_handler(self, ensure):
+        s = mScheduler()
+        self.assertTrue(s._ensure_connected())
+        self.assertTrue(ensure.called)
+        callback = ensure.call_args[0][0]
+
+        callback(KeyError(), 5)
+
+    def test_install_default_entries(self):
+        with patch_settings(CELERY_TASK_RESULT_EXPIRES=None,
+                            CELERYBEAT_SCHEDULE={}):
+            s = mScheduler()
+            s.install_default_entries({})
+            self.assertNotIn("celery.backend_cleanup", s.data)
+        with patch_settings(CELERY_TASK_RESULT_EXPIRES=30,
+                            CELERYBEAT_SCHEDULE={}):
+            s = mScheduler()
+            s.install_default_entries({})
+            self.assertIn("celery.backend_cleanup", s.data)
+
     def test_due_tick(self):
     def test_due_tick(self):
         scheduler = mScheduler()
         scheduler = mScheduler()
         scheduler.add(name="test_due_tick",
         scheduler.add(name="test_due_tick",
@@ -233,25 +296,73 @@ class test_Scheduler(Case):
         self.assertEqual(a.schedule["bar"].schedule._next_run_at, 40)
         self.assertEqual(a.schedule["bar"].schedule._next_run_at, 40)
 
 
 
 
+def create_persistent_scheduler(shelv=None):
+    if shelv is None:
+        shelv = MockShelve()
+
+    class MockPersistentScheduler(beat.PersistentScheduler):
+        sh = shelv
+        persistence = Object()
+        persistence.open = lambda *a, **kw: shelv
+        tick_raises_exit = False
+        shutdown_service = None
+
+        def tick(self):
+            if self.tick_raises_exit:
+                raise SystemExit()
+            if self.shutdown_service:
+                self.shutdown_service._is_shutdown.set()
+            return 0.0
+
+    return MockPersistentScheduler, shelv
+
+
+class test_PersistentScheduler(Case):
+
+    @patch("os.remove")
+    def test_remove_db(self, remove):
+        s = create_persistent_scheduler()[0](schedule_filename="schedule")
+        s._remove_db()
+        remove.assert_has_calls(
+            [call("schedule" + suffix) for suffix in s.known_suffixes]
+        )
+        err = OSError()
+        err.errno = errno.ENOENT
+        remove.side_effect = err
+        s._remove_db()
+        err.errno = errno.EPERM
+        with self.assertRaises(OSError):
+            s._remove_db()
+
+    def test_setup_schedule(self):
+        s = create_persistent_scheduler()[0](schedule_filename="schedule")
+        opens = s.persistence.open = Mock()
+        s._remove_db = Mock()
+
+        def effect(*args, **kwargs):
+            if opens.call_count > 1:
+                return s.sh
+            raise OSError()
+        opens.side_effect = effect
+        s.setup_schedule()
+        s._remove_db.assert_called_with()
+
+        s._store = {"__version__": 1}
+        s.setup_schedule()
+
+    def test_get_schedule(self):
+        s = create_persistent_scheduler()[0](schedule_filename="schedule")
+        s._store = {"entries": {}}
+        s.schedule = {"foo": "bar"}
+        self.assertDictEqual(s.schedule, {"foo": "bar"})
+        self.assertDictEqual(s._store["entries"], s.schedule)
+
+
 class test_Service(Case):
 class test_Service(Case):
 
 
     def get_service(self):
     def get_service(self):
-        sh = MockShelve()
-
-        class PersistentScheduler(beat.PersistentScheduler):
-            persistence = Object()
-            persistence.open = lambda *a, **kw: sh
-            tick_raises_exit = False
-            shutdown_service = None
-
-            def tick(self):
-                if self.tick_raises_exit:
-                    raise SystemExit()
-                if self.shutdown_service:
-                    self.shutdown_service._is_shutdown.set()
-                return 0.0
-
-        return beat.Service(scheduler_cls=PersistentScheduler), sh
+        Scheduler, mock_shelve = create_persistent_scheduler()
+        return beat.Service(scheduler_cls=Scheduler), mock_shelve
 
 
     def test_start(self):
     def test_start(self):
         s, sh = self.get_service()
         s, sh = self.get_service()

+ 15 - 0
celery/tests/app/test_control.py

@@ -117,6 +117,16 @@ class test_inspect(Case):
         self.i.cancel_consumer("foo")
         self.i.cancel_consumer("foo")
         self.assertIn("cancel_consumer", MockMailbox.sent)
         self.assertIn("cancel_consumer", MockMailbox.sent)
 
 
+    @with_mock_broadcast
+    def test_active_queues(self):
+        self.i.active_queues()
+        self.assertIn("active_queues", MockMailbox.sent)
+
+    @with_mock_broadcast
+    def test_report(self):
+        self.i.report()
+        self.assertIn("report", MockMailbox.sent)
+
 
 
 class test_Broadcast(Case):
 class test_Broadcast(Case):
 
 
@@ -153,6 +163,11 @@ class test_Broadcast(Case):
         self.control.rate_limit(mytask.name, "100/m")
         self.control.rate_limit(mytask.name, "100/m")
         self.assertIn("rate_limit", MockMailbox.sent)
         self.assertIn("rate_limit", MockMailbox.sent)
 
 
+    @with_mock_broadcast
+    def test_time_limit(self):
+        self.control.time_limit(mytask.name, soft=10, hard=20)
+        self.assertIn("time_limit", MockMailbox.sent)
+
     @with_mock_broadcast
     @with_mock_broadcast
     def test_revoke(self):
     def test_revoke(self):
         self.control.revoke("foozbaaz")
         self.control.revoke("foozbaaz")

+ 12 - 0
celery/tests/app/test_defaults.py

@@ -4,6 +4,7 @@ from __future__ import with_statement
 import sys
 import sys
 
 
 from importlib import import_module
 from importlib import import_module
+from mock import Mock, patch
 
 
 from celery.tests.utils import Case, pypy_version, sys_platform
 from celery.tests.utils import Case, pypy_version, sys_platform
 
 
@@ -17,6 +18,10 @@ class test_defaults(Case):
         if self._prev:
         if self._prev:
             sys.modules["celery.app.defaults"] = self._prev
             sys.modules["celery.app.defaults"] = self._prev
 
 
+    def test_any(self):
+        val = object()
+        self.assertIs(self.defaults.Option.typemap["any"](val), val)
+
     def test_default_pool_pypy_14(self):
     def test_default_pool_pypy_14(self):
         with sys_platform("darwin"):
         with sys_platform("darwin"):
             with pypy_version((1, 4, 0)):
             with pypy_version((1, 4, 0)):
@@ -27,6 +32,13 @@ class test_defaults(Case):
             with pypy_version((1, 5, 0)):
             with pypy_version((1, 5, 0)):
                 self.assertEqual(self.defaults.DEFAULT_POOL, "processes")
                 self.assertEqual(self.defaults.DEFAULT_POOL, "processes")
 
 
+    def test_deprecated(self):
+        source = Mock()
+        source.BROKER_INSIST = True
+        with patch("celery.utils.warn_deprecated") as warn:
+            self.defaults.find_deprecated_settings(source)
+            self.assertTrue(warn.called)
+
     def test_default_pool_jython(self):
     def test_default_pool_jython(self):
         with sys_platform("java 1.6.51"):
         with sys_platform("java 1.6.51"):
             self.assertEqual(self.defaults.DEFAULT_POOL, "threads")
             self.assertEqual(self.defaults.DEFAULT_POOL, "threads")

+ 15 - 3
celery/tests/app/test_loaders.py

@@ -4,7 +4,7 @@ from __future__ import with_statement
 import os
 import os
 import sys
 import sys
 
 
-from mock import patch
+from mock import Mock, patch
 
 
 from celery import loaders
 from celery import loaders
 from celery.app import app_or_default
 from celery.app import app_or_default
@@ -83,6 +83,17 @@ class test_LoaderBase(Case):
     def test_import_task_module(self):
     def test_import_task_module(self):
         self.assertEqual(sys, self.loader.import_task_module("sys"))
         self.assertEqual(sys, self.loader.import_task_module("sys"))
 
 
+    def test_init_worker_process(self):
+        self.loader.on_worker_process_init()
+        m = self.loader.on_worker_process_init = Mock()
+        self.loader.init_worker_process()
+        m.assert_called_with()
+
+    def test_config_from_object_module(self):
+        self.loader.import_from_cwd = Mock()
+        self.loader.config_from_object("module_name")
+        self.loader.import_from_cwd.assert_called_with("module_name")
+
     def test_conf_property(self):
     def test_conf_property(self):
         self.assertEqual(self.loader.conf["foo"], "bar")
         self.assertEqual(self.loader.conf["foo"], "bar")
         self.assertEqual(self.loader._conf["foo"], "bar")
         self.assertEqual(self.loader._conf["foo"], "bar")
@@ -181,7 +192,7 @@ class test_DefaultLoader(Case):
         celeryconfig.CELERY_IMPORTS = ("os", "sys")
         celeryconfig.CELERY_IMPORTS = ("os", "sys")
         configname = os.environ.get("CELERY_CONFIG_MODULE") or "celeryconfig"
         configname = os.environ.get("CELERY_CONFIG_MODULE") or "celeryconfig"
 
 
-        prevconfig = sys.modules[configname]
+        prevconfig = sys.modules.get(configname)
         sys.modules[configname] = celeryconfig
         sys.modules[configname] = celeryconfig
         try:
         try:
             l = default.Loader()
             l = default.Loader()
@@ -191,7 +202,8 @@ class test_DefaultLoader(Case):
             self.assertTupleEqual(settings.CELERY_IMPORTS, ("os", "sys"))
             self.assertTupleEqual(settings.CELERY_IMPORTS, ("os", "sys"))
             l.on_worker_init()
             l.on_worker_init()
         finally:
         finally:
-            sys.modules[configname] = prevconfig
+            if prevconfig:
+                sys.modules[configname] = prevconfig
 
 
     def test_import_from_cwd(self):
     def test_import_from_cwd(self):
         l = default.Loader()
         l = default.Loader()

+ 39 - 6
celery/tests/app/test_log.py

@@ -8,17 +8,39 @@ from tempfile import mktemp
 from mock import patch, Mock
 from mock import patch, Mock
 
 
 from celery import current_app
 from celery import current_app
-from celery.app.log import Logging
+from celery import signals
+from celery.app.log import Logging, TaskFormatter
 from celery.utils.log import LoggingProxy
 from celery.utils.log import LoggingProxy
 from celery.utils import uuid
 from celery.utils import uuid
-from celery.utils.log import get_logger, ColorFormatter, logger as base_logger
+from celery.utils.log import (
+    get_logger,
+    ColorFormatter,
+    logger as base_logger,
+)
 from celery.tests.utils import (
 from celery.tests.utils import (
-    Case, override_stdouts, wrap_logger, get_handlers,
+    AppCase, Case, override_stdouts, wrap_logger, get_handlers,
 )
 )
 
 
 log = current_app.log
 log = current_app.log
 
 
 
 
+class test_TaskFormatter(Case):
+
+    def test_no_task(self):
+        class Record(object):
+            msg = "hello world"
+            levelname = "info"
+            exc_text = exc_info = None
+
+            def getMessage(self):
+                return self.msg
+        record = Record()
+        x = TaskFormatter()
+        x.format(record)
+        self.assertEqual(record.task_name, "???")
+        self.assertEqual(record.task_id, "???")
+
+
 class test_ColorFormatter(Case):
 class test_ColorFormatter(Case):
 
 
     @patch("celery.utils.log.safe_str")
     @patch("celery.utils.log.safe_str")
@@ -71,11 +93,12 @@ class test_ColorFormatter(Case):
         self.assertEqual(safe_str.call_count, 1)
         self.assertEqual(safe_str.call_count, 1)
 
 
 
 
-class test_default_logger(Case):
+class test_default_logger(AppCase):
 
 
-    def setUp(self):
+    def setup(self):
         self.setup_logger = log.setup_logger
         self.setup_logger = log.setup_logger
         self.get_logger = lambda n=None: get_logger(n) if n else logging.root
         self.get_logger = lambda n=None: get_logger(n) if n else logging.root
+        signals.setup_logging.receivers[:] = []
         Logging._setup = False
         Logging._setup = False
 
 
     def test_get_logger_sets_parent(self):
     def test_get_logger_sets_parent(self):
@@ -86,6 +109,14 @@ class test_default_logger(Case):
         logger = get_logger(base_logger.name)
         logger = get_logger(base_logger.name)
         self.assertIs(logger.parent, logging.root)
         self.assertIs(logger.parent, logging.root)
 
 
+    def test_setup_logging_subsystem_misc(self):
+        log.setup_logging_subsystem(loglevel=None)
+        self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = True
+        try:
+            log.setup_logging_subsystem()
+        finally:
+            self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = False
+
     def test_setup_logging_subsystem_colorize(self):
     def test_setup_logging_subsystem_colorize(self):
         log.setup_logging_subsystem(colorize=None)
         log.setup_logging_subsystem(colorize=None)
         log.setup_logging_subsystem(colorize=True)
         log.setup_logging_subsystem(colorize=True)
@@ -149,6 +180,8 @@ class test_default_logger(Case):
                 log.redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
                 log.redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
                 logger.error("foo")
                 logger.error("foo")
                 self.assertIn("foo", sio.getvalue())
                 self.assertIn("foo", sio.getvalue())
+                log.redirect_stdouts_to_logger(logger, stdout=False,
+                        stderr=False)
         finally:
         finally:
             sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
             sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
 
 
@@ -186,7 +219,7 @@ class test_default_logger(Case):
 
 
 class test_task_logger(test_default_logger):
 class test_task_logger(test_default_logger):
 
 
-    def setUp(self):
+    def setup(self):
         logger = self.logger = get_logger("celery.task")
         logger = self.logger = get_logger("celery.task")
         logger.handlers = []
         logger.handlers = []
         logging.root.manager.loggerDict.pop(logger.name, None)
         logging.root.manager.loggerDict.pop(logger.name, None)

+ 4 - 0
celery/tests/backends/test_amqp.py

@@ -42,6 +42,10 @@ class test_AMQPBackend(Case):
         self.assertTrue(tb2._cache.get(tid))
         self.assertTrue(tb2._cache.get(tid))
         self.assertTrue(tb2.get_result(tid), 42)
         self.assertTrue(tb2.get_result(tid), 42)
 
 
+    def test_revive(self):
+        tb = self.create_backend()
+        tb.revive(None)
+
     def test_is_pickled(self):
     def test_is_pickled(self):
         tb1 = self.create_backend()
         tb1 = self.create_backend()
         tb2 = self.create_backend()
         tb2 = self.create_backend()

+ 24 - 0
celery/tests/backends/test_base.py

@@ -58,6 +58,10 @@ class test_BaseBackend_interface(Case):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
             b.forget("SOMExx-N0Nex1stant-IDxx-")
             b.forget("SOMExx-N0Nex1stant-IDxx-")
 
 
+    def test_get_children(self):
+        with self.assertRaises(NotImplementedError):
+            b.get_children("SOMExx-N0Nex1stant-IDxx-")
+
     def test_store_result(self):
     def test_store_result(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
             b.store_result("SOMExx-N0nex1stant-IDxx-", 42, states.SUCCESS)
             b.store_result("SOMExx-N0nex1stant-IDxx-", 42, states.SUCCESS)
@@ -98,6 +102,9 @@ class test_BaseBackend_interface(Case):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
             b.forget("SOMExx-N0nex1stant-IDxx-")
             b.forget("SOMExx-N0nex1stant-IDxx-")
 
 
+    def test_on_chord_part_return(self):
+        b.on_chord_part_return(None)
+
     def test_on_chord_apply(self, unlock="celery.chord_unlock"):
     def test_on_chord_apply(self, unlock="celery.chord_unlock"):
         p, current_app.tasks[unlock] = current_app.tasks.get(unlock), Mock()
         p, current_app.tasks[unlock] = current_app.tasks.get(unlock), Mock()
         try:
         try:
@@ -138,6 +145,7 @@ class test_prepare_exception(Case):
     def test_impossible(self):
     def test_impossible(self):
         x = b.prepare_exception(Impossible())
         x = b.prepare_exception(Impossible())
         self.assertIsInstance(x, UnpickleableExceptionWrapper)
         self.assertIsInstance(x, UnpickleableExceptionWrapper)
+        self.assertTrue(str(x))
         y = b.exception_to_python(x)
         y = b.exception_to_python(x)
         self.assertEqual(y.__class__.__name__, "Impossible")
         self.assertEqual(y.__class__.__name__, "Impossible")
         if sys.version_info < (2, 5):
         if sys.version_info < (2, 5):
@@ -202,6 +210,14 @@ class test_BaseDictBackend(Case):
         self.b.delete_taskset("can-delete")
         self.b.delete_taskset("can-delete")
         self.assertNotIn("can-delete", self.b._data)
         self.assertNotIn("can-delete", self.b._data)
 
 
+    def test_prepare_exception_json(self):
+        x = DictBackend(serializer="json")
+        e = x.prepare_exception(KeyError("foo"))
+        self.assertIn("exc_type", e)
+        e = x.exception_to_python(e)
+        self.assertEqual(e.__class__.__name__, "KeyError")
+        self.assertEqual(str(e), "'foo'")
+
     def test_save_taskset(self):
     def test_save_taskset(self):
         b = BaseDictBackend()
         b = BaseDictBackend()
         b._save_taskset = Mock()
         b._save_taskset = Mock()
@@ -237,6 +253,10 @@ class test_KeyValueStoreBackend(Case):
     def setUp(self):
     def setUp(self):
         self.b = KVBackend()
         self.b = KVBackend()
 
 
+    def test_on_chord_part_return(self):
+        assert not self.b.implements_incr
+        self.b.on_chord_part_return(None)
+
     def test_get_store_delete_result(self):
     def test_get_store_delete_result(self):
         tid = uuid()
         tid = uuid()
         self.b.mark_as_done(tid, "Hello world")
         self.b.mark_as_done(tid, "Hello world")
@@ -290,6 +310,10 @@ class test_KeyValueStoreBackend_interface(Case):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
             KeyValueStoreBackend().set("a", 1)
             KeyValueStoreBackend().set("a", 1)
 
 
+    def test_incr(self):
+        with self.assertRaises(NotImplementedError):
+            KeyValueStoreBackend().incr("a")
+
     def test_cleanup(self):
     def test_cleanup(self):
         self.assertFalse(KeyValueStoreBackend().cleanup())
         self.assertFalse(KeyValueStoreBackend().cleanup())
 
 

+ 7 - 1
celery/tests/backends/test_database.py

@@ -6,6 +6,7 @@ import sys
 from datetime import datetime
 from datetime import datetime
 
 
 from nose import SkipTest
 from nose import SkipTest
+from pickle import loads, dumps
 
 
 from celery import states
 from celery import states
 from celery.app import app_or_default
 from celery.app import app_or_default
@@ -151,7 +152,8 @@ class test_DatabaseBackend(Case):
         tb = DatabaseBackend(backend="memory://")
         tb = DatabaseBackend(backend="memory://")
         tid = uuid()
         tid = uuid()
         tb.mark_as_done(tid, {"foo": "bar"})
         tb.mark_as_done(tid, {"foo": "bar"})
-        x = AsyncResult(tid)
+        tb.mark_as_done(tid, {"foo": "bar"})
+        x = AsyncResult(tid, backend=tb)
         x.forget()
         x.forget()
         self.assertIsNone(x.result)
         self.assertIsNone(x.result)
 
 
@@ -159,6 +161,10 @@ class test_DatabaseBackend(Case):
         tb = DatabaseBackend()
         tb = DatabaseBackend()
         tb.process_cleanup()
         tb.process_cleanup()
 
 
+    def test_reduce(self):
+        tb = DatabaseBackend()
+        self.assertTrue(loads(dumps(tb)))
+
     def test_save__restore__delete_taskset(self):
     def test_save__restore__delete_taskset(self):
         tb = DatabaseBackend()
         tb = DatabaseBackend()
 
 

+ 2 - 0
celery/tests/backends/test_pyredis_compat.py

@@ -1,6 +1,7 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 from nose import SkipTest
 from nose import SkipTest
+from pickle import loads, dumps
 
 
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.tests.utils import Case
 from celery.tests.utils import Case
@@ -19,3 +20,4 @@ class test_RedisBackend(Case):
         self.assertEqual(x.redis_port, 312)
         self.assertEqual(x.redis_port, 312)
         self.assertEqual(x.redis_db, 1)
         self.assertEqual(x.redis_db, 1)
         self.assertEqual(x.redis_password, "foo")
         self.assertEqual(x.redis_password, "foo")
+        self.assertTrue(loads(dumps(x)))

+ 33 - 0
celery/tests/backends/test_redis.py

@@ -1,11 +1,16 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
+from __future__ import with_statement
 
 
 from datetime import timedelta
 from datetime import timedelta
 
 
 from mock import Mock, patch
 from mock import Mock, patch
+from nose import SkipTest
+from pickle import loads, dumps
 
 
 from celery import current_app
 from celery import current_app
 from celery import states
 from celery import states
+from celery.datastructures import AttributeDict
+from celery.exceptions import ImproperlyConfigured
 from celery.result import AsyncResult
 from celery.result import AsyncResult
 from celery.task import subtask
 from celery.task import subtask
 from celery.utils import cached_property, uuid
 from celery.utils import cached_property, uuid
@@ -81,6 +86,34 @@ class test_RedisBackend(Case):
 
 
         self.MockBackend = MockBackend
         self.MockBackend = MockBackend
 
 
+    def test_reduce(self):
+        try:
+            from celery.backends.redis import RedisBackend
+            x = RedisBackend()
+            self.assertTrue(loads(dumps(x)))
+        except ImportError:
+            raise SkipTest("redis not installed")
+
+    def test_no_redis(self):
+        self.MockBackend.redis = None
+        with self.assertRaises(ImproperlyConfigured):
+            self.MockBackend()
+
+    def test_url(self):
+        x = self.MockBackend("redis://foobar//1")
+        self.assertEqual(x.host, "foobar")
+        self.assertEqual(x.db, "1")
+
+    def test_conf_raises_KeyError(self):
+        conf = AttributeDict({"CELERY_RESULT_SERIALIZER": "json",
+                              "CELERY_MAX_CACHED_RESULTS": 1,
+                              "CELERY_TASK_RESULT_EXPIRES": None})
+        prev, current_app.conf = current_app.conf, conf
+        try:
+            self.MockBackend()
+        finally:
+            current_app.conf = prev
+
     def test_expires_defaults_to_config(self):
     def test_expires_defaults_to_config(self):
         conf = current_app.conf
         conf = current_app.conf
         prev = conf.CELERY_TASK_RESULT_EXPIRES
         prev = conf.CELERY_TASK_RESULT_EXPIRES

+ 31 - 1
celery/tests/bin/test_base.py

@@ -3,7 +3,9 @@ from __future__ import with_statement
 
 
 import os
 import os
 
 
-from celery.bin.base import Command
+from mock import patch
+
+from celery.bin.base import Command, Option
 from celery.tests.utils import AppCase, override_stdouts
 from celery.tests.utils import AppCase, override_stdouts
 
 
 
 
@@ -41,6 +43,13 @@ class test_Command(AppCase):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
             Command().run()
             Command().run()
 
 
+    @patch("sys.stdout")
+    def test_parse_options_version_only(self, stdout):
+        cmd = Command()
+        with self.assertRaises(SystemExit):
+            cmd.parse_options("prog", ["--version"])
+        stdout.write.assert_called_with(cmd.version + "\n")
+
     def test_execute_from_commandline(self):
     def test_execute_from_commandline(self):
         cmd = MockCommand()
         cmd = MockCommand()
         args1, kwargs1 = cmd.execute_from_commandline()     # sys.argv
         args1, kwargs1 = cmd.execute_from_commandline()     # sys.argv
@@ -71,6 +80,21 @@ class test_Command(AppCase):
         finally:
         finally:
             if prev:
             if prev:
                 os.environ["CELERY_CONFIG_MODULE"] = prev
                 os.environ["CELERY_CONFIG_MODULE"] = prev
+            else:
+                os.environ.pop("CELERY_CONFIG_MODULE", None)
+
+    def test_with_custom_broker(self):
+        prev = os.environ.pop("CELERY_BROKER_URL", None)
+        try:
+            cmd = MockCommand()
+            cmd.setup_app_from_commandline(["--broker=xyzza://"])
+            self.assertEqual(os.environ.get("CELERY_BROKER_URL"),
+                    "xyzza://")
+        finally:
+            if prev:
+                os.environ["CELERY_BROKER_URL"] = prev
+            else:
+                os.environ.pop("CELERY_BROKER_URL", None)
 
 
     def test_with_custom_app(self):
     def test_with_custom_app(self):
         cmd = MockCommand()
         cmd = MockCommand()
@@ -89,3 +113,9 @@ class test_Command(AppCase):
         self.assertEqual(cmd.app.conf.BROKER_HOST, "broker.example.com")
         self.assertEqual(cmd.app.conf.BROKER_HOST, "broker.example.com")
         self.assertEqual(cmd.app.conf.CELERYD_PREFETCH_MULTIPLIER, 100)
         self.assertEqual(cmd.app.conf.CELERYD_PREFETCH_MULTIPLIER, 100)
         self.assertListEqual(rest, ["--loglevel=INFO"])
         self.assertListEqual(rest, ["--loglevel=INFO"])
+
+    def test_parse_preload_options_shortopt(self):
+        cmd = Command()
+        cmd.preload_options = (Option("-s", action="store", dest="silent"), )
+        acc = cmd.parse_preload_options(["-s", "yes"])
+        self.assertEqual(acc.get("silent"), "yes")

+ 6 - 0
celery/tests/bin/test_celerybeat.py

@@ -182,6 +182,12 @@ class test_div(AppCase):
         self.assertTrue(MockDaemonContext.opened)
         self.assertTrue(MockDaemonContext.opened)
         self.assertTrue(MockDaemonContext.closed)
         self.assertTrue(MockDaemonContext.closed)
 
 
+    @patch("os.chdir")
+    def test_prepare_preload_options(self, chdir):
+        cmd = celerybeat_bin.BeatCommand()
+        cmd.prepare_preload_options({"working_directory": "/opt/Project"})
+        chdir.assert_called_with("/opt/Project")
+
     def test_parse_options(self):
     def test_parse_options(self):
         cmd = celerybeat_bin.BeatCommand()
         cmd = celerybeat_bin.BeatCommand()
         cmd.app = app_or_default()
         cmd.app = app_or_default()

+ 19 - 1
celery/tests/bin/test_celeryd.py

@@ -7,7 +7,7 @@ import sys
 
 
 from functools import wraps
 from functools import wraps
 
 
-from mock import patch
+from mock import Mock, patch
 from nose import SkipTest
 from nose import SkipTest
 
 
 from billiard import current_process
 from billiard import current_process
@@ -64,6 +64,16 @@ class test_Worker(AppCase):
         self.assertEqual(worker.use_queues, ["foo", "bar", "baz"])
         self.assertEqual(worker.use_queues, ["foo", "bar", "baz"])
         self.assertTrue("foo" in celery.amqp.queues)
         self.assertTrue("foo" in celery.amqp.queues)
 
 
+    @disable_stdouts
+    def test_cpu_count(self):
+        celery = Celery(set_as_current=False)
+        with patch("celery.apps.worker.cpu_count") as cpu_count:
+            cpu_count.side_effect = NotImplementedError()
+            worker = celery.Worker(concurrency=None)
+            self.assertEqual(worker.concurrency, 2)
+        worker = celery.Worker(concurrency=5)
+        self.assertEqual(worker.concurrency, 5)
+
     @disable_stdouts
     @disable_stdouts
     def test_windows_B_option(self):
     def test_windows_B_option(self):
         celery = Celery(set_as_current=False)
         celery = Celery(set_as_current=False)
@@ -139,6 +149,14 @@ class test_Worker(AppCase):
         worker.init_loader()
         worker.init_loader()
         worker.run()
         worker.run()
 
 
+        prev, cd.IGNORE_ERRORS = cd.IGNORE_ERRORS, (KeyError, )
+        try:
+            worker.run_worker = Mock()
+            worker.run_worker.side_effect = KeyError()
+            worker.run()
+        finally:
+            cd.IGNORE_ERRORS = prev
+
     @disable_stdouts
     @disable_stdouts
     def test_purge_messages(self):
     def test_purge_messages(self):
         self.Worker().purge_messages()
         self.Worker().purge_messages()

+ 2 - 4
celery/tests/bin/test_celeryd_multi.py

@@ -312,7 +312,7 @@ class test_MultiTool(Case):
         self.prepare_pidfile_for_getpids(PIDFile)
         self.prepare_pidfile_for_getpids(PIDFile)
         self.assertIsNone(self.t.shutdown_nodes([]))
         self.assertIsNone(self.t.shutdown_nodes([]))
         self.t.signal_node = Mock()
         self.t.signal_node = Mock()
-        self.t.node_alive = Mock()
+        node_alive = self.t.node_alive = Mock()
         self.t.node_alive.return_value = False
         self.t.node_alive.return_value = False
 
 
         callback = Mock()
         callback = Mock()
@@ -324,11 +324,9 @@ class test_MultiTool(Case):
         self.t.signal_node.return_value = False
         self.t.signal_node.return_value = False
         self.assertTrue(callback.called)
         self.assertTrue(callback.called)
         self.t.stop(["foo", "bar", "baz"], "celeryd", callback=None)
         self.t.stop(["foo", "bar", "baz"], "celeryd", callback=None)
-        calls = [0]
 
 
         def on_node_alive(pid):
         def on_node_alive(pid):
-            calls[0] += 1
-            if calls[0] > 3:
+            if node_alive.call_count > 4:
                 return True
                 return True
             return False
             return False
         self.t.signal_node.return_value = True
         self.t.signal_node.return_value = True

+ 21 - 0
celery/tests/bin/test_celeryev.py

@@ -1,6 +1,8 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
+from __future__ import with_statement
 
 
 from nose import SkipTest
 from nose import SkipTest
+from mock import patch as mpatch
 
 
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.bin import celeryev
 from celery.bin import celeryev
@@ -32,6 +34,14 @@ class test_EvCommand(Case):
         self.assertEqual(self.ev.run(dump=True), "me dumper, you?")
         self.assertEqual(self.ev.run(dump=True), "me dumper, you?")
         self.assertIn("celeryev:dump", proctitle.last[0])
         self.assertIn("celeryev:dump", proctitle.last[0])
 
 
+    @mpatch("os.chdir")
+    def test_prepare_preload_options(self, chdir):
+        self.ev.prepare_preload_options({"working_directory": "/opt/Project"})
+        chdir.assert_called_with("/opt/Project")
+        chdir.called = False
+        self.ev.prepare_preload_options({})
+        self.assertFalse(chdir.called)
+
     def test_run_top(self):
     def test_run_top(self):
         try:
         try:
             import curses  # noqa
             import curses  # noqa
@@ -56,6 +66,17 @@ class test_EvCommand(Case):
         self.assertEqual(kw["logfile"], "logfile")
         self.assertEqual(kw["logfile"], "logfile")
         self.assertIn("celeryev:cam", proctitle.last[0])
         self.assertIn("celeryev:cam", proctitle.last[0])
 
 
+    @mpatch("celery.events.snapshot.evcam")
+    @mpatch("celery.bin.celeryev.detached")
+    def test_run_cam_detached(self, detached, evcam):
+        self.ev.prog_name = "celeryev"
+        self.ev.run_evcam("myapp.Camera", detach=True)
+        self.assertTrue(detached.called)
+        self.assertTrue(evcam.called)
+
+    def test_get_options(self):
+        self.assertTrue(self.ev.get_options())
+
     @patch("celery.bin.celeryev", "EvCommand", MockCommand)
     @patch("celery.bin.celeryev", "EvCommand", MockCommand)
     def test_main(self):
     def test_main(self):
         MockCommand.executed = []
         MockCommand.executed = []

+ 16 - 0
celery/tests/concurrency/test_concurrency.py

@@ -47,6 +47,14 @@ class test_BasePool(Case):
                               {"target": (3, (8, 16)),
                               {"target": (3, (8, 16)),
                                "callback": (4, (42, ))})
                                "callback": (4, (42, ))})
 
 
+    def test_does_not_debug(self):
+        x = BasePool(10)
+        x._does_debug = False
+        x.apply_async(object)
+
+    def test_num_processes(self):
+        self.assertEqual(BasePool(7).num_processes, 7)
+
     def test_interface_on_start(self):
     def test_interface_on_start(self):
         BasePool(10).on_start()
         BasePool(10).on_start()
 
 
@@ -69,3 +77,11 @@ class test_BasePool(Case):
         p = BasePool(10)
         p = BasePool(10)
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
             p.restart()
             p.restart()
+
+    def test_interface_on_terminate(self):
+        p = BasePool(10)
+        p.on_terminate()
+
+    def test_interface_terminate_job(self):
+        with self.assertRaises(NotImplementedError):
+            BasePool(10).terminate_job(101)

+ 19 - 8
celery/tests/concurrency/test_processes.py

@@ -76,9 +76,9 @@ class MockPool(object):
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         self.started = True
         self.started = True
         self._state = mp.RUN
         self._state = mp.RUN
-        self.processes = kwargs.get("processes")
-        self._pool = [Object(pid=i) for i in range(self.processes)]
-        self._current_proc = cycle(xrange(self.processes)).next
+        self._processes = kwargs.get("processes")
+        self._pool = [Object(pid=i) for i in range(self._processes)]
+        self._current_proc = cycle(xrange(self._processes)).next
 
 
     def close(self):
     def close(self):
         self.closed = True
         self.closed = True
@@ -91,10 +91,10 @@ class MockPool(object):
         self.terminated = True
         self.terminated = True
 
 
     def grow(self, n=1):
     def grow(self, n=1):
-        self.processes += n
+        self._processes += n
 
 
     def shrink(self, n=1):
     def shrink(self, n=1):
-        self.processes -= n
+        self._processes -= n
 
 
     def apply_async(self, *args, **kwargs):
     def apply_async(self, *args, **kwargs):
         pass
         pass
@@ -179,11 +179,11 @@ class test_TaskPool(Case):
     def test_grow_shrink(self):
     def test_grow_shrink(self):
         pool = TaskPool(10)
         pool = TaskPool(10)
         pool.start()
         pool.start()
-        self.assertEqual(pool._pool.processes, 10)
+        self.assertEqual(pool._pool._processes, 10)
         pool.grow()
         pool.grow()
-        self.assertEqual(pool._pool.processes, 11)
+        self.assertEqual(pool._pool._processes, 11)
         pool.shrink(2)
         pool.shrink(2)
-        self.assertEqual(pool._pool.processes, 9)
+        self.assertEqual(pool._pool._processes, 9)
 
 
     def test_info(self):
     def test_info(self):
         pool = TaskPool(10)
         pool = TaskPool(10)
@@ -197,6 +197,17 @@ class test_TaskPool(Case):
         self.assertIsNone(info["max-tasks-per-child"])
         self.assertIsNone(info["max-tasks-per-child"])
         self.assertEqual(info["timeouts"], (5, 10))
         self.assertEqual(info["timeouts"], (5, 10))
 
 
+    def test_num_processes(self):
+        pool = TaskPool(7)
+        pool.start()
+        self.assertEqual(pool.num_processes, 7)
+
+    def test_restart_pool(self):
+        pool = TaskPool()
+        pool._pool = Mock()
+        pool.restart()
+        pool._pool.restart.assert_called_with()
+
     def test_restart(self):
     def test_restart(self):
         raise SkipTest("functional test")
         raise SkipTest("functional test")
 
 

+ 36 - 15
celery/tests/events/test_events.py

@@ -3,9 +3,10 @@ from __future__ import with_statement
 
 
 import socket
 import socket
 
 
+from mock import Mock
+
 from celery import events
 from celery import events
-from celery.app import app_or_default
-from celery.tests.utils import Case
+from celery.tests.utils import AppCase
 
 
 
 
 class MockProducer(object):
 class MockProducer(object):
@@ -29,7 +30,7 @@ class MockProducer(object):
         return False
         return False
 
 
 
 
-class test_Event(Case):
+class test_Event(AppCase):
 
 
     def test_constructor(self):
     def test_constructor(self):
         event = events.Event("world war II")
         event = events.Event("world war II")
@@ -37,10 +38,7 @@ class test_Event(Case):
         self.assertTrue(event["timestamp"])
         self.assertTrue(event["timestamp"])
 
 
 
 
-class test_EventDispatcher(Case):
-
-    def setUp(self):
-        self.app = app_or_default()
+class test_EventDispatcher(AppCase):
 
 
     def test_send(self):
     def test_send(self):
         producer = MockProducer()
         producer = MockProducer()
@@ -67,6 +65,30 @@ class test_EventDispatcher(Case):
         for ev in evs:
         for ev in evs:
             self.assertTrue(producer.has_event(ev))
             self.assertTrue(producer.has_event(ev))
 
 
+        buf = eventer._outbound_buffer = Mock()
+        buf.popleft.side_effect = IndexError()
+        eventer.flush()
+
+    def test_enter_exit(self):
+        with self.app.broker_connection() as conn:
+            d = self.app.events.Dispatcher(conn)
+            d.close = Mock()
+            with d as _d:
+                self.assertTrue(_d)
+            d.close.assert_called_with()
+
+    def test_enable_disable_callbacks(self):
+        on_enable = Mock()
+        on_disable = Mock()
+        with self.app.broker_connection() as conn:
+            with self.app.events.Dispatcher(conn, enabled=False) as d:
+                d.on_enabled.add(on_enable)
+                d.on_disabled.add(on_disable)
+                d.enable()
+                on_enable.assert_called_with()
+                d.disable()
+                on_disable.assert_called_with()
+
     def test_enabled_disable(self):
     def test_enabled_disable(self):
         connection = self.app.broker_connection()
         connection = self.app.broker_connection()
         channel = connection.channel()
         channel = connection.channel()
@@ -99,10 +121,7 @@ class test_EventDispatcher(Case):
             connection.close()
             connection.close()
 
 
 
 
-class test_EventReceiver(Case):
-
-    def setUp(self):
-        self.app = app_or_default()
+class test_EventReceiver(AppCase):
 
 
     def test_process(self):
     def test_process(self):
 
 
@@ -181,11 +200,13 @@ class test_EventReceiver(Case):
             connection.close()
             connection.close()
 
 
 
 
-class test_misc(Case):
-
-    def setUp(self):
-        self.app = app_or_default()
+class test_misc(AppCase):
 
 
     def test_State(self):
     def test_State(self):
         state = self.app.events.State()
         state = self.app.events.State()
         self.assertDictEqual(dict(state.workers), {})
         self.assertDictEqual(dict(state.workers), {})
+
+    def test_default_dispatcher(self):
+        with self.app.events.default_dispatcher() as d:
+            self.assertTrue(d)
+            self.assertTrue(d.connection)

+ 18 - 3
celery/tests/events/test_snapshot.py

@@ -1,6 +1,8 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
+from mock import patch
+
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.events import Events
 from celery.events import Events
 from celery.events.snapshot import Polaroid, evcam
 from celery.events.snapshot import Polaroid, evcam
@@ -114,11 +116,24 @@ class test_evcam(Case):
 
 
     def setUp(self):
     def setUp(self):
         self.app = app_or_default()
         self.app = app_or_default()
-        self.app.events = self.MockEvents()
+        self.prev, self.app.events = self.app.events, self.MockEvents()
+
+    def tearDown(self):
+        self.app.events = self.prev
 
 
     def test_evcam(self):
     def test_evcam(self):
         evcam(Polaroid, timer=timer)
         evcam(Polaroid, timer=timer)
         evcam(Polaroid, timer=timer, loglevel="CRITICAL")
         evcam(Polaroid, timer=timer, loglevel="CRITICAL")
         self.MockReceiver.raise_keyboard_interrupt = True
         self.MockReceiver.raise_keyboard_interrupt = True
-        with self.assertRaises(SystemExit):
-            evcam(Polaroid, timer=timer)
+        try:
+            with self.assertRaises(SystemExit):
+                evcam(Polaroid, timer=timer)
+        finally:
+            self.MockReceiver.raise_keyboard_interrupt = False
+
+    @patch("atexit.register")
+    @patch("celery.platforms.create_pidlock")
+    def test_evcam_pidfile(self, create_pidlock, atexit):
+        evcam(Polaroid, timer=timer, pidfile="/var/pid")
+        self.assertTrue(atexit.called)
+        create_pidlock.assert_called_with("/var/pid")

+ 5 - 0
celery/tests/events/test_state.py

@@ -172,6 +172,11 @@ class test_State(Case):
         self.assertFalse(r.state.alive_workers())
         self.assertFalse(r.state.alive_workers())
         self.assertFalse(r.state.workers["utest1"].alive)
         self.assertFalse(r.state.workers["utest1"].alive)
 
 
+    def test_itertasks(self):
+        s = State()
+        s.tasks = {"a": "a", "b": "b", "c": "c", "d": "d"}
+        self.assertEqual(len(list(s.itertasks(limit=2))), 2)
+
     def test_worker_heartbeat_expire(self):
     def test_worker_heartbeat_expire(self):
         r = ev_worker_heartbeats(State())
         r = ev_worker_heartbeats(State())
         r.next()
         r.next()

+ 6 - 10
celery/tests/slow/test_buckets.py

@@ -148,18 +148,14 @@ class test_TaskBucket(Case):
         x = buckets.TaskBucket(task_registry=self.registry)
         x = buckets.TaskBucket(task_registry=self.registry)
         x.not_empty = Mock()
         x.not_empty = Mock()
         get = x._get = Mock()
         get = x._get = Mock()
-        calls = [0]
         remaining = [0]
         remaining = [0]
 
 
         def effect():
         def effect():
-            try:
-                if not calls[0]:
-                    raise Empty()
-                rem = remaining[0]
-                remaining[0] = 0
-                return rem, Mock()
-            finally:
-                calls[0] += 1
+            if get.call_count == 1:
+                raise Empty()
+            rem = remaining[0]
+            remaining[0] = 0
+            return rem, Mock()
         get.side_effect = effect
         get.side_effect = effect
 
 
         with mock_context(Mock()) as context:
         with mock_context(Mock()) as context:
@@ -167,7 +163,7 @@ class test_TaskBucket(Case):
             x.wait = Mock()
             x.wait = Mock()
             x.get(block=True)
             x.get(block=True)
 
 
-            calls[0] = 0
+            get.reset()
             remaining[0] = 1
             remaining[0] = 1
             x.get(block=True)
             x.get(block=True)
 
 

+ 8 - 0
celery/tests/tasks/test_registry.py

@@ -23,6 +23,9 @@ class MockPeriodicTask(PeriodicTask):
 
 
 class test_TaskRegistry(Case):
 class test_TaskRegistry(Case):
 
 
+    def test_NotRegistered_str(self):
+        self.assertTrue(repr(TaskRegistry.NotRegistered("tasks.add")))
+
     def assertRegisterUnregisterCls(self, r, task):
     def assertRegisterUnregisterCls(self, r, task):
         with self.assertRaises(r.NotRegistered):
         with self.assertRaises(r.NotRegistered):
             r.unregister(task)
             r.unregister(task)
@@ -64,3 +67,8 @@ class test_TaskRegistry(Case):
 
 
         self.assertTrue(MockTask().run())
         self.assertTrue(MockTask().run())
         self.assertTrue(MockPeriodicTask().run())
         self.assertTrue(MockPeriodicTask().run())
+
+    def test_compat(self):
+        r = TaskRegistry()
+        r.regular()
+        r.periodic()

+ 136 - 1
celery/tests/tasks/test_result.py

@@ -1,11 +1,21 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
+from pickle import loads, dumps
+from mock import Mock
+
 from celery import states
 from celery import states
 from celery.app import app_or_default
 from celery.app import app_or_default
+from celery.exceptions import IncompleteStream
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
-from celery.result import AsyncResult, EagerResult, TaskSetResult, ResultSet
+from celery.result import (
+    AsyncResult,
+    EagerResult,
+    TaskSetResult,
+    ResultSet,
+    from_serializable,
+)
 from celery.exceptions import TimeoutError
 from celery.exceptions import TimeoutError
 from celery.task import task
 from celery.task import task
 from celery.task.base import Task
 from celery.task.base import Task
@@ -53,6 +63,71 @@ class test_AsyncResult(AppCase):
         for task in (self.task1, self.task2, self.task3, self.task4):
         for task in (self.task1, self.task2, self.task3, self.task4):
             save_result(task)
             save_result(task)
 
 
+    def test_compat_properties(self):
+        x = AsyncResult("1")
+        self.assertEqual(x.task_id, x.id)
+        x.task_id = "2"
+        self.assertEqual(x.id, "2")
+
+    def test_children(self):
+        x = AsyncResult("1")
+        children = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
+        x.backend = Mock()
+        x.backend.get_children.return_value = children
+        x.backend.READY_STATES = states.READY_STATES
+        self.assertTrue(x.children)
+        self.assertEqual(len(x.children), 3)
+
+    def test_build_graph_get_leaf_collect(self):
+        x = AsyncResult("1")
+        x.backend._cache["1"] = {"status": states.SUCCESS, "result": None}
+        c = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
+        x.iterdeps = Mock()
+        x.iterdeps.return_value = (
+            (None, x),
+            (x, c[0]),
+            (c[0], c[1]),
+            (c[1], c[2])
+        )
+        x.backend.READY_STATES = states.READY_STATES
+        self.assertTrue(x.graph)
+
+        self.assertIs(x.get_leaf(), 2)
+
+        it = x.collect()
+        self.assertListEqual(list(it), [
+            (x, None),
+            (c[0], 0),
+            (c[1], 1),
+            (c[2], 2),
+        ])
+
+    def test_iterdeps(self):
+        x = AsyncResult("1")
+        x.backend._cache["1"] = {"status": states.SUCCESS, "result": None}
+        c = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
+        for child in c:
+            child.backend = Mock()
+            child.backend.get_children.return_value = []
+        x.backend.get_children = Mock()
+        x.backend.get_children.return_value = c
+        it = x.iterdeps()
+        self.assertListEqual(list(it), [
+            (None, x),
+            (x, c[0]),
+            (x, c[1]),
+            (x, c[2]),
+        ])
+        x.backend._cache.pop("1")
+        x.ready = Mock()
+        x.ready.return_value = False
+        with self.assertRaises(IncompleteStream):
+            list(x.iterdeps())
+        list(x.iterdeps(intermediate=True))
+
+    def test_eq_not_implemented(self):
+        self.assertFalse(AsyncResult("1") == object())
+
     def test_reduce(self):
     def test_reduce(self):
         a1 = AsyncResult("uuid", task_name=mytask.name)
         a1 = AsyncResult("uuid", task_name=mytask.name)
         restored = pickle.loads(pickle.dumps(a1))
         restored = pickle.loads(pickle.dumps(a1))
@@ -129,6 +204,7 @@ class test_AsyncResult(AppCase):
         self.assertEqual(ok2_res.get(), "quick")
         self.assertEqual(ok2_res.get(), "quick")
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
             nok_res.get()
             nok_res.get()
+        self.assertTrue(nok_res.get(propagate=False))
         self.assertIsInstance(nok2_res.result, KeyError)
         self.assertIsInstance(nok2_res.result, KeyError)
         self.assertEqual(ok_res.info, "the")
         self.assertEqual(ok_res.info, "the")
 
 
@@ -159,6 +235,32 @@ class test_AsyncResult(AppCase):
 
 
 class test_ResultSet(AppCase):
 class test_ResultSet(AppCase):
 
 
+    def test_resultset_repr(self):
+        self.assertTrue(repr(ResultSet(map(AsyncResult, [1, 2, 3]))))
+
+    def test_eq_other(self):
+        self.assertFalse(ResultSet([1, 3, 3]) == 1)
+        self.assertTrue(ResultSet([1]) == ResultSet([1]))
+
+    def test_get(self):
+        x = ResultSet(map(AsyncResult, [1, 2, 3]))
+        b = x.results[0].backend = Mock()
+        b.supports_native_join = False
+        x.join_native = Mock()
+        x.join = Mock()
+        x.get()
+        self.assertTrue(x.join.called)
+        b.supports_native_join = True
+        x.get()
+        self.assertTrue(x.join_native.called)
+
+    def test_add(self):
+        x = ResultSet([1])
+        x.add(2)
+        self.assertEqual(len(x), 2)
+        x.add(2)
+        self.assertEqual(len(x), 2)
+
     def test_add_discard(self):
     def test_add_discard(self):
         x = ResultSet([])
         x = ResultSet([])
         x.add(AsyncResult("1"))
         x.add(AsyncResult("1"))
@@ -231,6 +333,21 @@ class test_TaskSetResult(AppCase):
         self.assertEqual(len(self.ts), self.size)
         self.assertEqual(len(self.ts), self.size)
         self.assertEqual(self.ts.total, self.size)
         self.assertEqual(self.ts.total, self.size)
 
 
+    def test_compat_properties(self):
+        self.assertEqual(self.ts.taskset_id, self.ts.id)
+        self.ts.taskset_id = "foo"
+        self.assertEqual(self.ts.taskset_id, "foo")
+
+    def test_eq_other(self):
+        self.assertFalse(self.ts == 1)
+
+    def test_reduce(self):
+        self.assertTrue(loads(dumps(self.ts)))
+
+    def test_compat_subtasks_kwarg(self):
+        x = TaskSetResult(uuid(), subtasks=[1, 2, 3])
+        self.assertEqual(x.results, [1, 2, 3])
+
     def test_iterate_raises(self):
     def test_iterate_raises(self):
         ar = MockAsyncResultFailure(uuid())
         ar = MockAsyncResultFailure(uuid())
         ts = TaskSetResult(uuid(), [ar])
         ts = TaskSetResult(uuid(), [ar])
@@ -432,6 +549,7 @@ class test_EagerResult(AppCase):
         res = RaisingTask.apply(args=[3, 3])
         res = RaisingTask.apply(args=[3, 3])
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
             res.wait()
             res.wait()
+        self.assertTrue(res.wait(propagate=False))
 
 
     def test_wait(self):
     def test_wait(self):
         res = EagerResult("x", "x", states.RETRY)
         res = EagerResult("x", "x", states.RETRY)
@@ -439,6 +557,23 @@ class test_EagerResult(AppCase):
         self.assertEqual(res.state, states.RETRY)
         self.assertEqual(res.state, states.RETRY)
         self.assertEqual(res.status, states.RETRY)
         self.assertEqual(res.status, states.RETRY)
 
 
+    def test_forget(self):
+        res = EagerResult("x", "x", states.RETRY)
+        res.forget()
+
     def test_revoke(self):
     def test_revoke(self):
         res = RaisingTask.apply(args=[3, 3])
         res = RaisingTask.apply(args=[3, 3])
         self.assertFalse(res.revoke())
         self.assertFalse(res.revoke())
+
+
+class test_serializable(AppCase):
+
+    def test_AsyncResult(self):
+        x = AsyncResult(uuid())
+        self.assertEqual(x, from_serializable(x.serializable()))
+        self.assertEqual(x, from_serializable(x))
+
+    def test_TaskSetResult(self):
+        x = TaskSetResult(uuid(), [AsyncResult(uuid()) for _ in range(10)])
+        self.assertEqual(x, from_serializable(x.serializable()))
+        self.assertEqual(x, from_serializable(x))

+ 34 - 1
celery/tests/tasks/test_tasks.py

@@ -3,9 +3,11 @@ from __future__ import with_statement
 
 
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from functools import wraps
 from functools import wraps
+from mock import patch
+from pickle import loads, dumps
 
 
 from celery import task
 from celery import task
-from celery.task import current
+from celery.task import current, Task
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.task import task as task_dec
 from celery.task import task as task_dec
 from celery.exceptions import RetryTaskError
 from celery.exceptions import RetryTaskError
@@ -57,6 +59,7 @@ def retry_task(arg1, arg2, kwarg=1, max_retries=None, care=True):
     current.iterations += 1
     current.iterations += 1
     rmax = current.max_retries if max_retries is None else max_retries
     rmax = current.max_retries if max_retries is None else max_retries
 
 
+    assert repr(current.request)
     retries = current.request.retries
     retries = current.request.retries
     if care and retries >= rmax:
     if care and retries >= rmax:
         return arg1
         return arg1
@@ -301,6 +304,22 @@ class test_tasks(Case):
     def test_task_class_repr(self):
     def test_task_class_repr(self):
         task = self.createTask("c.unittest.t.repr")
         task = self.createTask("c.unittest.t.repr")
         self.assertIn("class Task of", repr(task.app.Task))
         self.assertIn("class Task of", repr(task.app.Task))
+        prev, task.app.Task._app = task.app.Task._app, None
+        try:
+            self.assertIn("unbound", repr(task.app.Task, ))
+        finally:
+            task.app.Task._app = prev
+
+    def test_bind_no_magic_kwargs(self):
+        task = self.createTask("c.unittest.t.magic_kwargs")
+        task.__class__.accept_magic_kwargs = None
+        task.bind(task.app)
+
+    def test_annotate(self):
+        with patch("celery.app.task.resolve_all_annotations") as anno:
+            anno.return_value = [{"FOO": "BAR"}]
+            Task.annotate()
+            self.assertEqual(Task.FOO, "BAR")
 
 
     def test_after_return(self):
     def test_after_return(self):
         task = self.createTask("c.unittest.t.after_return")
         task = self.createTask("c.unittest.t.after_return")
@@ -436,6 +455,13 @@ class test_apply_task(Case):
         with self.assertRaises(KeyError):
         with self.assertRaises(KeyError):
             raising.apply(throw=True)
             raising.apply(throw=True)
 
 
+    def test_apply_no_magic_kwargs(self):
+        increment_counter.accept_magic_kwargs = False
+        try:
+            increment_counter.apply()
+        finally:
+            increment_counter.accept_magic_kwargs = True
+
     def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
     def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
         raising.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
         raising.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
         try:
         try:
@@ -551,6 +577,13 @@ def patch_crontab_nowfun(cls, retval):
 
 
 class test_crontab_parser(Case):
 class test_crontab_parser(Case):
 
 
+    def test_crontab_reduce(self):
+        self.assertTrue(loads(dumps(crontab("*"))))
+
+    def test_range_steps_not_enough(self):
+        with self.assertRaises(crontab_parser.ParseException):
+            crontab_parser(24)._range_steps([1])
+
     def test_parse_star(self):
     def test_parse_star(self):
         self.assertEqual(crontab_parser(24).parse('*'), set(range(24)))
         self.assertEqual(crontab_parser(24).parse('*'), set(range(24)))
         self.assertEqual(crontab_parser(60).parse('*'), set(range(60)))
         self.assertEqual(crontab_parser(60).parse('*'), set(range(60)))

+ 139 - 0
celery/tests/utilities/test_dispatcher.py

@@ -0,0 +1,139 @@
+from __future__ import absolute_import
+
+
+import gc
+import sys
+import time
+
+from celery.utils.dispatch import Signal
+from celery.tests.utils import Case
+
+
+if sys.platform.startswith('java'):
+
+    def garbage_collect():
+        # Some JVM GCs will execute finalizers in a different thread, meaning
+        # we need to wait for that to complete before we go on looking for the
+        # effects of that.
+        gc.collect()
+        time.sleep(0.1)
+
+elif hasattr(sys, "pypy_version_info"):
+
+    def garbage_collect():  # noqa
+        # Collecting weakreferences can take two collections on PyPy.
+        gc.collect()
+        gc.collect()
+else:
+
+    def garbage_collect():  # noqa
+        gc.collect()
+
+
+def receiver_1_arg(val, **kwargs):
+    return val
+
+
+class Callable(object):
+
+    def __call__(self, val, **kwargs):
+        return val
+
+    def a(self, val, **kwargs):
+        return val
+
+a_signal = Signal(providing_args=["val"])
+
+
+class DispatcherTests(Case):
+    """Test suite for dispatcher (barely started)"""
+
+    def _testIsClean(self, signal):
+        """Assert that everything has been cleaned up automatically"""
+        self.assertEqual(signal.receivers, [])
+
+        # force cleanup just in case
+        signal.receivers = []
+
+    def testExact(self):
+        a_signal.connect(receiver_1_arg, sender=self)
+        expected = [(receiver_1_arg, "test")]
+        result = a_signal.send(sender=self, val="test")
+        self.assertEqual(result, expected)
+        a_signal.disconnect(receiver_1_arg, sender=self)
+        self._testIsClean(a_signal)
+
+    def testIgnoredSender(self):
+        a_signal.connect(receiver_1_arg)
+        expected = [(receiver_1_arg, "test")]
+        result = a_signal.send(sender=self, val="test")
+        self.assertEqual(result, expected)
+        a_signal.disconnect(receiver_1_arg)
+        self._testIsClean(a_signal)
+
+    def testGarbageCollected(self):
+        a = Callable()
+        a_signal.connect(a.a, sender=self)
+        expected = []
+        del a
+        garbage_collect()
+        result = a_signal.send(sender=self, val="test")
+        self.assertEqual(result, expected)
+        self._testIsClean(a_signal)
+
+    def testMultipleRegistration(self):
+        a = Callable()
+        a_signal.connect(a)
+        a_signal.connect(a)
+        a_signal.connect(a)
+        a_signal.connect(a)
+        a_signal.connect(a)
+        a_signal.connect(a)
+        result = a_signal.send(sender=self, val="test")
+        self.assertEqual(len(result), 1)
+        self.assertEqual(len(a_signal.receivers), 1)
+        del a
+        del result
+        garbage_collect()
+        self._testIsClean(a_signal)
+
+    def testUidRegistration(self):
+
+        def uid_based_receiver_1(**kwargs):
+            pass
+
+        def uid_based_receiver_2(**kwargs):
+            pass
+
+        a_signal.connect(uid_based_receiver_1, dispatch_uid="uid")
+        a_signal.connect(uid_based_receiver_2, dispatch_uid="uid")
+        self.assertEqual(len(a_signal.receivers), 1)
+        a_signal.disconnect(dispatch_uid="uid")
+        self._testIsClean(a_signal)
+
+    def testRobust(self):
+        """Test the sendRobust function"""
+
+        def fails(val, **kwargs):
+            raise ValueError('this')
+
+        a_signal.connect(fails)
+        result = a_signal.send_robust(sender=self, val="test")
+        err = result[0][1]
+        self.assertTrue(isinstance(err, ValueError))
+        self.assertEqual(err.args, ('this',))
+        a_signal.disconnect(fails)
+        self._testIsClean(a_signal)
+
+    def testDisconnection(self):
+        receiver_1 = Callable()
+        receiver_2 = Callable()
+        receiver_3 = Callable()
+        a_signal.connect(receiver_1)
+        a_signal.connect(receiver_2)
+        a_signal.connect(receiver_3)
+        a_signal.disconnect(receiver_1)
+        del receiver_2
+        garbage_collect()
+        a_signal.disconnect(receiver_3)
+        self._testIsClean(a_signal)

+ 11 - 0
celery/tests/utilities/test_imports.py

@@ -1,4 +1,5 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
+from __future__ import with_statement
 
 
 from mock import Mock, patch
 from mock import Mock, patch
 
 
@@ -7,6 +8,8 @@ from celery.utils.imports import (
     symbol_by_name,
     symbol_by_name,
     reload_from_cwd,
     reload_from_cwd,
     module_file,
     module_file,
+    find_module,
+    NotAPackage,
 )
 )
 
 
 from celery.tests.utils import Case
 from celery.tests.utils import Case
@@ -14,6 +17,13 @@ from celery.tests.utils import Case
 
 
 class test_import_utils(Case):
 class test_import_utils(Case):
 
 
+    def test_find_module(self):
+        self.assertTrue(find_module("celery"))
+        imp = Mock()
+        imp.return_value = None
+        with self.assertRaises(NotAPackage):
+            find_module("foo.bar.baz", imp=imp)
+
     def test_qualname(self):
     def test_qualname(self):
         Class = type("Fox", (object, ), {"__module__": "quick.brown"})
         Class = type("Fox", (object, ), {"__module__": "quick.brown"})
         self.assertEqual(qualname(Class), "quick.brown.Fox")
         self.assertEqual(qualname(Class), "quick.brown.Fox")
@@ -32,6 +42,7 @@ class test_import_utils(Case):
         from celery.worker import WorkController
         from celery.worker import WorkController
         self.assertIs(symbol_by_name(".worker:WorkController",
         self.assertIs(symbol_by_name(".worker:WorkController",
                     package="celery"), WorkController)
                     package="celery"), WorkController)
+        self.assertTrue(symbol_by_name(":group", package="celery"))
 
 
     @patch("celery.utils.imports.reload")
     @patch("celery.utils.imports.reload")
     def test_reload_from_cwd(self, reload):
     def test_reload_from_cwd(self, reload):

+ 79 - 0
celery/tests/utilities/test_saferef.py

@@ -0,0 +1,79 @@
+from __future__ import absolute_import
+
+from celery.utils.dispatch.saferef import safe_ref
+from celery.tests.utils import Case
+
+
+class Class1(object):
+
+    def x(self):
+        pass
+
+
+def fun(obj):
+    pass
+
+
+class Class2(object):
+
+    def __call__(self, obj):
+        pass
+
+
+class SaferefTests(Case):
+
+    def setUp(self):
+        ts = []
+        ss = []
+        for x in xrange(5000):
+            t = Class1()
+            ts.append(t)
+            s = safe_ref(t.x, self._closure)
+            ss.append(s)
+        ts.append(fun)
+        ss.append(safe_ref(fun, self._closure))
+        for x in xrange(30):
+            t = Class2()
+            ts.append(t)
+            s = safe_ref(t, self._closure)
+            ss.append(s)
+        self.ts = ts
+        self.ss = ss
+        self.closureCount = 0
+
+    def tearDown(self):
+        del self.ts
+        del self.ss
+
+    def testIn(self):
+        """Test the "in" operator for safe references (cmp)"""
+        for t in self.ts[:50]:
+            self.assertTrue(safe_ref(t.x) in self.ss)
+
+    def testValid(self):
+        """Test that the references are valid (return instance methods)"""
+        for s in self.ss:
+            self.assertTrue(s())
+
+    def testShortCircuit(self):
+        """Test that creation short-circuits to reuse existing references"""
+        sd = {}
+        for s in self.ss:
+            sd[s] = 1
+        for t in self.ts:
+            if hasattr(t, 'x'):
+                self.assertIn(safe_ref(t.x), sd)
+            else:
+                self.assertIn(safe_ref(t), sd)
+
+    def testRepresentation(self):
+        """Test that the reference object's representation works
+
+        XXX Doesn't currently check the results, just that no error
+            is raised
+        """
+        repr(self.ss[-1])
+
+    def _closure(self, ref):
+        """Dumb utility mechanism to increment deletion counter"""
+        self.closureCount += 1

+ 29 - 1
celery/tests/utilities/test_timeutils.py

@@ -1,8 +1,13 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
+from __future__ import with_statement
 
 
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 
 
+from mock import Mock
+
+from celery.exceptions import ImproperlyConfigured
 from celery.utils import timeutils
 from celery.utils import timeutils
+from celery.utils.timeutils import timezone
 from celery.tests.utils import Case
 from celery.tests.utils import Case
 
 
 
 
@@ -54,10 +59,33 @@ class test_timeutils(Case):
         now = datetime.now()
         now = datetime.now()
         self.assertIs(timeutils.maybe_iso8601(now), now)
         self.assertIs(timeutils.maybe_iso8601(now), now)
 
 
-    def test_maybe_timdelta(self):
+    def test_maybe_timedelta(self):
         D = timeutils.maybe_timedelta
         D = timeutils.maybe_timedelta
 
 
         for i in (30, 30.6):
         for i in (30, 30.6):
             self.assertEqual(D(i), timedelta(seconds=i))
             self.assertEqual(D(i), timedelta(seconds=i))
 
 
         self.assertEqual(D(timedelta(days=2)), timedelta(days=2))
         self.assertEqual(D(timedelta(days=2)), timedelta(days=2))
+
+    def test_remaining_relative(self):
+        timeutils.remaining(datetime.utcnow(), timedelta(hours=1),
+                relative=True)
+
+
+class test_timezone(Case):
+
+    def test_get_timezone_with_pytz(self):
+        prev, timeutils.pytz = timeutils.pytz, Mock()
+        try:
+            self.assertTrue(timezone.get_timezone("UTC"))
+        finally:
+            timeutils.pytz = prev
+
+    def test_get_timezone_without_pytz(self):
+        prev, timeutils.pytz = timeutils.pytz, None
+        try:
+            self.assertTrue(timezone.get_timezone("UTC"))
+            with self.assertRaises(ImproperlyConfigured):
+                timezone.get_timezone("Europe/Oslo")
+        finally:
+            timeutils.pytz = prev

+ 16 - 1
celery/tests/utilities/test_utils.py

@@ -3,10 +3,12 @@ from __future__ import with_statement
 
 
 from kombu.utils.functional import promise
 from kombu.utils.functional import promise
 
 
+from mock import patch
+
 from celery import utils
 from celery import utils
 from celery.utils import text
 from celery.utils import text
 from celery.utils import functional
 from celery.utils import functional
-from celery.utils.functional import mpromise
+from celery.utils.functional import mpromise, maybe_list
 from celery.utils.threads import bgThread
 from celery.utils.threads import bgThread
 from celery.tests.utils import Case
 from celery.tests.utils import Case
 
 
@@ -112,6 +114,9 @@ class test_utils(Case):
         self.assertEqual(text.abbrtask("feeds.tasks.refresh", 30),
         self.assertEqual(text.abbrtask("feeds.tasks.refresh", 30),
                                         "feeds.tasks.refresh")
                                         "feeds.tasks.refresh")
 
 
+    def test_pretty(self):
+        self.assertTrue(text.pretty(("a", "b", "c")))
+
     def test_cached_property(self):
     def test_cached_property(self):
 
 
         def fun(obj):
         def fun(obj):
@@ -122,6 +127,16 @@ class test_utils(Case):
         self.assertIs(x.__set__(None, None), x)
         self.assertIs(x.__set__(None, None), x)
         self.assertIs(x.__delete__(None), x)
         self.assertIs(x.__delete__(None), x)
 
 
+    def test_maybe_list(self):
+        self.assertEqual(maybe_list(1), [1])
+        self.assertEqual(maybe_list([1]), [1])
+        self.assertIsNone(maybe_list(None))
+
+    @patch("warnings.warn")
+    def test_warn_deprecated(self, warn):
+        utils.warn_deprecated("Foo")
+        self.assertTrue(warn.called)
+
 
 
 class test_mpromise(Case):
 class test_mpromise(Case):
 
 

+ 19 - 0
celery/tests/utils.py

@@ -529,3 +529,22 @@ def mock_open(typ=WhateverIO, side_effect=None):
 
 
 def patch_many(*targets):
 def patch_many(*targets):
     return nested(*[mock.patch(target) for target in targets])
     return nested(*[mock.patch(target) for target in targets])
+
+
+@contextmanager
+def patch_settings(app=None, **config):
+    if app is None:
+        from celery import current_app
+        app = current_app
+    prev = {}
+    for key, value in config.iteritems():
+        try:
+            prev[key] = getattr(app.conf, key)
+        except AttributeError:
+            pass
+        setattr(app.conf, key, value)
+
+    yield app.conf
+
+    for key, value in prev.iteritems():
+        setattr(app.conf, key, value)

+ 2 - 4
celery/tests/worker/test_autoreload.py

@@ -72,18 +72,16 @@ class test_StatMonitor(Case):
             st_mtime = time()
             st_mtime = time()
         stat.return_value = st()
         stat.return_value = st()
         x = StatMonitor(["a", "b"])
         x = StatMonitor(["a", "b"])
-        calls = [0]
 
 
         def on_is_set():
         def on_is_set():
-            calls[0] += 1
-            if calls[0] > 2:
+            if x.shutdown_event.is_set.call_count > 3:
                 return True
                 return True
             return False
             return False
         x.shutdown_event = Mock()
         x.shutdown_event = Mock()
         x.shutdown_event.is_set.side_effect = on_is_set
         x.shutdown_event.is_set.side_effect = on_is_set
 
 
         x.start()
         x.start()
-        calls[0] = 0
+        x.shutdown_event = Mock()
         stat.side_effect = OSError()
         stat.side_effect = OSError()
         x.start()
         x.start()
 
 

+ 8 - 13
celery/tests/worker/test_worker.py

@@ -172,12 +172,13 @@ class test_QoS(Case):
         qos = QoS(consumer, 10)
         qos = QoS(consumer, 10)
         qos.update()
         qos.update()
         self.assertEqual(qos.value, 10)
         self.assertEqual(qos.value, 10)
-        self.assertIn({"prefetch_count": 10}, consumer.qos.call_args)
+        consumer.qos.assert_called_with(prefetch_count=10)
         qos.decrement()
         qos.decrement()
         self.assertEqual(qos.value, 9)
         self.assertEqual(qos.value, 9)
-        self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
+        consumer.qos.assert_called_with(prefetch_count=9)
         qos.decrement_eventually()
         qos.decrement_eventually()
         self.assertEqual(qos.value, 8)
         self.assertEqual(qos.value, 8)
+        consumer.qos.assert_called_with(prefetch_count=9)
         self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
         self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
 
 
         # Does not decrement 0 value
         # Does not decrement 0 value
@@ -675,17 +676,13 @@ class test_Consumer(Case):
     def test_open_connection_errback(self, sleep, connect):
     def test_open_connection_errback(self, sleep, connect):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule,
                       send_events=False)
                       send_events=False)
-        calls = [0]
         from kombu.transport.memory import Transport
         from kombu.transport.memory import Transport
         Transport.connection_errors = (StdChannelError, )
         Transport.connection_errors = (StdChannelError, )
 
 
         def effect():
         def effect():
-            try:
-                if calls[0] > 1:
-                    return
-                raise StdChannelError()
-            finally:
-                calls[0] += 1
+            if connect.call_count > 1:
+                return
+            raise StdChannelError()
         connect.side_effect = effect
         connect.side_effect = effect
         l._open_connection()
         l._open_connection()
         connect.assert_called_with()
         connect.assert_called_with()
@@ -811,10 +808,8 @@ class test_WorkController(AppCase):
         app = Celery(loader=loader, set_as_current=False)
         app = Celery(loader=loader, set_as_current=False)
         app.conf = AttributeDict(DEFAULTS)
         app.conf = AttributeDict(DEFAULTS)
         process_initializer(app, "awesome.worker.com")
         process_initializer(app, "awesome.worker.com")
-        self.assertIn((tuple(WORKER_SIGIGNORE), {}),
-                      _signals.ignore.call_args_list)
-        self.assertIn((tuple(WORKER_SIGRESET), {}),
-                      _signals.reset.call_args_list)
+        _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
+        _signals.reset.assert_any_call(*WORKER_SIGRESET)
         self.assertTrue(app.loader.init_worker.call_count)
         self.assertTrue(app.loader.init_worker.call_count)
         self.assertTrue(on_worker_process_init.called)
         self.assertTrue(on_worker_process_init.called)
         self.assertIs(_tls.current_app, app)
         self.assertIs(_tls.current_app, app)

+ 19 - 19
celery/utils/compat.py

@@ -17,21 +17,21 @@ import sys
 is_py3k = sys.version_info[0] == 3
 is_py3k = sys.version_info[0] == 3
 
 
 try:
 try:
-    reload = reload                     # noqa
-except NameError:
-    from imp import reload              # noqa
+    reload = reload                         # noqa
+except NameError:                           # pragma: no cover
+    from imp import reload                  # noqa
 
 
 try:
 try:
-    from UserList import UserList       # noqa
-except ImportError:
-    from collections import UserList    # noqa
+    from UserList import UserList           # noqa
+except ImportError:                         # pragma: no cover
+    from collections import UserList        # noqa
 
 
 try:
 try:
-    from UserDict import UserDict       # noqa
-except ImportError:
-    from collections import UserDict    # noqa
+    from UserDict import UserDict           # noqa
+except ImportError:                         # pragma: no cover
+    from collections import UserDict        # noqa
 
 
-if is_py3k:
+if is_py3k:                                 # pragma: no cover
     from io import StringIO, BytesIO
     from io import StringIO, BytesIO
     from .encoding import bytes_to_str
     from .encoding import bytes_to_str
 
 
@@ -40,21 +40,21 @@ if is_py3k:
         def write(self, data):
         def write(self, data):
             StringIO.write(self, bytes_to_str(data))
             StringIO.write(self, bytes_to_str(data))
 else:
 else:
-    from StringIO import StringIO       # noqa
-    BytesIO = WhateverIO = StringIO     # noqa
+    from StringIO import StringIO           # noqa
+    BytesIO = WhateverIO = StringIO         # noqa
 
 
 
 
 ############## collections.OrderedDict ######################################
 ############## collections.OrderedDict ######################################
 try:
 try:
     from collections import OrderedDict
     from collections import OrderedDict
-except ImportError:
-    from ordereddict import OrderedDict  # noqa
+except ImportError:                         # pragma: no cover
+    from ordereddict import OrderedDict     # noqa
 
 
 ############## itertools.zip_longest #######################################
 ############## itertools.zip_longest #######################################
 
 
 try:
 try:
     from itertools import izip_longest as zip_longest
     from itertools import izip_longest as zip_longest
-except ImportError:
+except ImportError:                         # pragma: no cover
     import itertools
     import itertools
 
 
     def zip_longest(*args, **kwds):  # noqa
     def zip_longest(*args, **kwds):  # noqa
@@ -77,14 +77,14 @@ except ImportError:
 from itertools import chain
 from itertools import chain
 
 
 
 
-def _compat_chain_from_iterable(iterables):
+def _compat_chain_from_iterable(iterables):  # pragma: no cover
     for it in iterables:
     for it in iterables:
         for element in it:
         for element in it:
             yield element
             yield element
 
 
 try:
 try:
     chain_from_iterable = getattr(chain, "from_iterable")
     chain_from_iterable = getattr(chain, "from_iterable")
-except AttributeError:
+except AttributeError:   # pragma: no cover
     chain_from_iterable = _compat_chain_from_iterable
     chain_from_iterable = _compat_chain_from_iterable
 
 
 
 
@@ -94,13 +94,13 @@ import os
 from stat import ST_DEV, ST_INO
 from stat import ST_DEV, ST_INO
 import platform as _platform
 import platform as _platform
 
 
-if _platform.system() == "Windows":
+if _platform.system() == "Windows":  # pragma: no cover
     #since windows doesn't go with WatchedFileHandler use FileHandler instead
     #since windows doesn't go with WatchedFileHandler use FileHandler instead
     WatchedFileHandler = logging.FileHandler
     WatchedFileHandler = logging.FileHandler
 else:
 else:
     try:
     try:
         from logging.handlers import WatchedFileHandler
         from logging.handlers import WatchedFileHandler
-    except ImportError:
+    except ImportError:  # pragma: no cover
         class WatchedFileHandler(logging.FileHandler):  # noqa
         class WatchedFileHandler(logging.FileHandler):  # noqa
             """
             """
             A handler for logging to a file, which watches the file
             A handler for logging to a file, which watches the file

+ 4 - 4
celery/utils/dispatch/saferef.py

@@ -11,7 +11,7 @@ import weakref
 import traceback
 import traceback
 
 
 
 
-def safe_ref(target, on_delete=None):
+def safe_ref(target, on_delete=None):  # pragma: no cover
     """Return a *safe* weak reference to a callable target
     """Return a *safe* weak reference to a callable target
 
 
     :param target: the object to be weakly referenced, if it's a
     :param target: the object to be weakly referenced, if it's a
@@ -37,7 +37,7 @@ def safe_ref(target, on_delete=None):
         return weakref.ref(target)
         return weakref.ref(target)
 
 
 
 
-class BoundMethodWeakref(object):
+class BoundMethodWeakref(object):  # pragma: no cover
     """'Safe' and reusable weak references to instance methods.
     """'Safe' and reusable weak references to instance methods.
 
 
     BoundMethodWeakref objects provide a mechanism for
     BoundMethodWeakref objects provide a mechanism for
@@ -199,7 +199,7 @@ class BoundMethodWeakref(object):
                 return function.__get__(target)
                 return function.__get__(target)
 
 
 
 
-class BoundNonDescriptorMethodWeakref(BoundMethodWeakref):
+class BoundNonDescriptorMethodWeakref(BoundMethodWeakref):  # pragma: no cover
     """A specialized :class:`BoundMethodWeakref`, for platforms where
     """A specialized :class:`BoundMethodWeakref`, for platforms where
     instance methods are not descriptors.
     instance methods are not descriptors.
 
 
@@ -269,7 +269,7 @@ class BoundNonDescriptorMethodWeakref(BoundMethodWeakref):
                 return getattr(target, function.__name__)
                 return getattr(target, function.__name__)
 
 
 
 
-def get_bound_method_weakref(target, on_delete):
+def get_bound_method_weakref(target, on_delete):  # pragma: no cover
     """Instantiates the appropiate :class:`BoundMethodWeakRef`, depending
     """Instantiates the appropiate :class:`BoundMethodWeakRef`, depending
     on the details of the underlying class method implementation."""
     on the details of the underlying class method implementation."""
     if hasattr(target, '__get__'):
     if hasattr(target, '__get__'):

+ 2 - 7
celery/utils/dispatch/signal.py

@@ -3,23 +3,18 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 import weakref
 import weakref
-try:
-    set
-except NameError:
-    from sets import Set as set                 # Python 2.3 fallback
-
 from . import saferef
 from . import saferef
 
 
 WEAKREF_TYPES = (weakref.ReferenceType, saferef.BoundMethodWeakref)
 WEAKREF_TYPES = (weakref.ReferenceType, saferef.BoundMethodWeakref)
 
 
 
 
-def _make_id(target):
+def _make_id(target):  # pragma: no cover
     if hasattr(target, 'im_func'):
     if hasattr(target, 'im_func'):
         return (id(target.im_self), id(target.im_func))
         return (id(target.im_self), id(target.im_func))
     return id(target)
     return id(target)
 
 
 
 
-class Signal(object):
+class Signal(object):  # pragma: no cover
     """Base class for all signals
     """Base class for all signals
 
 
 
 

+ 4 - 4
celery/utils/functional.py

@@ -20,9 +20,9 @@ from threading import Lock, RLock
 
 
 try:
 try:
     from collections import Sequence
     from collections import Sequence
-except ImportError:
+except ImportError:             # pragma: no cover
     # <= Py2.5
     # <= Py2.5
-    Sequence = (list, tuple)  # noqa
+    Sequence = (list, tuple)    # noqa
 
 
 from kombu.utils.functional import promise, maybe_promise
 from kombu.utils.functional import promise, maybe_promise
 
 
@@ -76,7 +76,7 @@ class LRUCache(UserDict):
         for k in self:
         for k in self:
             try:
             try:
                 yield (k, self.data[k])
                 yield (k, self.data[k])
-            except KeyError:
+            except KeyError:  # pragma: no cover
                 pass
                 pass
     iteritems = _iterate_items
     iteritems = _iterate_items
 
 
@@ -100,7 +100,7 @@ class LRUCache(UserDict):
 def maybe_list(l):
 def maybe_list(l):
     if l is None:
     if l is None:
         return l
         return l
-    elif isinstance(l, Sequence):
+    elif not isinstance(l, basestring) and isinstance(l, Sequence):
         return l
         return l
     return [l]
     return [l]
 
 

+ 2 - 2
celery/utils/log.py

@@ -45,8 +45,8 @@ class ColorFormatter(logging.Formatter):
     colors = {"DEBUG": COLORS["blue"], "WARNING": COLORS["yellow"],
     colors = {"DEBUG": COLORS["blue"], "WARNING": COLORS["yellow"],
               "ERROR": COLORS["red"], "CRITICAL": COLORS["magenta"]}
               "ERROR": COLORS["red"], "CRITICAL": COLORS["magenta"]}
 
 
-    def __init__(self, msg, use_color=True):
-        logging.Formatter.__init__(self, msg)
+    def __init__(self, fmt=None, use_color=True):
+        logging.Formatter.__init__(self, fmt)
         self.use_color = use_color
         self.use_color = use_color
 
 
     def formatException(self, ei):
     def formatException(self, ei):

+ 2 - 4
celery/utils/mail.py

@@ -18,6 +18,7 @@ import warnings
 
 
 from email.mime.text import MIMEText
 from email.mime.text import MIMEText
 
 
+from .functional import maybe_list
 from .imports import symbol_by_name
 from .imports import symbol_by_name
 
 
 supports_timeout = sys.version_info >= (2, 6)
 supports_timeout = sys.version_info >= (2, 6)
@@ -31,15 +32,12 @@ class Message(object):
 
 
     def __init__(self, to=None, sender=None, subject=None, body=None,
     def __init__(self, to=None, sender=None, subject=None, body=None,
             charset="us-ascii"):
             charset="us-ascii"):
-        self.to = to
+        self.to = maybe_list(to)
         self.sender = sender
         self.sender = sender
         self.subject = subject
         self.subject = subject
         self.body = body
         self.body = body
         self.charset = charset
         self.charset = charset
 
 
-        if not isinstance(self.to, (list, tuple)):
-            self.to = [self.to]
-
     def __repr__(self):
     def __repr__(self):
         return "<Email: To:%r Subject:%r>" % (self.to, self.subject)
         return "<Email: To:%r Subject:%r>" % (self.to, self.subject)
 
 

+ 1 - 1
celery/utils/serialization.py

@@ -74,7 +74,7 @@ def find_nearest_pickleable_exception(exc):
     getmro_ = getattr(cls, "mro", None)
     getmro_ = getattr(cls, "mro", None)
 
 
     # old-style classes doesn't have mro()
     # old-style classes doesn't have mro()
-    if not getmro_:
+    if not getmro_:  # pragma: no cover
         # all Py2.4 exceptions has a baseclass.
         # all Py2.4 exceptions has a baseclass.
         if not getattr(cls, "__bases__", ()):
         if not getattr(cls, "__bases__", ()):
             return
             return

+ 0 - 1
celery/utils/timeutils.py

@@ -71,7 +71,6 @@ class _Zone(object):
     @cached_property
     @cached_property
     def utc(self):
     def utc(self):
         return self.get_timezone("UTC")
         return self.get_timezone("UTC")
-
 timezone = _Zone()
 timezone = _Zone()
 
 
 
 

+ 2 - 1
contrib/release/doc4allmods

@@ -2,7 +2,8 @@
 
 
 PACKAGE="$1"
 PACKAGE="$1"
 SKIP_PACKAGES="$PACKAGE tests management urls"
 SKIP_PACKAGES="$PACKAGE tests management urls"
-SKIP_FILES="celery.backends.pyredis.rst
+SKIP_FILES="celery.__compat__.rst
+            celery.backends.pyredis.rst
             celery.bin.rst
             celery.bin.rst
             celery.bin.celeryd_detach.rst
             celery.bin.celeryd_detach.rst
             celery.concurrency.processes._win.rst
             celery.concurrency.processes._win.rst

+ 3 - 6
contrib/release/py3k-run-tests

@@ -8,9 +8,6 @@ nosetests -vd celery.tests                                      \
             --cover3-html                                       \
             --cover3-html                                       \
             --cover3-html-dir="$base/cover"                     \
             --cover3-html-dir="$base/cover"                     \
             --cover3-package=celery                             \
             --cover3-package=celery                             \
-            --cover3-exclude="                                  \
-              celery.tests.*                                    \
-              celery.utils.compat                               \
-              celery.utils.dispatch*"                           \
-            --with-xunit                                        \
-              --xunit-file="$base/nosetests.xml"
+            --cover3-exclude="celery.tests.*"                   \
+          --with-xunit                                          \
+            --xunit-file="$base/nosetests.xml"

+ 1 - 1
requirements/test.txt

@@ -2,7 +2,7 @@ unittest2>=0.4.0
 nose
 nose
 nose-cover3
 nose-cover3
 coverage>=3.0
 coverage>=3.0
-mock>=0.7.0
+mock==dev
 redis
 redis
 pymongo
 pymongo
 SQLAlchemy
 SQLAlchemy

+ 0 - 2
setup.cfg

@@ -4,8 +4,6 @@ cover3-branch = 1
 cover3-html = 1
 cover3-html = 1
 cover3-package = celery
 cover3-package = celery
 cover3-exclude = celery.tests.*
 cover3-exclude = celery.tests.*
-                 celery.utils.compat
-                 celery.utils.dispatch*
 
 
 [build_sphinx]
 [build_sphinx]
 source-dir = docs/
 source-dir = docs/

+ 1 - 1
setup.py

@@ -126,7 +126,7 @@ elif py_version[0:2] == (2, 5):
 
 
 # -*- Tests Requires -*-
 # -*- Tests Requires -*-
 
 
-tests_require = ["nose", "nose-cover3", "sqlalchemy", "mock"]
+tests_require = ["nose", "nose-cover3", "sqlalchemy", "mock==dev"]
 if sys.version_info < (2, 7):
 if sys.version_info < (2, 7):
     tests_require.append("unittest2")
     tests_require.append("unittest2")
 elif sys.version_info <= (2, 5):
 elif sys.version_info <= (2, 5):