Pārlūkot izejas kodu

celery.tests.functional: Functional test suite added.

Ask Solem 14 gadi atpakaļ
vecāks
revīzija
d7b846f445

+ 7 - 0
celery/tests/functional/__init__.py

@@ -0,0 +1,7 @@
+import os
+
+config = os.environ.setdefault("CELERY_FUNTEST_CONFIG_MODULE",
+                               "celery.tests.functional.config")
+
+os.environ["CELERY_CONFIG_MODULE"] = config
+os.environ["CELERY_LOADER"] = "default"

+ 120 - 0
celery/tests/functional/case.py

@@ -0,0 +1,120 @@
+import atexit
+import logging
+import os
+import signal
+import socket
+import sys
+import traceback
+import unittest2 as unittest
+
+from itertools import count
+
+from celery.task.control import broadcast, ping
+from celery.utils import get_full_cls_name
+
+HOSTNAME = socket.gethostname()
+
+def say(msg):
+    sys.stderr.write("%s\n" % msg)
+
+def flatten_response(response):
+    flat = {}
+    for item in response:
+        flat.update(item)
+    return flat
+
+class Worker(object):
+    started = False
+    next_worker_id = count(1).next
+    _shutdown_called = False
+
+    def __init__(self, hostname, loglevel="error"):
+        self.hostname = hostname
+        self.loglevel = loglevel
+
+    def start(self):
+        if not self.started:
+            self._fork_and_exec()
+            self.started = True
+
+    def _fork_and_exec(self):
+        pid = os.fork()
+        if pid == 0:
+            os.execv(sys.executable,
+                    [sys.executable] + ["-m", "celery.bin.celeryd",
+                                        "-l", self.loglevel,
+                                        "-n", self.hostname])
+            os.exit()
+        self.pid = pid
+
+    def is_alive(self, timeout=1):
+        r = ping(destination=[self.hostname],
+                 timeout=timeout)
+        return self.hostname in flatten_response(r)
+
+    def wait_until_started(self, timeout=10, interval=0.2):
+        for iteration in count(0):
+            if iteration * interval >= timeout:
+                raise Exception(
+                        "Worker won't start (after %s secs.)" % timeout)
+            if self.is_alive(interval):
+                break
+        say("--WORKER %s IS ONLINE--" % self.hostname)
+
+    def ensure_shutdown(self, timeout=10, interval=0.5):
+        os.kill(self.pid, signal.SIGTERM)
+        for iteration in count(0):
+            if iteration * interval >= timeout:
+                raise Exception(
+                        "Worker won't shutdown (after %s secs.)" % timeout)
+            broadcast("shutdown", destination=[self.hostname])
+            if not self.is_alive(interval):
+                break
+        say("--WORKER %s IS SHUTDOWN--" % self.hostname)
+        self._shutdown_called = True
+
+    def ensure_started(self):
+        self.start()
+        self.wait_until_started()
+
+    @classmethod
+    def managed(cls, hostname=None, caller=None):
+        hostname = hostname or socket.gethostname()
+        if caller:
+            hostname = ".".join([get_full_cls_name(caller), hostname])
+        else:
+            hostname += str(cls.next_worker_id())
+        worker = cls(hostname)
+        worker.ensure_started()
+        stack = traceback.format_stack()
+
+        @atexit.register
+        def _ensure_shutdown_once():
+            if not worker._shutdown_called:
+                say("-- Found worker not stopped at shutdown: %s\n%s" % (
+                        worker.hostname,
+                        "\n".join(stack)))
+                worker.ensure_shutdown()
+
+        return worker
+
+
+class WorkerCase(unittest.TestCase):
+    hostname = HOSTNAME
+    worker = None
+
+    @classmethod
+    def setUpClass(cls):
+        logging.getLogger("amqplib").setLevel(logging.ERROR)
+        cls.worker = Worker.managed(cls.hostname, caller=cls)
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.worker.ensure_shutdown()
+
+    def assertWorkerAlive(self, timeout=1):
+        self.assertTrue(self.worker.is_alive)
+
+    def my_response(self, response):
+        return flatten_response(response)[self.worker.hostname]
+

+ 28 - 0
celery/tests/functional/config.py

@@ -0,0 +1,28 @@
+import atexit
+import os
+
+CARROT_BACKEND = os.environ.get("CARROT_BACKEND") or "amqplib"
+
+BROKER_HOST = os.environ.get("BROKER_HOST") or "localhost"
+BROKER_USER = os.environ.get("BROKER_USER") or "guest"
+BROKER_PASSWORD = os.environ.get("BROKER_PASSWORD") or "guest"
+BROKER_VHOST = os.environ.get("BROKER_VHOST") or "/"
+
+
+CELERY_RESULT_BACKEND = "amqp"
+CELERY_SEND_TASK_ERROR_EMAILS = False
+
+CELERY_DEFAULT_QUEUE = "testcelery"
+CELERY_DEFAULT_EXCHANGE = "testcelery"
+CELERY_DEFAULT_ROUTING_KEY = "testcelery"
+CELERY_QUEUES = {"testcelery": {"binding_key": "testcelery"}}
+
+CELERYD_LOG_COLOR = False
+
+CELERY_IMPORTS = ("celery.tests.functional.tasks", )
+
+@atexit.register
+def teardown_testdb():
+    import os
+    if os.path.exists("test.db"):
+        os.remove("test.db")

+ 23 - 0
celery/tests/functional/tasks.py

@@ -0,0 +1,23 @@
+import time
+
+from celery.decorators import task
+from celery.task.sets import subtask
+
+
+@task
+def add(x, y):
+    return x + y
+
+
+@task
+def add_cb(x, y, callback=None):
+    result = x + y
+    if callback:
+        return subtask(callback).apply_async(result)
+    return result
+
+
+@task
+def sleeptask(i):
+    time.sleep(i)
+    return i

+ 2 - 0
celery/tests/functional/test.cfg

@@ -0,0 +1,2 @@
+[nose]
+where = celery/tests/functional

+ 32 - 0
celery/tests/functional/test_basic.py

@@ -0,0 +1,32 @@
+import operator
+import time
+
+from celery.task.control import broadcast
+
+from celery.tests.functional import tasks
+from celery.tests.functional.case import WorkerCase
+
+
+class test_basic(WorkerCase):
+
+    def test_started(self):
+        self.assertWorkerAlive()
+
+    def test_roundtrip_simple_task(self):
+        publisher = tasks.add.get_publisher()
+        results = [(tasks.add.apply_async(i, publisher=publisher), i)
+                        for i in zip(xrange(100), xrange(100))]
+        for result, i in results:
+            self.assertEqual(result.get(timeout=10), operator.add(*i))
+
+    def test_dump_active(self):
+        tasks.sleeptask.delay(2)
+        tasks.sleeptask.delay(2)
+        time.sleep(0.2)
+        r = broadcast("dump_active",
+                           arguments={"safe": True}, reply=True)
+        active = self.my_response(r)
+        self.assertEqual(len(active), 2)
+        self.assertEqual(active[0]["name"], tasks.sleeptask.name)
+        self.assertEqual(active[0]["args"], [2])
+