Browse Source

That's 99% coverage, probably don't want to take it further

Ask Solem 15 years ago
parent
commit
8e6c9ba22c

+ 2 - 5
celery/backends/base.py

@@ -29,7 +29,8 @@ def find_nearest_pickleable_exception(exc):
     unwanted = (Exception, BaseException, object)
     is_unwanted = lambda exc: any(map(curry(operator.is_, exc), unwanted))
 
-    for supercls in exc.__class__.mro():
+    mro_ = getattr(exc.__class__, "mro", lambda: [])
+    for supercls in mro_():
         if is_unwanted(supercls):
             # only BaseException and object, from here on down,
             # we don't care about these.
@@ -134,10 +135,6 @@ class BaseBackend(object):
             return exc_cls(*exc.exc_args)
         return exc
 
-    def mark_as_retry(self, task_id, exc):
-        """Mark task for retry."""
-        return self.store_result(task_id, exc, status="RETRY")
-
     def get_status(self, task_id):
         """Get the status of a task."""
         raise NotImplementedError(

+ 2 - 42
celery/pool.py

@@ -14,46 +14,6 @@ from celery.utils import gen_unique_id
 from functools import partial as curry
 
 
-class DynamicPool(Pool):
-    """Version of :class:`multiprocessing.Pool` that can dynamically grow
-    in size."""
-
-    def __init__(self, processes=None, initializer=None, initargs=()):
-        super(DynamicPool, self).__init__(processes=processes,
-                                          initializer=initializer,
-                                          initargs=initargs)
-        self._initializer = initializer
-        self._initargs = initargs
-
-    def add_worker(self):
-        """Add another worker to the pool."""
-        w = self.Process(target=worker,
-                         args=(self._inqueue, self._outqueue,
-                               self._initializer, self._initargs))
-        self._pool.append(w)
-        w.name = w.name.replace("Process", "PoolWorker")
-        w.daemon = True
-        w.start()
-
-    def grow(self, size=1):
-        """Add ``increment`` new workers to the pool."""
-        [self._add_worker() for i in xrange(size)]
-
-    def get_worker_pids(self):
-        """Returns the process id's of all the pool workers."""
-        return [process.pid for process in self._pool]
-
-    def reap_dead_workers(self):
-        dead = [process for process in self._pool
-                            if not process.is_alive()]
-        self._pool = [process for process in self._pool
-                            if process not in dead]
-        return dead
-
-    def replace_dead_workers(self):
-        self.grow(len(self.find_dead_workers()))
-
-
 class TaskPool(object):
     """Pool of running child processes, which starts waiting for the
     processes to finish when the queue limit has been reached.
@@ -88,7 +48,7 @@ class TaskPool(object):
 
         """
         self._processes = {}
-        self._pool = DynamicPool(processes=self.limit)
+        self._pool = Pool(processes=self.limit)
 
     def stop(self):
         """Terminate the pool."""
@@ -173,7 +133,7 @@ class TaskPool(object):
 
     def get_worker_pids(self):
         """Returns the process id's of all the pool workers."""
-        return self._pool.get_worker_pids()
+        return [process.pid for process in self._pool._pool]
 
     def on_ready(self, callbacks, errbacks, meta, ret_value):
         """What to do when a worker task is ready and its return value has

+ 9 - 0
celery/tests/test_backends/test_base.py

@@ -1,4 +1,5 @@
 import unittest
+import types
 from celery.backends.base import find_nearest_pickleable_exception as fnpe
 from celery.backends.base import BaseBackend, KeyValueStoreBackend
 from celery.backends.base import UnpickleableExceptionWrapper
@@ -11,6 +12,7 @@ class wrapobject(object):
         self.args = args
 
 
+Oldstyle = types.ClassType("Oldstyle", (), {})
 Unpickleable = subclass_exception("Unpickleable", KeyError, "foo.module")
 Impossible = subclass_exception("Impossible", object, "foo.module")
 Lookalike = subclass_exception("Lookalike", wrapobject, "foo.module")
@@ -27,9 +29,16 @@ class TestBaseBackendInterface(unittest.TestCase):
         self.assertRaises(NotImplementedError,
                 b.store_result, "SOMExx-N0nex1stant-IDxx-", 42, "DONE")
 
+    def test_get_result(self):
+        self.assertRaises(NotImplementedError,
+                b.get_result, "SOMExx-N0nex1stant-IDxx-")
+
 
 class TestPickleException(unittest.TestCase):
 
+    def test_oldstyle(self):
+        self.assertTrue(fnpe(Oldstyle()) is None)
+
     def test_BaseException(self):
         self.assertTrue(fnpe(Exception()) is None)
 

+ 24 - 0
celery/tests/test_backends/test_database.py

@@ -1,6 +1,10 @@
 import unittest
 from celery.backends.database import Backend
 from celery.utils import gen_unique_id
+from celery.task import PeriodicTask
+from celery import registry
+from celery.models import PeriodicTaskMeta
+from datetime import datetime, timedelta
 
 
 class SomeClass(object):
@@ -9,8 +13,28 @@ class SomeClass(object):
         self.data = data
 
 
+class MyPeriodicTask(PeriodicTask):
+    name = "c.u.my-periodic-task-244"
+    run_every = timedelta(seconds=1)
+
+    def run(self, **kwargs):
+        return 42
+registry.tasks.register(MyPeriodicTask)
+
+
 class TestDatabaseBackend(unittest.TestCase):
 
+    def test_run_periodic_tasks(self):
+        #obj, created = PeriodicTaskMeta.objects.get_or_create(
+        #                    name=MyPeriodicTask.name,
+        #                    defaults={"last_run_at": datetime.now() -
+        #                        timedelta(days=-4)})
+        #if not created:
+        #    obj.last_run_at = datetime.now() - timedelta(days=4)
+        #    obj.save()
+        b = Backend()
+        b.run_periodic_tasks()
+
     def test_backend(self):
         b = Backend()
         tid = gen_unique_id()

+ 7 - 1
celery/tests/test_task_builtins.py

@@ -1,5 +1,5 @@
 import unittest
-from celery.task.builtins import PingTask
+from celery.task.builtins import PingTask, DeleteExpiredTaskMetaTask
 from celery.task.base import ExecuteRemoteTask
 from celery.serialization import pickle
 
@@ -20,3 +20,9 @@ class TestRemoteExecuteTask(unittest.TestCase):
         self.assertEquals(ExecuteRemoteTask.apply(
                             args=[pickle.dumps(some_func), [10], {}]).get(),
                           100)
+
+
+class TestDeleteExpiredTaskMetaTask(unittest.TestCase):
+
+    def test_run(self):
+        DeleteExpiredTaskMetaTask.apply()

+ 6 - 0
celery/tests/test_worker_controllers.py

@@ -89,3 +89,9 @@ class TestPeriodicWorkController(unittest.TestCase):
         hold_queue = Queue()
         m = PeriodicWorkController(bucket_queue, hold_queue)
         m.run_periodic_tasks()
+
+    def test_on_iteration(self):
+        bucket_queue = Queue()
+        hold_queue = Queue()
+        m = PeriodicWorkController(bucket_queue, hold_queue)
+        m.on_iteration()

+ 71 - 1
celery/tests/test_worker_job.py

@@ -1,4 +1,5 @@
 # -*- coding: utf-8 -*-
+import sys
 import unittest
 from celery.worker.job import jail
 from celery.worker.job import TaskWrapper
@@ -8,7 +9,11 @@ from celery.registry import tasks, NotRegistered
 from celery.pool import TaskPool
 from celery.utils import gen_unique_id
 from carrot.backends.base import BaseMessage
+from StringIO import StringIO
+from celery.log import setup_logger
+from django.core import cache
 import simplejson
+import logging
 
 scratch = {"ACK": False}
 
@@ -48,7 +53,7 @@ class TestJail(unittest.TestCase):
         from django.db import connection
         connection._was_closed = False
         old_connection_close = connection.close
-
+        
         def monkeypatched_connection_close(*args, **kwargs):
             connection._was_closed = True
             return old_connection_close(*args, **kwargs)
@@ -61,6 +66,47 @@ class TestJail(unittest.TestCase):
 
         connection.close = old_connection_close
 
+    def test_django_cache_connection_is_closed(self):
+        old_cache_close = getattr(cache.cache, "close", None)
+        old_backend = cache.settings.CACHE_BACKEND
+        cache.settings.CACHE_BACKEND = "libmemcached"
+        cache._was_closed = False
+        old_cache_parse_backend = getattr(cache, "parse_backend_uri", None)
+
+        def monkeypatched_cache_close(*args, **kwargs):
+            cache._was_closed = True
+
+        cache.cache.close = monkeypatched_cache_close
+
+        jail(gen_unique_id(), gen_unique_id(), mytask, [4], {})
+        self.assertTrue(cache._was_closed)
+        cache.cache.close = old_cache_close
+        cache.settings.CACHE_BACKEND = old_backend
+        if old_cache_parse_backend:
+            cache.parse_backend_uri = old_cache_parse_backend
+
+    def test_django_cache_connection_is_closed_django_1_1(self):
+        old_cache_close = getattr(cache.cache, "close", None)
+        old_backend = cache.settings.CACHE_BACKEND
+        cache.settings.CACHE_BACKEND = "libmemcached"
+        cache._was_closed = False
+        old_cache_parse_backend = getattr(cache, "parse_backend_uri", None)
+        cache.parse_backend_uri = lambda uri: ["libmemcached", "1", "2"]
+
+        def monkeypatched_cache_close(*args, **kwargs):
+            cache._was_closed = True
+
+        cache.cache.close = monkeypatched_cache_close
+
+        jail(gen_unique_id(), gen_unique_id(), mytask, [4], {})
+        self.assertTrue(cache._was_closed)
+        cache.cache.close = old_cache_close
+        cache.settings.CACHE_BACKEND = old_backend
+        if old_cache_parse_backend:
+            cache.parse_backend_uri = old_cache_parse_backend
+        else:
+            del(cache.parse_backend_uri)
+
 
 class TestTaskWrapper(unittest.TestCase):
 
@@ -163,3 +209,27 @@ class TestTaskWrapper(unittest.TestCase):
             "loglevel": 10,
             "task_id": tw.task_id,
             "task_name": tw.task_name})
+
+    def test_on_failure(self):
+        tid = gen_unique_id()
+        tw = TaskWrapper("cu.mytask", tid, mytask, [4], {"f": "x"})
+        try:
+            raise Exception("Inside unit tests")
+        except Exception:
+            exc_info = ExceptionInfo(sys.exc_info())
+
+        logfh = StringIO()
+        tw.logger.handlers = []
+        tw.logger = setup_logger(logfile=logfh, loglevel=logging.INFO)
+
+        from celery import conf
+        conf.SEND_CELERY_TASK_ERROR_EMAILS = True
+
+        tw.on_failure(exc_info, {"task_id": tid, "task_name": "cu.mytask"})
+        logvalue = logfh.getvalue()
+        self.assertTrue("cu.mytask" in logvalue)
+        self.assertTrue(tid in logvalue)
+        self.assertTrue("ERROR" in logvalue)
+
+        conf.SEND_CELERY_TASK_ERROR_EMAILS = False
+         

+ 2 - 1
celery/worker/controllers.py

@@ -18,6 +18,7 @@ class InfinityThread(threading.Thread):
     the :meth:`stop` method.
 
     """
+    is_infinite = True
 
     def __init__(self):
         super(InfinityThread, self).__init__()
@@ -31,7 +32,7 @@ class InfinityThread(threading.Thread):
         To start the thread use :meth:`start` instead.
 
         """
-        while True:
+        while self.is_infinite:
             if self._shutdown.isSet():
                 break
             self.on_iteration()

+ 2 - 1
celery/worker/job.py

@@ -3,7 +3,6 @@
 Jobs Executable by the Worker Server.
 
 """
-from celery.conf import SEND_CELERY_TASK_ERROR_EMAILS
 from celery.registry import tasks, NotRegistered
 from celery.datastructures import ExceptionInfo
 from celery.backends import default_backend
@@ -233,6 +232,8 @@ class TaskWrapper(object):
 
     def on_failure(self, exc_info, meta):
         """The handler used if the task raised an exception."""
+        from celery.conf import SEND_CELERY_TASK_ERROR_EMAILS
+
         task_id = meta.get("task_id")
         task_name = meta.get("task_name")
         context = {