test_worker.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. import unittest2 as unittest
  2. from Queue import Empty
  3. from datetime import datetime, timedelta
  4. from multiprocessing import get_logger
  5. from carrot.connection import BrokerConnection
  6. from carrot.backends.base import BaseMessage
  7. from billiard.serialization import pickle
  8. from celery import conf
  9. from celery.utils import gen_unique_id
  10. from celery.worker import WorkController
  11. from celery.worker.job import TaskWrapper
  12. from celery.worker.buckets import FastQueue
  13. from celery.worker.listener import CarrotListener, RUN
  14. from celery.worker.scheduler import Scheduler
  15. from celery.decorators import task as task_dec
  16. from celery.decorators import periodic_task as periodic_task_dec
  17. from celery.tests.utils import execute_context
  18. from celery.tests.compat import catch_warnings
  19. class PlaceHolder(object):
  20. pass
  21. class MockControlDispatch(object):
  22. commands = []
  23. def dispatch_from_message(self, message):
  24. self.commands.append(message.pop("command", None))
  25. class MockEventDispatcher(object):
  26. sent = []
  27. closed = False
  28. def send(self, event, *args, **kwargs):
  29. self.sent.append(event)
  30. def close(self):
  31. self.closed = True
  32. class MockHeart(object):
  33. closed = False
  34. def stop(self):
  35. self.closed = True
  36. @task_dec()
  37. def foo_task(x, y, z, **kwargs):
  38. return x * y * z
  39. @periodic_task_dec(run_every=60)
  40. def foo_periodic_task():
  41. return "foo"
  42. class MockLogger(object):
  43. def critical(self, *args, **kwargs):
  44. pass
  45. def info(self, *args, **kwargs):
  46. pass
  47. def error(self, *args, **kwargs):
  48. pass
  49. def debug(self, *args, **kwargs):
  50. pass
  51. class MockBackend(object):
  52. _acked = False
  53. def ack(self, delivery_tag):
  54. self._acked = True
  55. class MockPool(object):
  56. def __init__(self, *args, **kwargs):
  57. self.raise_regular = kwargs.get("raise_regular", False)
  58. self.raise_base = kwargs.get("raise_base", False)
  59. def apply_async(self, *args, **kwargs):
  60. if self.raise_regular:
  61. raise KeyError("some exception")
  62. if self.raise_base:
  63. raise KeyboardInterrupt("Ctrl+c")
  64. def start(self):
  65. pass
  66. def stop(self):
  67. pass
  68. return True
  69. class MockController(object):
  70. def __init__(self, w, *args, **kwargs):
  71. self._w = w
  72. self._stopped = False
  73. def start(self):
  74. self._w["started"] = True
  75. self._stopped = False
  76. def stop(self):
  77. self._stopped = True
  78. def create_message(backend, **data):
  79. data.setdefault("id", gen_unique_id())
  80. return BaseMessage(backend, body=pickle.dumps(dict(**data)),
  81. content_type="application/x-python-serialize",
  82. content_encoding="binary")
  83. class TestCarrotListener(unittest.TestCase):
  84. def setUp(self):
  85. self.ready_queue = FastQueue()
  86. self.eta_schedule = Scheduler(self.ready_queue)
  87. self.logger = get_logger()
  88. self.logger.setLevel(0)
  89. def test_mainloop(self):
  90. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  91. send_events=False)
  92. class MockConnection(object):
  93. def drain_events(self):
  94. return "draining"
  95. l.connection = MockConnection()
  96. l.connection.connection = MockConnection()
  97. it = l._mainloop()
  98. self.assertTrue(it.next(), "draining")
  99. records = {}
  100. def create_recorder(key):
  101. def _recorder(*args, **kwargs):
  102. records[key] = True
  103. return _recorder
  104. l.task_consumer = PlaceHolder()
  105. l.task_consumer.iterconsume = create_recorder("consume_tasks")
  106. l.broadcast_consumer = PlaceHolder()
  107. l.broadcast_consumer.register_callback = create_recorder(
  108. "broadcast_callback")
  109. l.broadcast_consumer.iterconsume = create_recorder(
  110. "consume_broadcast")
  111. l.task_consumer.add_consumer = create_recorder("consumer_add")
  112. records.clear()
  113. self.assertEqual(l._detect_wait_method(), l._mainloop)
  114. for record in ("broadcast_callback", "consume_broadcast",
  115. "consume_tasks"):
  116. self.assertTrue(records.get(record))
  117. records.clear()
  118. l.connection.connection = PlaceHolder()
  119. self.assertIs(l._detect_wait_method(), l.task_consumer.iterconsume)
  120. self.assertTrue(records.get("consumer_add"))
  121. def test_connection(self):
  122. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  123. send_events=False)
  124. l.reset_connection()
  125. self.assertIsInstance(l.connection, BrokerConnection)
  126. l.stop_consumers()
  127. self.assertIsNone(l.connection)
  128. self.assertIsNone(l.task_consumer)
  129. l.reset_connection()
  130. self.assertIsInstance(l.connection, BrokerConnection)
  131. l.stop()
  132. l.close_connection()
  133. self.assertIsNone(l.connection)
  134. self.assertIsNone(l.task_consumer)
  135. def test_receive_message_control_command(self):
  136. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  137. send_events=False)
  138. backend = MockBackend()
  139. m = create_message(backend, control={"command": "shutdown"})
  140. l.event_dispatcher = MockEventDispatcher()
  141. l.control_dispatch = MockControlDispatch()
  142. l.receive_message(m.decode(), m)
  143. self.assertIn("shutdown", l.control_dispatch.commands)
  144. def test_close_connection(self):
  145. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  146. send_events=False)
  147. l._state = RUN
  148. l.close_connection()
  149. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  150. send_events=False)
  151. eventer = l.event_dispatcher = MockEventDispatcher()
  152. heart = l.heart = MockHeart()
  153. l._state = RUN
  154. l.stop_consumers()
  155. self.assertTrue(eventer.closed)
  156. self.assertTrue(heart.closed)
  157. def test_receive_message_unknown(self):
  158. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  159. send_events=False)
  160. backend = MockBackend()
  161. m = create_message(backend, unknown={"baz": "!!!"})
  162. l.event_dispatcher = MockEventDispatcher()
  163. l.control_dispatch = MockControlDispatch()
  164. def with_catch_warnings(log):
  165. l.receive_message(m.decode(), m)
  166. self.assertTrue(log)
  167. self.assertIn("unknown message", log[0].message.args[0])
  168. context = catch_warnings(record=True)
  169. execute_context(context, with_catch_warnings)
  170. def test_receieve_message(self):
  171. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  172. send_events=False)
  173. backend = MockBackend()
  174. m = create_message(backend, task=foo_task.name,
  175. args=[2, 4, 8], kwargs={})
  176. l.event_dispatcher = MockEventDispatcher()
  177. l.receive_message(m.decode(), m)
  178. in_bucket = self.ready_queue.get_nowait()
  179. self.assertIsInstance(in_bucket, TaskWrapper)
  180. self.assertEqual(in_bucket.task_name, foo_task.name)
  181. self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
  182. self.assertTrue(self.eta_schedule.empty())
  183. def test_receieve_message_eta_isoformat(self):
  184. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  185. send_events=False)
  186. backend = MockBackend()
  187. m = create_message(backend, task=foo_task.name,
  188. eta=datetime.now().isoformat(),
  189. args=[2, 4, 8], kwargs={})
  190. l.event_dispatcher = MockEventDispatcher()
  191. l.receive_message(m.decode(), m)
  192. items = [entry[2] for entry in self.eta_schedule.queue]
  193. found = 0
  194. for item in items:
  195. if item.task_name == foo_task.name:
  196. found = True
  197. self.assertTrue(found)
  198. def test_revoke(self):
  199. ready_queue = FastQueue()
  200. l = CarrotListener(ready_queue, self.eta_schedule, self.logger,
  201. send_events=False)
  202. backend = MockBackend()
  203. id = gen_unique_id()
  204. c = create_message(backend, control={"command": "revoke",
  205. "task_id": id})
  206. t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
  207. kwargs={}, id=id)
  208. l.event_dispatcher = MockEventDispatcher()
  209. l.receive_message(c.decode(), c)
  210. from celery.worker.revoke import revoked
  211. self.assertIn(id, revoked)
  212. l.receive_message(t.decode(), t)
  213. self.assertTrue(ready_queue.empty())
  214. def test_receieve_message_not_registered(self):
  215. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  216. send_events=False)
  217. backend = MockBackend()
  218. m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
  219. l.event_dispatcher = MockEventDispatcher()
  220. self.assertFalse(l.receive_message(m.decode(), m))
  221. self.assertRaises(Empty, self.ready_queue.get_nowait)
  222. self.assertTrue(self.eta_schedule.empty())
  223. def test_receieve_message_eta(self):
  224. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  225. send_events=False)
  226. backend = MockBackend()
  227. m = create_message(backend, task=foo_task.name,
  228. args=[2, 4, 8], kwargs={},
  229. eta=(datetime.now() +
  230. timedelta(days=1)).isoformat())
  231. l.reset_connection()
  232. p, conf.BROKER_CONNECTION_RETRY = conf.BROKER_CONNECTION_RETRY, False
  233. try:
  234. l.reset_connection()
  235. finally:
  236. conf.BROKER_CONNECTION_RETRY = p
  237. l.receive_message(m.decode(), m)
  238. in_hold = self.eta_schedule.queue[0]
  239. self.assertEqual(len(in_hold), 4)
  240. eta, priority, task, on_accept = in_hold
  241. self.assertIsInstance(task, TaskWrapper)
  242. self.assertTrue(callable(on_accept))
  243. self.assertEqual(task.task_name, foo_task.name)
  244. self.assertEqual(task.execute(), 2 * 4 * 8)
  245. self.assertRaises(Empty, self.ready_queue.get_nowait)
  246. class TestWorkController(unittest.TestCase):
  247. def setUp(self):
  248. self.worker = WorkController(concurrency=1, loglevel=0)
  249. self.worker.logger = MockLogger()
  250. def test_attrs(self):
  251. worker = self.worker
  252. self.assertIsInstance(worker.eta_schedule, Scheduler)
  253. self.assertTrue(worker.scheduler)
  254. self.assertTrue(worker.pool)
  255. self.assertTrue(worker.listener)
  256. self.assertTrue(worker.mediator)
  257. self.assertTrue(worker.components)
  258. def test_process_task(self):
  259. worker = self.worker
  260. worker.pool = MockPool()
  261. backend = MockBackend()
  262. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  263. kwargs={})
  264. task = TaskWrapper.from_message(m, m.decode())
  265. worker.process_task(task)
  266. worker.pool.stop()
  267. def test_process_task_raise_base(self):
  268. worker = self.worker
  269. worker.pool = MockPool(raise_base=True)
  270. backend = MockBackend()
  271. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  272. kwargs={})
  273. task = TaskWrapper.from_message(m, m.decode())
  274. worker.process_task(task)
  275. worker.pool.stop()
  276. def test_process_task_raise_regular(self):
  277. worker = self.worker
  278. worker.pool = MockPool(raise_regular=True)
  279. backend = MockBackend()
  280. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  281. kwargs={})
  282. task = TaskWrapper.from_message(m, m.decode())
  283. worker.process_task(task)
  284. worker.pool.stop()
  285. def test_start_stop(self):
  286. worker = self.worker
  287. w1 = {"started": False}
  288. w2 = {"started": False}
  289. w3 = {"started": False}
  290. w4 = {"started": False}
  291. worker.components = [MockController(w1), MockController(w2),
  292. MockController(w3), MockController(w4)]
  293. worker.start()
  294. for w in (w1, w2, w3, w4):
  295. self.assertTrue(w["started"])
  296. for component in worker.components:
  297. self.assertTrue(component._stopped)