Browse Source

100% coverage for celery.contrib.*

Ask Solem 13 years ago
parent
commit
c65922900f

+ 1 - 0
celery/app/amqp.py

@@ -335,6 +335,7 @@ class AMQP(object):
         queues."""
         return self.ConsumerSet(connection,
                                 from_dict=queues or self.queues.consume_from,
+                                channel=connection.default_channel,
                                 **kwargs)
 
     def get_default_queue(self):

+ 4 - 4
celery/contrib/migrate.py

@@ -70,11 +70,11 @@ def migrate_tasks(source, dest, timeout=1.0, app=None,
     producer = app.amqp.TaskPublisher(dest)
     if migrate is None:
         migrate = partial(migrate_task, producer)
-    if callback is not None:
-        callback = partial(callback, state)
     consumer = app.amqp.get_task_consumer(source)
     consumer.register_callback(update_state)
-    consumer.register_callback(callback)
+    if callback is not None:
+        callback = partial(callback, state)
+        consumer.register_callback(callback)
     consumer.register_callback(migrate)
 
     # declare all queues on the new broker.
@@ -90,7 +90,7 @@ def migrate_tasks(source, dest, timeout=1.0, app=None,
     # start migrating messages.
     with consumer:
         try:
-            for _ in eventloop(source, timeout=timeout):
+            for _ in eventloop(source, timeout=timeout):  # pragma: no cover
                 pass
         except socket.timeout:
             return

+ 31 - 24
celery/contrib/rdb.py

@@ -65,46 +65,53 @@ class Rdb(Pdb):
     _sock = None
 
     def __init__(self, host=CELERY_RDB_HOST, port=CELERY_RDB_PORT,
-            port_search_limit=100, port_skew=+0):
+            port_search_limit=100, port_skew=+0, out=sys.stdout):
         self.active = True
+        self.out = out
 
+        self._prev_handles = sys.stdin, sys.stdout
+
+        self._sock, this_port = self.get_avail_port(host, port,
+            port_search_limit, port_skew)
+        self._sock.listen(1)
+        me = "%s:%s" % (self.me, this_port)
+        context = self.context = {"me": me, "host": host, "port": this_port}
+        self.say("%(me)s: Please telnet %(host)s %(port)s."
+                 "  Type `exit` in session to continue." % context)
+        self.say("%(me)s: Waiting for client..." % context)
+
+        self._client, address = self._sock.accept()
+        context["remote_addr"] = ":".join(map(str, address))
+        self.say("%(me)s: In session with %(remote_addr)s" % context)
+        self._handle = sys.stdin = sys.stdout = self._client.makefile("rw")
+        Pdb.__init__(self, completekey="tab",
+                           stdin=self._handle, stdout=self._handle)
+
+    def get_avail_port(self, host, port, search_limit=100, skew=+0):
         try:
-            _, port_skew = current_process().name.split('-')
-            port_skew = int(port_skew)
+            _, skew = current_process().name.split('-')
+            skew = int(skew)
         except ValueError:
             pass
-
-        self._prev_handles = sys.stdin, sys.stdout
         this_port = None
-        for i in xrange(port_search_limit):
-            self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-            this_port = port + port_skew + i
+        for i in xrange(search_limit):
+            _sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+            this_port = port + skew + i
             try:
-                self._sock.bind((host, this_port))
+                _sock.bind((host, this_port))
             except socket.error, exc:
                 if exc.errno in [errno.EADDRINUSE, errno.EINVAL]:
                     continue
                 raise
             else:
-                break
+                return _sock, this_port
         else:
             raise Exception(
                 "%s: Could not find available port. Please set using "
                 "environment variable CELERY_RDB_PORT" % (self.me, ))
 
-        self._sock.listen(1)
-        me = "%s:%s" % (self.me, this_port)
-        context = self.context = {"me": me, "host": host, "port": this_port}
-        print("%(me)s: Please telnet %(host)s %(port)s."
-              "  Type `exit` in session to continue." % context)
-        print("%(me)s: Waiting for client..." % context)
-
-        self._client, address = self._sock.accept()
-        context["remote_addr"] = ":".join(map(str, address))
-        print("%(me)s: In session with %(remote_addr)s" % context)
-        self._handle = sys.stdin = sys.stdout = self._client.makefile("rw")
-        Pdb.__init__(self, completekey="tab",
-                           stdin=self._handle, stdout=self._handle)
+    def say(self, m):
+        self.out.write(m + "\n")
 
     def _close_session(self):
         self.stdin, self.stdout = sys.stdin, sys.stdout = self._prev_handles
@@ -112,7 +119,7 @@ class Rdb(Pdb):
         self._client.close()
         self._sock.close()
         self.active = False
-        print("%(me)s: Session %(remote_addr)s ended." % self.context)
+        self.say("%(me)s: Session %(remote_addr)s ended." % self.context)
 
     def do_continue(self, arg):
         self._close_session()

+ 2 - 1
celery/tests/test_backends/test_cassandra.py

@@ -3,7 +3,7 @@ from __future__ import with_statement
 
 import socket
 
-from mock import Mock, patch
+from mock import Mock
 from pickle import loads, dumps
 
 from celery import Celery
@@ -96,6 +96,7 @@ class test_CassandraBackend(AppCase):
 
             calls = [0]
             end = [10]
+
             def work_eventually(*arg):
                 try:
                     if calls[0] > end[0]:

+ 0 - 1
celery/tests/test_backends/test_mongodb.py

@@ -326,4 +326,3 @@ class test_MongoBackend(AppCase):
         with self.assertRaises(ImproperlyConfigured):
             x._get_database()
         db.authenticate.assert_called_with("jerry", "cere4l")
-

+ 1 - 1
celery/tests/test_concurrency/test_gevent.py

@@ -5,7 +5,7 @@ import os
 import sys
 
 from nose import SkipTest
-from mock import patch, Mock
+from mock import Mock
 
 from celery.concurrency.gevent import (
     Schedule,

+ 0 - 0
celery/tests/test_contrib/__init__.py


+ 7 - 0
celery/tests/test_task/test_task_abortable.py → celery/tests/test_contrib/test_abortable.py

@@ -1,6 +1,7 @@
 from __future__ import absolute_import
 
 from celery.contrib.abortable import AbortableTask, AbortableAsyncResult
+from celery.result import AsyncResult
 from celery.tests.utils import Case
 
 
@@ -24,6 +25,12 @@ class test_AbortableTask(Case):
         tid = result.id
         self.assertFalse(t.is_aborted(task_id=tid))
 
+    def test_is_aborted_not_abort_result(self):
+        t = MyAbortableTask()
+        t.AsyncResult = AsyncResult
+        t.request.id = "foo"
+        self.assertFalse(t.is_aborted())
+
     def test_abort_yields_aborted(self):
         t = MyAbortableTask()
         result = t.apply_async()

+ 101 - 0
celery/tests/test_contrib/test_migrate.py

@@ -0,0 +1,101 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+from kombu import BrokerConnection, Producer, Queue, Exchange
+from kombu.exceptions import StdChannelError
+from mock import patch
+
+from celery.contrib.migrate import (
+    State,
+    migrate_task,
+    migrate_tasks,
+)
+from celery.utils.encoding import bytes_t
+from celery.tests.utils import AppCase, Case, Mock
+
+
+def Message(body, exchange="exchange", routing_key="rkey",
+        compression=None, content_type="application/json",
+        content_encoding="utf-8"):
+    return Mock(attrs=dict(body=body,
+        delivery_info=dict(exchange=exchange, routing_key=routing_key),
+        headers=dict(compression=compression),
+        content_type=content_type, content_encoding=content_encoding,
+        properties={}))
+
+
+class test_State(Case):
+
+    def test_strtotal(self):
+        x = State()
+        self.assertEqual(x.strtotal, u"?")
+        x.total_apx = 100
+        self.assertEqual(x.strtotal, u"100")
+
+
+class test_migrate_task(Case):
+
+    def test_removes_compression_header(self):
+        x = Message("foo", compression="zlib")
+        producer = Mock()
+        migrate_task(producer, x.body, x)
+        self.assertTrue(producer.publish.called)
+        args, kwargs = producer.publish.call_args
+        self.assertIsInstance(args[0], bytes_t)
+        self.assertNotIn("compression", kwargs["headers"])
+        self.assertEqual(kwargs["compression"], "zlib")
+        self.assertEqual(kwargs["content_type"], "application/json")
+        self.assertEqual(kwargs["content_encoding"], "utf-8")
+        self.assertEqual(kwargs["exchange"], "exchange")
+        self.assertEqual(kwargs["routing_key"], "rkey")
+
+
+class test_migrate_tasks(AppCase):
+
+    def test_migrate(self, name="testcelery"):
+        x = BrokerConnection("memory://foo")
+        y = BrokerConnection("memory://foo")
+        # use separate state
+        x.default_channel.queues = {}
+        y.default_channel.queues = {}
+
+        ex = Exchange(name, "direct")
+        q = Queue(name, exchange=ex, routing_key=name)
+        q(x.default_channel).declare()
+        Producer(x).publish("foo", exchange=name, routing_key=name)
+        Producer(x).publish("bar", exchange=name, routing_key=name)
+        Producer(x).publish("baz", exchange=name, routing_key=name)
+        self.assertTrue(x.default_channel.queues)
+        self.assertFalse(y.default_channel.queues)
+
+        migrate_tasks(x, y)
+
+        yq = q(y.default_channel)
+        self.assertEqual(yq.get().body, "foo")
+        self.assertEqual(yq.get().body, "bar")
+        self.assertEqual(yq.get().body, "baz")
+
+        Producer(x).publish("foo", exchange=name, routing_key=name)
+        callback = Mock()
+        migrate_tasks(x, y, callback=callback)
+        self.assertTrue(callback.called)
+        migrate = Mock()
+        Producer(x).publish("baz", exchange=name, routing_key=name)
+        migrate_tasks(x, y, callback=callback, migrate=migrate)
+        self.assertTrue(migrate.called)
+
+        with patch("kombu.transport.virtual.Channel.queue_declare") as qd:
+
+            def effect(*args, **kwargs):
+                if kwargs.get("passive"):
+                    raise StdChannelError()
+                return 0, 3, 0
+            qd.side_effect = effect
+            migrate_tasks(x, y)
+
+        x = BrokerConnection("memory://")
+        x.default_channel.queues = {}
+        y.default_channel.queues = {}
+        callback = Mock()
+        migrate_tasks(x, y, callback=callback)
+        self.assertFalse(callback.called)

+ 99 - 0
celery/tests/test_contrib/test_rdb.py

@@ -0,0 +1,99 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+import errno
+import socket
+
+from mock import Mock, patch
+
+from celery.contrib.rdb import (
+    Rdb,
+    debugger,
+    set_trace,
+)
+from celery.tests.utils import Case, WhateverIO
+
+
+class test_Rdb(Case):
+
+    @patch("celery.contrib.rdb.Rdb")
+    def test_debugger(self, Rdb):
+        x = debugger()
+        self.assertTrue(x)
+        self.assertIs(x, debugger())
+
+    @patch("celery.contrib.rdb.debugger")
+    @patch("celery.contrib.rdb._frame")
+    def test_set_trace(self, _frame, debugger):
+        self.assertTrue(set_trace(Mock()))
+        self.assertTrue(set_trace())
+        self.assertTrue(debugger.return_value.set_trace.called)
+
+    @patch("celery.contrib.rdb.Rdb.get_avail_port")
+    def test_rdb(self, get_avail_port):
+        sock = Mock()
+        get_avail_port.return_value = (sock, 8000)
+        sock.accept.return_value = (Mock(), ["helu"])
+        out = WhateverIO()
+        rdb = Rdb(out=out)
+        self.assertTrue(get_avail_port.called)
+        self.assertIn("helu", out.getvalue())
+
+        # set_quit
+        with patch("sys.settrace") as settrace:
+            rdb.set_quit()
+            settrace.assert_called_with(None)
+
+        # set_trace
+        with patch("celery.contrib.rdb.Pdb.set_trace") as pset:
+            with patch("celery.contrib.rdb._frame"):
+                rdb.set_trace()
+                rdb.set_trace(Mock())
+                pset.side_effect = socket.error
+                pset.side_effect.errno = errno.ECONNRESET
+                rdb.set_trace()
+                pset.side_effect.errno = errno.ENOENT
+                with self.assertRaises(socket.error):
+                    rdb.set_trace()
+
+        # _close_session
+        rdb._close_session()
+
+        # do_continue
+        rdb.set_continue = Mock()
+        rdb.do_continue(Mock())
+        rdb.set_continue.assert_called_with()
+
+        # do_quit
+        rdb.set_quit = Mock()
+        rdb.do_quit(Mock())
+        rdb.set_quit.assert_called_with()
+
+    @patch("socket.socket")
+    def test_get_avail_port(self, sock):
+        out = WhateverIO()
+        sock.return_value.accept.return_value = (Mock(), ["helu"])
+        Rdb(out=out)
+
+        with patch("celery.contrib.rdb.current_process") as curproc:
+            curproc.return_value.name = "PoolWorker-10"
+            Rdb(out=out)
+
+        err = sock.return_value.bind.side_effect = socket.error()
+        err.errno = errno.ENOENT
+        with self.assertRaises(socket.error):
+            Rdb(out=out)
+        err.errno = errno.EADDRINUSE
+        with self.assertRaises(Exception):
+            Rdb(out=out)
+        called = [0]
+
+        def effect(*a, **kw):
+            try:
+                if called[0] > 50:
+                    return True
+                raise err
+            finally:
+                called[0] += 1
+        sock.return_value.bind.side_effect = effect
+        Rdb(out=out)

+ 1 - 1
celery/tests/test_security/__init__.py

@@ -125,6 +125,7 @@ class test_security(SecurityCase):
     @patch("celery.security.disable_untrusted_serializers")
     def test_setup_registry_complete(self, dis, reg, key="KEY", cert="CERT"):
         calls = [0]
+
         def effect(*args):
             try:
                 m = Mock()
@@ -139,7 +140,6 @@ class test_security(SecurityCase):
             dis.assert_called_with(["json"])
             reg.assert_called_with("A", "B", store)
 
-
     def test_security_conf(self):
         current_app.conf.CELERY_TASK_SERIALIZER = 'auth'
 

+ 1 - 0
celery/tests/test_security/test_certificate.py

@@ -1,4 +1,5 @@
 from __future__ import absolute_import
+from __future__ import with_statement
 
 from celery.exceptions import SecurityError
 from celery.security.certificate import Certificate, CertStore, FSCertStore

+ 1 - 1
celery/tests/test_worker/test_worker_autoreload.py

@@ -20,7 +20,7 @@ from celery.worker.autoreload import (
     Autoreloader,
 )
 
-from celery.tests.utils import AppCase, Case, WhateverIO, mock_open
+from celery.tests.utils import AppCase, Case, mock_open
 
 
 class test_WorkerComponent(AppCase):

+ 1 - 0
celery/tests/utils.py

@@ -1,4 +1,5 @@
 from __future__ import absolute_import
+from __future__ import with_statement
 
 try:
     import unittest

+ 1 - 10
contrib/release/py3k-run-tests

@@ -9,17 +9,8 @@ nosetests -vd celery.tests                                      \
             --cover3-html-dir="$base/cover"                     \
             --cover3-package=celery                             \
             --cover3-exclude="                                  \
-              celery                                            \
               celery.tests.*                                    \
               celery.utils.compat                               \
-              celery.utils.dispatch*                            \
-              celery.db.a805d4bd                                \
-              celery.db.dfd042c7                                \
-              celery.contrib*                                   \
-              celery.concurrency.threads                        \
-              celery.concurrency.gevent                         \
-              celery.backends.mongodb                           \
-              celery.backends.cassandra                         \
-              celery.events.cursesmon"                          \
+              celery.utils.dispatch*"                           \
             --with-xunit                                        \
               --xunit-file="$base/nosetests.xml"

+ 1 - 3
setup.cfg

@@ -3,11 +3,9 @@ where = celery/tests
 cover3-branch = 1
 cover3-html = 1
 cover3-package = celery
-cover3-exclude = celery
-                 celery.tests.*
+cover3-exclude = celery.tests.*
                  celery.utils.compat
                  celery.utils.dispatch*
-                 celery.contrib*
 
 [build_sphinx]
 source-dir = docs/