123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 |
- import unittest
- from Queue import Queue, Empty
- from datetime import datetime, timedelta
- from multiprocessing import get_logger
- from carrot.connection import BrokerConnection
- from carrot.backends.base import BaseMessage
- from billiard.serialization import pickle
- from celery import conf
- from celery.utils import gen_unique_id, noop
- from celery.worker import WorkController
- from celery.worker.listener import CarrotListener, RUN, CLOSE
- from celery.worker.job import TaskWrapper
- from celery.worker.scheduler import Scheduler
- from celery.decorators import task as task_dec
- from celery.decorators import periodic_task as periodic_task_dec
- from testunits.utils import execute_context
- from testunits.compat import catch_warnings
- class PlaceHolder(object):
- pass
- class MockControlDispatch(object):
- commands = []
- def dispatch_from_message(self, message):
- self.commands.append(message.pop("command", None))
- class MockEventDispatcher(object):
- sent = []
- closed = False
- def send(self, event, *args, **kwargs):
- self.sent.append(event)
- def close(self):
- self.closed = True
- class MockHeart(object):
- closed = False
- def stop(self):
- self.closed = True
- @task_dec()
- def foo_task(x, y, z, **kwargs):
- return x * y * z
- @periodic_task_dec(run_every=60)
- def foo_periodic_task():
- return "foo"
- class MockLogger(object):
- def critical(self, *args, **kwargs):
- pass
- def info(self, *args, **kwargs):
- pass
- def error(self, *args, **kwargs):
- pass
- def debug(self, *args, **kwargs):
- pass
- class MockBackend(object):
- _acked = False
- def ack(self, delivery_tag):
- self._acked = True
- class MockPool(object):
- def __init__(self, *args, **kwargs):
- self.raise_regular = kwargs.get("raise_regular", False)
- self.raise_base = kwargs.get("raise_base", False)
- def apply_async(self, *args, **kwargs):
- if self.raise_regular:
- raise KeyError("some exception")
- if self.raise_base:
- raise KeyboardInterrupt("Ctrl+c")
- def start(self):
- pass
- def stop(self):
- pass
- return True
- class MockController(object):
- def __init__(self, w, *args, **kwargs):
- self._w = w
- self._stopped = False
- def start(self):
- self._w["started"] = True
- self._stopped = False
- def stop(self):
- self._stopped = True
- def create_message(backend, **data):
- data.setdefault("id", gen_unique_id())
- return BaseMessage(backend, body=pickle.dumps(dict(**data)),
- content_type="application/x-python-serialize",
- content_encoding="binary")
- class TestCarrotListener(unittest.TestCase):
- def setUp(self):
- self.ready_queue = Queue()
- self.eta_schedule = Scheduler(self.ready_queue)
- self.logger = get_logger()
- self.logger.setLevel(0)
- def test_mainloop(self):
- l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
- send_events=False)
- class MockConnection(object):
- def drain_events(self):
- return "draining"
- l.connection = PlaceHolder()
- l.connection.connection = MockConnection()
- it = l._mainloop()
- self.assertTrue(it.next(), "draining")
- records = {}
- def create_recorder(key):
- def _recorder(*args, **kwargs):
- records[key] = True
- return _recorder
- l.task_consumer = PlaceHolder()
- l.task_consumer.iterconsume = create_recorder("consume_tasks")
- l.broadcast_consumer = PlaceHolder()
- l.broadcast_consumer.register_callback = create_recorder(
- "broadcast_callback")
- l.broadcast_consumer.iterconsume = create_recorder(
- "consume_broadcast")
- l.task_consumer.add_consumer = create_recorder("consumer_add")
- records.clear()
- self.assertEquals(l._detect_wait_method(), l._mainloop)
- self.assertTrue(records.get("broadcast_callback"))
- self.assertTrue(records.get("consume_broadcast"))
- self.assertTrue(records.get("consume_tasks"))
- records.clear()
- l.connection.connection = PlaceHolder()
- self.assertTrue(l._detect_wait_method() is l.task_consumer.iterconsume)
- self.assertTrue(records.get("consumer_add"))
- def test_connection(self):
- l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
- send_events=False)
- l.reset_connection()
- self.assertTrue(isinstance(l.connection, BrokerConnection))
- l.stop_consumers()
- self.assertTrue(l.connection is None)
- self.assertTrue(l.task_consumer is None)
- l.reset_connection()
- self.assertTrue(isinstance(l.connection, BrokerConnection))
- l.stop()
- l.close_connection()
- self.assertTrue(l.connection is None)
- self.assertTrue(l.task_consumer is None)
- def test_receive_message_control_command(self):
- l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
- send_events=False)
- backend = MockBackend()
- m = create_message(backend, control={"command": "shutdown"})
- l.event_dispatcher = MockEventDispatcher()
- l.control_dispatch = MockControlDispatch()
- l.receive_message(m.decode(), m)
- self.assertTrue("shutdown" in l.control_dispatch.commands)
- def test_close_connection(self):
- l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
- send_events=False)
- l._state = RUN
- l.close_connection()
- l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
- send_events=False)
- eventer = l.event_dispatcher = MockEventDispatcher()
- heart = l.heart = MockHeart()
- l._state = RUN
- l.stop_consumers()
- self.assertTrue(eventer.closed)
- self.assertTrue(heart.closed)
- def test_receive_message_unknown(self):
- l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
- send_events=False)
- backend = MockBackend()
- m = create_message(backend, unknown={"baz": "!!!"})
- l.event_dispatcher = MockEventDispatcher()
- l.control_dispatch = MockControlDispatch()
- def with_catch_warnings(log):
- l.receive_message(m.decode(), m)
- self.assertTrue(log)
- self.assertTrue("unknown message" in log[0].message.args[0])
- context = catch_warnings(record=True)
- execute_context(context, with_catch_warnings)
- def test_receieve_message(self):
- l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
- send_events=False)
- backend = MockBackend()
- m = create_message(backend, task=foo_task.name,
- args=[2, 4, 8], kwargs={})
- l.event_dispatcher = MockEventDispatcher()
- l.receive_message(m.decode(), m)
- in_bucket = self.ready_queue.get_nowait()
- self.assertTrue(isinstance(in_bucket, TaskWrapper))
- self.assertEquals(in_bucket.task_name, foo_task.name)
- self.assertEquals(in_bucket.execute(), 2 * 4 * 8)
- self.assertTrue(self.eta_schedule.empty())
- def test_receieve_message_eta_isoformat(self):
- l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
- send_events=False)
- backend = MockBackend()
- m = create_message(backend, task=foo_task.name,
- eta=datetime.now().isoformat(),
- args=[2, 4, 8], kwargs={})
- l.event_dispatcher = MockEventDispatcher()
- l.receive_message(m.decode(), m)
- items = [entry[2] for entry in self.eta_schedule.queue]
- found = 0
- for item in items:
- if item.task_name == foo_task.name:
- found = True
- self.assertTrue(found)
- def test_revoke(self):
- ready_queue = Queue()
- l = CarrotListener(ready_queue, self.eta_schedule, self.logger,
- send_events=False)
- backend = MockBackend()
- id = gen_unique_id()
- c = create_message(backend, control={"command": "revoke",
- "task_id": id})
- t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
- kwargs={}, id=id)
- l.event_dispatcher = MockEventDispatcher()
- l.receive_message(c.decode(), c)
- from celery.worker.revoke import revoked
- self.assertTrue(id in revoked)
- l.receive_message(t.decode(), t)
- self.assertTrue(ready_queue.empty())
- def test_receieve_message_not_registered(self):
- l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
- send_events=False)
- backend = MockBackend()
- m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
- l.event_dispatcher = MockEventDispatcher()
- self.assertFalse(l.receive_message(m.decode(), m))
- self.assertRaises(Empty, self.ready_queue.get_nowait)
- self.assertTrue(self.eta_schedule.empty())
- def test_receieve_message_eta(self):
- l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
- send_events=False)
- backend = MockBackend()
- m = create_message(backend, task=foo_task.name,
- args=[2, 4, 8], kwargs={},
- eta=(datetime.now() +
- timedelta(days=1)).isoformat())
- l.reset_connection()
- p, conf.BROKER_CONNECTION_RETRY = conf.BROKER_CONNECTION_RETRY, False
- try:
- l.reset_connection()
- finally:
- conf.BROKER_CONNECTION_RETRY = p
- l.receive_message(m.decode(), m)
- in_hold = self.eta_schedule.queue[0]
- self.assertEquals(len(in_hold), 4)
- eta, priority, task, on_accept = in_hold
- self.assertTrue(isinstance(task, TaskWrapper))
- self.assertTrue(callable(on_accept))
- self.assertEquals(task.task_name, foo_task.name)
- self.assertEquals(task.execute(), 2 * 4 * 8)
- self.assertRaises(Empty, self.ready_queue.get_nowait)
- class TestWorkController(unittest.TestCase):
- def setUp(self):
- self.worker = WorkController(concurrency=1, loglevel=0)
- self.worker.logger = MockLogger()
- def test_attrs(self):
- worker = self.worker
- self.assertTrue(isinstance(worker.eta_schedule, Scheduler))
- self.assertTrue(worker.scheduler)
- self.assertTrue(worker.pool)
- self.assertTrue(worker.listener)
- self.assertTrue(worker.mediator)
- self.assertTrue(worker.components)
- def test_process_task(self):
- worker = self.worker
- worker.pool = MockPool()
- backend = MockBackend()
- m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
- kwargs={})
- task = TaskWrapper.from_message(m, m.decode())
- worker.process_task(task)
- worker.pool.stop()
- def test_process_task_raise_base(self):
- worker = self.worker
- worker.pool = MockPool(raise_base=True)
- backend = MockBackend()
- m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
- kwargs={})
- task = TaskWrapper.from_message(m, m.decode())
- worker.process_task(task)
- worker.pool.stop()
- def test_process_task_raise_regular(self):
- worker = self.worker
- worker.pool = MockPool(raise_regular=True)
- backend = MockBackend()
- m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
- kwargs={})
- task = TaskWrapper.from_message(m, m.decode())
- worker.process_task(task)
- worker.pool.stop()
- def test_start_stop(self):
- worker = self.worker
- w1 = {"started": False}
- w2 = {"started": False}
- w3 = {"started": False}
- w4 = {"started": False}
- worker.components = [MockController(w1), MockController(w2),
- MockController(w3), MockController(w4)]
- worker.start()
- for w in (w1, w2, w3, w4):
- self.assertTrue(w["started"])
- for component in worker.components:
- self.assertTrue(component._stopped)
|