test_worker.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. import unittest
  2. from Queue import Queue, Empty
  3. from datetime import datetime, timedelta
  4. from multiprocessing import get_logger
  5. from carrot.connection import BrokerConnection
  6. from carrot.backends.base import BaseMessage
  7. from billiard.serialization import pickle
  8. from celery import conf
  9. from celery.utils import gen_unique_id, noop
  10. from celery.worker import WorkController
  11. from celery.worker.listener import CarrotListener, RUN, CLOSE
  12. from celery.worker.job import TaskWrapper
  13. from celery.worker.scheduler import Scheduler
  14. from celery.decorators import task as task_dec
  15. from celery.decorators import periodic_task as periodic_task_dec
  16. from testunits.utils import execute_context
  17. from testunits.compat import catch_warnings
  18. class PlaceHolder(object):
  19. pass
  20. class MockControlDispatch(object):
  21. commands = []
  22. def dispatch_from_message(self, message):
  23. self.commands.append(message.pop("command", None))
  24. class MockEventDispatcher(object):
  25. sent = []
  26. closed = False
  27. def send(self, event, *args, **kwargs):
  28. self.sent.append(event)
  29. def close(self):
  30. self.closed = True
  31. class MockHeart(object):
  32. closed = False
  33. def stop(self):
  34. self.closed = True
  35. @task_dec()
  36. def foo_task(x, y, z, **kwargs):
  37. return x * y * z
  38. @periodic_task_dec(run_every=60)
  39. def foo_periodic_task():
  40. return "foo"
  41. class MockLogger(object):
  42. def critical(self, *args, **kwargs):
  43. pass
  44. def info(self, *args, **kwargs):
  45. pass
  46. def error(self, *args, **kwargs):
  47. pass
  48. def debug(self, *args, **kwargs):
  49. pass
  50. class MockBackend(object):
  51. _acked = False
  52. def ack(self, delivery_tag):
  53. self._acked = True
  54. class MockPool(object):
  55. def __init__(self, *args, **kwargs):
  56. self.raise_regular = kwargs.get("raise_regular", False)
  57. self.raise_base = kwargs.get("raise_base", False)
  58. def apply_async(self, *args, **kwargs):
  59. if self.raise_regular:
  60. raise KeyError("some exception")
  61. if self.raise_base:
  62. raise KeyboardInterrupt("Ctrl+c")
  63. def start(self):
  64. pass
  65. def stop(self):
  66. pass
  67. return True
  68. class MockController(object):
  69. def __init__(self, w, *args, **kwargs):
  70. self._w = w
  71. self._stopped = False
  72. def start(self):
  73. self._w["started"] = True
  74. self._stopped = False
  75. def stop(self):
  76. self._stopped = True
  77. def create_message(backend, **data):
  78. data.setdefault("id", gen_unique_id())
  79. return BaseMessage(backend, body=pickle.dumps(dict(**data)),
  80. content_type="application/x-python-serialize",
  81. content_encoding="binary")
  82. class TestCarrotListener(unittest.TestCase):
  83. def setUp(self):
  84. self.ready_queue = Queue()
  85. self.eta_schedule = Scheduler(self.ready_queue)
  86. self.logger = get_logger()
  87. self.logger.setLevel(0)
  88. def test_mainloop(self):
  89. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  90. send_events=False)
  91. class MockConnection(object):
  92. def drain_events(self):
  93. return "draining"
  94. l.connection = PlaceHolder()
  95. l.connection.connection = MockConnection()
  96. it = l._mainloop()
  97. self.assertTrue(it.next(), "draining")
  98. records = {}
  99. def create_recorder(key):
  100. def _recorder(*args, **kwargs):
  101. records[key] = True
  102. return _recorder
  103. l.task_consumer = PlaceHolder()
  104. l.task_consumer.iterconsume = create_recorder("consume_tasks")
  105. l.broadcast_consumer = PlaceHolder()
  106. l.broadcast_consumer.register_callback = create_recorder(
  107. "broadcast_callback")
  108. l.broadcast_consumer.iterconsume = create_recorder(
  109. "consume_broadcast")
  110. l.task_consumer.add_consumer = create_recorder("consumer_add")
  111. records.clear()
  112. self.assertEquals(l._detect_wait_method(), l._mainloop)
  113. self.assertTrue(records.get("broadcast_callback"))
  114. self.assertTrue(records.get("consume_broadcast"))
  115. self.assertTrue(records.get("consume_tasks"))
  116. records.clear()
  117. l.connection.connection = PlaceHolder()
  118. self.assertTrue(l._detect_wait_method() is l.task_consumer.iterconsume)
  119. self.assertTrue(records.get("consumer_add"))
  120. def test_connection(self):
  121. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  122. send_events=False)
  123. l.reset_connection()
  124. self.assertTrue(isinstance(l.connection, BrokerConnection))
  125. l.stop_consumers()
  126. self.assertTrue(l.connection is None)
  127. self.assertTrue(l.task_consumer is None)
  128. l.reset_connection()
  129. self.assertTrue(isinstance(l.connection, BrokerConnection))
  130. l.stop()
  131. l.close_connection()
  132. self.assertTrue(l.connection is None)
  133. self.assertTrue(l.task_consumer is None)
  134. def test_receive_message_control_command(self):
  135. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  136. send_events=False)
  137. backend = MockBackend()
  138. m = create_message(backend, control={"command": "shutdown"})
  139. l.event_dispatcher = MockEventDispatcher()
  140. l.control_dispatch = MockControlDispatch()
  141. l.receive_message(m.decode(), m)
  142. self.assertTrue("shutdown" in l.control_dispatch.commands)
  143. def test_close_connection(self):
  144. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  145. send_events=False)
  146. l._state = RUN
  147. l.close_connection()
  148. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  149. send_events=False)
  150. eventer = l.event_dispatcher = MockEventDispatcher()
  151. heart = l.heart = MockHeart()
  152. l._state = RUN
  153. l.stop_consumers()
  154. self.assertTrue(eventer.closed)
  155. self.assertTrue(heart.closed)
  156. def test_receive_message_unknown(self):
  157. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  158. send_events=False)
  159. backend = MockBackend()
  160. m = create_message(backend, unknown={"baz": "!!!"})
  161. l.event_dispatcher = MockEventDispatcher()
  162. l.control_dispatch = MockControlDispatch()
  163. def with_catch_warnings(log):
  164. l.receive_message(m.decode(), m)
  165. self.assertTrue(log)
  166. self.assertTrue("unknown message" in log[0].message.args[0])
  167. context = catch_warnings(record=True)
  168. execute_context(context, with_catch_warnings)
  169. def test_receieve_message(self):
  170. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  171. send_events=False)
  172. backend = MockBackend()
  173. m = create_message(backend, task=foo_task.name,
  174. args=[2, 4, 8], kwargs={})
  175. l.event_dispatcher = MockEventDispatcher()
  176. l.receive_message(m.decode(), m)
  177. in_bucket = self.ready_queue.get_nowait()
  178. self.assertTrue(isinstance(in_bucket, TaskWrapper))
  179. self.assertEquals(in_bucket.task_name, foo_task.name)
  180. self.assertEquals(in_bucket.execute(), 2 * 4 * 8)
  181. self.assertTrue(self.eta_schedule.empty())
  182. def test_receieve_message_eta_isoformat(self):
  183. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  184. send_events=False)
  185. backend = MockBackend()
  186. m = create_message(backend, task=foo_task.name,
  187. eta=datetime.now().isoformat(),
  188. args=[2, 4, 8], kwargs={})
  189. l.event_dispatcher = MockEventDispatcher()
  190. l.receive_message(m.decode(), m)
  191. items = [entry[2] for entry in self.eta_schedule.queue]
  192. found = 0
  193. for item in items:
  194. if item.task_name == foo_task.name:
  195. found = True
  196. self.assertTrue(found)
  197. def test_revoke(self):
  198. ready_queue = Queue()
  199. l = CarrotListener(ready_queue, self.eta_schedule, self.logger,
  200. send_events=False)
  201. backend = MockBackend()
  202. id = gen_unique_id()
  203. c = create_message(backend, control={"command": "revoke",
  204. "task_id": id})
  205. t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
  206. kwargs={}, id=id)
  207. l.event_dispatcher = MockEventDispatcher()
  208. l.receive_message(c.decode(), c)
  209. from celery.worker.revoke import revoked
  210. self.assertTrue(id in revoked)
  211. l.receive_message(t.decode(), t)
  212. self.assertTrue(ready_queue.empty())
  213. def test_receieve_message_not_registered(self):
  214. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  215. send_events=False)
  216. backend = MockBackend()
  217. m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
  218. l.event_dispatcher = MockEventDispatcher()
  219. self.assertFalse(l.receive_message(m.decode(), m))
  220. self.assertRaises(Empty, self.ready_queue.get_nowait)
  221. self.assertTrue(self.eta_schedule.empty())
  222. def test_receieve_message_eta(self):
  223. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  224. send_events=False)
  225. backend = MockBackend()
  226. m = create_message(backend, task=foo_task.name,
  227. args=[2, 4, 8], kwargs={},
  228. eta=(datetime.now() +
  229. timedelta(days=1)).isoformat())
  230. l.reset_connection()
  231. p, conf.BROKER_CONNECTION_RETRY = conf.BROKER_CONNECTION_RETRY, False
  232. try:
  233. l.reset_connection()
  234. finally:
  235. conf.BROKER_CONNECTION_RETRY = p
  236. l.receive_message(m.decode(), m)
  237. in_hold = self.eta_schedule.queue[0]
  238. self.assertEquals(len(in_hold), 4)
  239. eta, priority, task, on_accept = in_hold
  240. self.assertTrue(isinstance(task, TaskWrapper))
  241. self.assertTrue(callable(on_accept))
  242. self.assertEquals(task.task_name, foo_task.name)
  243. self.assertEquals(task.execute(), 2 * 4 * 8)
  244. self.assertRaises(Empty, self.ready_queue.get_nowait)
  245. class TestWorkController(unittest.TestCase):
  246. def setUp(self):
  247. self.worker = WorkController(concurrency=1, loglevel=0)
  248. self.worker.logger = MockLogger()
  249. def test_attrs(self):
  250. worker = self.worker
  251. self.assertTrue(isinstance(worker.eta_schedule, Scheduler))
  252. self.assertTrue(worker.scheduler)
  253. self.assertTrue(worker.pool)
  254. self.assertTrue(worker.listener)
  255. self.assertTrue(worker.mediator)
  256. self.assertTrue(worker.components)
  257. def test_process_task(self):
  258. worker = self.worker
  259. worker.pool = MockPool()
  260. backend = MockBackend()
  261. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  262. kwargs={})
  263. task = TaskWrapper.from_message(m, m.decode())
  264. worker.process_task(task)
  265. worker.pool.stop()
  266. def test_process_task_raise_base(self):
  267. worker = self.worker
  268. worker.pool = MockPool(raise_base=True)
  269. backend = MockBackend()
  270. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  271. kwargs={})
  272. task = TaskWrapper.from_message(m, m.decode())
  273. worker.process_task(task)
  274. worker.pool.stop()
  275. def test_process_task_raise_regular(self):
  276. worker = self.worker
  277. worker.pool = MockPool(raise_regular=True)
  278. backend = MockBackend()
  279. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  280. kwargs={})
  281. task = TaskWrapper.from_message(m, m.decode())
  282. worker.process_task(task)
  283. worker.pool.stop()
  284. def test_start_stop(self):
  285. worker = self.worker
  286. w1 = {"started": False}
  287. w2 = {"started": False}
  288. w3 = {"started": False}
  289. w4 = {"started": False}
  290. worker.components = [MockController(w1), MockController(w2),
  291. MockController(w3), MockController(w4)]
  292. worker.start()
  293. for w in (w1, w2, w3, w4):
  294. self.assertTrue(w["started"])
  295. for component in worker.components:
  296. self.assertTrue(component._stopped)