test_worker.py 19 KB


  1. import socket
  2. import unittest2 as unittest
  3. from datetime import datetime, timedelta
  4. from Queue import Empty
  5. from carrot.backends.base import BaseMessage
  6. from carrot.connection import BrokerConnection
  7. from celery.utils.timer2 import Timer
  8. from celery import conf
  9. from celery import log
  10. from celery.decorators import task as task_dec
  11. from celery.decorators import periodic_task as periodic_task_dec
  12. from celery.serialization import pickle
  13. from celery.utils import gen_unique_id
  14. from celery.worker import WorkController
  15. from celery.worker.buckets import FastQueue
  16. from celery.worker.job import TaskRequest
  17. from celery.worker import listener
  18. from celery.worker.listener import CarrotListener, QoS, RUN
  19. from celery.tests.compat import catch_warnings
  20. from celery.tests.utils import execute_context
  21. class PlaceHolder(object):
  22. pass
  23. class MyCarrotListener(CarrotListener):
  24. def restart_heartbeat(self):
  25. self.heart = None
  26. class MockControlDispatch(object):
  27. commands = []
  28. def dispatch_from_message(self, message):
  29. self.commands.append(message.pop("command", None))
  30. class MockEventDispatcher(object):
  31. sent = []
  32. closed = False
  33. flushed = False
  34. def send(self, event, *args, **kwargs):
  35. self.sent.append(event)
  36. def close(self):
  37. self.closed = True
  38. def flush(self):
  39. self.flushed = True
  40. class MockHeart(object):
  41. closed = False
  42. def stop(self):
  43. self.closed = True
  44. @task_dec()
  45. def foo_task(x, y, z, **kwargs):
  46. return x * y * z
  47. @periodic_task_dec(run_every=60)
  48. def foo_periodic_task():
  49. return "foo"
  50. class MockLogger(object):
  51. def __init__(self):
  52. self.logged = []
  53. def critical(self, msg, *args, **kwargs):
  54. self.logged.append(msg)
  55. def info(self, msg, *args, **kwargs):
  56. self.logged.append(msg)
  57. def error(self, msg, *args, **kwargs):
  58. self.logged.append(msg)
  59. def debug(self, msg, *args, **kwargs):
  60. self.logged.append(msg)
  61. class MockBackend(object):
  62. _acked = False
  63. def ack(self, delivery_tag):
  64. self._acked = True
  65. class MockPool(object):
  66. _terminated = False
  67. _stopped = False
  68. def __init__(self, *args, **kwargs):
  69. self.raise_regular = kwargs.get("raise_regular", False)
  70. self.raise_base = kwargs.get("raise_base", False)
  71. def apply_async(self, *args, **kwargs):
  72. if self.raise_regular:
  73. raise KeyError("some exception")
  74. if self.raise_base:
  75. raise KeyboardInterrupt("Ctrl+c")
  76. def start(self):
  77. pass
  78. def stop(self):
  79. self._stopped = True
  80. return True
  81. def terminate(self):
  82. self._terminated = True
  83. self.stop()
  84. class MockController(object):
  85. def __init__(self, w, *args, **kwargs):
  86. self._w = w
  87. self._stopped = False
  88. def start(self):
  89. self._w["started"] = True
  90. self._stopped = False
  91. def stop(self):
  92. self._stopped = True
  93. def create_message(backend, **data):
  94. data.setdefault("id", gen_unique_id())
  95. return BaseMessage(backend, body=pickle.dumps(dict(**data)),
  96. content_type="application/x-python-serialize",
  97. content_encoding="binary")
  98. class test_QoS(unittest.TestCase):
  99. class MockConsumer(object):
  100. prefetch_count = 0
  101. def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
  102. self.prefetch_count = prefetch_count
  103. def test_decrement(self):
  104. consumer = self.MockConsumer()
  105. qos = QoS(consumer, 10, log.get_default_logger())
  106. qos.update()
  107. self.assertEqual(int(qos.value), 10)
  108. self.assertEqual(consumer.prefetch_count, 10)
  109. qos.decrement()
  110. self.assertEqual(int(qos.value), 9)
  111. self.assertEqual(consumer.prefetch_count, 9)
  112. qos.decrement_eventually()
  113. self.assertEqual(int(qos.value), 8)
  114. self.assertEqual(consumer.prefetch_count, 9)
  115. class test_CarrotListener(unittest.TestCase):
  116. def setUp(self):
  117. self.ready_queue = FastQueue()
  118. self.eta_schedule = Timer()
  119. self.logger = log.get_default_logger()
  120. self.logger.setLevel(0)
  121. def tearDown(self):
  122. self.eta_schedule.stop()
  123. def test_mainloop(self):
  124. l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  125. send_events=False)
  126. class MockConnection(object):
  127. def drain_events(self):
  128. return "draining"
  129. l.connection = MockConnection()
  130. l.connection.connection = MockConnection()
  131. it = l._mainloop()
  132. self.assertTrue(it.next(), "draining")
  133. records = {}
  134. def create_recorder(key):
  135. def _recorder(*args, **kwargs):
  136. records[key] = True
  137. return _recorder
  138. l.task_consumer = PlaceHolder()
  139. l.task_consumer.iterconsume = create_recorder("consume_tasks")
  140. l.broadcast_consumer = PlaceHolder()
  141. l.broadcast_consumer.register_callback = create_recorder(
  142. "broadcast_callback")
  143. l.broadcast_consumer.iterconsume = create_recorder(
  144. "consume_broadcast")
  145. l.task_consumer.add_consumer = create_recorder("consumer_add")
  146. records.clear()
  147. self.assertEqual(l._detect_wait_method(), l._mainloop)
  148. for record in ("broadcast_callback", "consume_broadcast",
  149. "consume_tasks"):
  150. self.assertTrue(records.get(record))
  151. records.clear()
  152. l.connection.connection = PlaceHolder()
  153. self.assertIs(l._detect_wait_method(), l.task_consumer.iterconsume)
  154. self.assertTrue(records.get("consumer_add"))
  155. def test_connection(self):
  156. l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  157. send_events=False)
  158. l.reset_connection()
  159. self.assertIsInstance(l.connection, BrokerConnection)
  160. l.stop_consumers()
  161. self.assertIsNone(l.connection)
  162. self.assertIsNone(l.task_consumer)
  163. l.reset_connection()
  164. self.assertIsInstance(l.connection, BrokerConnection)
  165. l.stop_consumers()
  166. l.stop()
  167. l.close_connection()
  168. self.assertIsNone(l.connection)
  169. self.assertIsNone(l.task_consumer)
  170. def test_receive_message_control_command(self):
  171. l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  172. send_events=False)
  173. backend = MockBackend()
  174. m = create_message(backend, control={"command": "shutdown"})
  175. l.event_dispatcher = MockEventDispatcher()
  176. l.control_dispatch = MockControlDispatch()
  177. l.receive_message(m.decode(), m)
  178. self.assertIn("shutdown", l.control_dispatch.commands)
  179. def test_close_connection(self):
  180. l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  181. send_events=False)
  182. l._state = RUN
  183. l.close_connection()
  184. l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  185. send_events=False)
  186. eventer = l.event_dispatcher = MockEventDispatcher()
  187. heart = l.heart = MockHeart()
  188. l._state = RUN
  189. l.stop_consumers()
  190. self.assertTrue(eventer.closed)
  191. self.assertTrue(heart.closed)
  192. def test_receive_message_unknown(self):
  193. l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  194. send_events=False)
  195. backend = MockBackend()
  196. m = create_message(backend, unknown={"baz": "!!!"})
  197. l.event_dispatcher = MockEventDispatcher()
  198. l.control_dispatch = MockControlDispatch()
  199. def with_catch_warnings(log):
  200. l.receive_message(m.decode(), m)
  201. self.assertTrue(log)
  202. self.assertIn("unknown message", log[0].message.args[0])
  203. context = catch_warnings(record=True)
  204. execute_context(context, with_catch_warnings)
  205. def test_receive_message_eta_OverflowError(self):
  206. l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  207. send_events=False)
  208. backend = MockBackend()
  209. called = [False]
  210. def to_timestamp(d):
  211. called[0] = True
  212. raise OverflowError()
  213. m = create_message(backend, task=foo_task.name,
  214. args=("2, 2"),
  215. kwargs={},
  216. eta=datetime.now().isoformat())
  217. l.event_dispatcher = MockEventDispatcher()
  218. l.control_dispatch = MockControlDispatch()
  219. prev, listener.to_timestamp = listener.to_timestamp, to_timestamp
  220. try:
  221. l.receive_message(m.decode(), m)
  222. self.assertTrue(m.acknowledged)
  223. self.assertTrue(called[0])
  224. finally:
  225. listener.to_timestamp = prev
  226. def test_receive_message_InvalidTaskError(self):
  227. logger = MockLogger()
  228. l = MyCarrotListener(self.ready_queue, self.eta_schedule, logger,
  229. send_events=False)
  230. backend = MockBackend()
  231. m = create_message(backend, task=foo_task.name,
  232. args=(1, 2), kwargs="foobarbaz", id=1)
  233. l.event_dispatcher = MockEventDispatcher()
  234. l.control_dispatch = MockControlDispatch()
  235. l.receive_message(m.decode(), m)
  236. self.assertIn("Invalid task ignored", logger.logged[0])
  237. def test_on_decode_error(self):
  238. logger = MockLogger()
  239. l = MyCarrotListener(self.ready_queue, self.eta_schedule, logger,
  240. send_events=False)
  241. class MockMessage(object):
  242. content_type = "application/x-msgpack"
  243. content_encoding = "binary"
  244. body = "foobarbaz"
  245. acked = False
  246. def ack(self):
  247. self.acked = True
  248. message = MockMessage()
  249. l.on_decode_error(message, KeyError("foo"))
  250. self.assertTrue(message.acked)
  251. self.assertIn("Message decoding error", logger.logged[0])
  252. def test_receieve_message(self):
  253. l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  254. send_events=False)
  255. backend = MockBackend()
  256. m = create_message(backend, task=foo_task.name,
  257. args=[2, 4, 8], kwargs={})
  258. l.event_dispatcher = MockEventDispatcher()
  259. l.receive_message(m.decode(), m)
  260. in_bucket = self.ready_queue.get_nowait()
  261. self.assertIsInstance(in_bucket, TaskRequest)
  262. self.assertEqual(in_bucket.task_name, foo_task.name)
  263. self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
  264. self.assertTrue(self.eta_schedule.empty())
  265. def test_receieve_message_eta_isoformat(self):
  266. class MockConsumer(object):
  267. prefetch_count_incremented = False
  268. def qos(self, **kwargs):
  269. self.prefetch_count_incremented = True
  270. l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  271. send_events=False)
  272. backend = MockBackend()
  273. m = create_message(backend, task=foo_task.name,
  274. eta=datetime.now().isoformat(),
  275. args=[2, 4, 8], kwargs={})
  276. l.task_consumer = MockConsumer()
  277. l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
  278. l.event_dispatcher = MockEventDispatcher()
  279. l.receive_message(m.decode(), m)
  280. l.eta_schedule.stop()
  281. items = [entry[2] for entry in self.eta_schedule.queue]
  282. found = 0
  283. for item in items:
  284. if item.args[0].task_name == foo_task.name:
  285. found = True
  286. self.assertTrue(found)
  287. self.assertTrue(l.task_consumer.prefetch_count_incremented)
  288. l.eta_schedule.stop()
  289. def test_revoke(self):
  290. ready_queue = FastQueue()
  291. l = MyCarrotListener(ready_queue, self.eta_schedule, self.logger,
  292. send_events=False)
  293. backend = MockBackend()
  294. id = gen_unique_id()
  295. c = create_message(backend, control={"command": "revoke",
  296. "task_id": id})
  297. t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
  298. kwargs={}, id=id)
  299. l.event_dispatcher = MockEventDispatcher()
  300. l.receive_message(c.decode(), c)
  301. from celery.worker.state import revoked
  302. self.assertIn(id, revoked)
  303. l.receive_message(t.decode(), t)
  304. self.assertTrue(ready_queue.empty())
  305. def test_receieve_message_not_registered(self):
  306. l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  307. send_events=False)
  308. backend = MockBackend()
  309. m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
  310. l.event_dispatcher = MockEventDispatcher()
  311. self.assertFalse(l.receive_message(m.decode(), m))
  312. self.assertRaises(Empty, self.ready_queue.get_nowait)
  313. self.assertTrue(self.eta_schedule.empty())
  314. def test_receieve_message_eta(self):
  315. l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  316. send_events=False)
  317. dispatcher = l.event_dispatcher = MockEventDispatcher()
  318. backend = MockBackend()
  319. m = create_message(backend, task=foo_task.name,
  320. args=[2, 4, 8], kwargs={},
  321. eta=(datetime.now() +
  322. timedelta(days=1)).isoformat())
  323. l.reset_connection()
  324. p, conf.BROKER_CONNECTION_RETRY = conf.BROKER_CONNECTION_RETRY, False
  325. try:
  326. l.reset_connection()
  327. finally:
  328. conf.BROKER_CONNECTION_RETRY = p
  329. l.stop_consumers()
  330. self.assertTrue(dispatcher.flushed)
  331. l.event_dispatcher = MockEventDispatcher()
  332. l.receive_message(m.decode(), m)
  333. l.eta_schedule.stop()
  334. in_hold = self.eta_schedule.queue[0]
  335. self.assertEqual(len(in_hold), 3)
  336. eta, priority, entry = in_hold
  337. task = entry.args[0]
  338. self.assertIsInstance(task, TaskRequest)
  339. self.assertEqual(task.task_name, foo_task.name)
  340. self.assertEqual(task.execute(), 2 * 4 * 8)
  341. self.assertRaises(Empty, self.ready_queue.get_nowait)
  342. def test_start__consume_messages(self):
  343. class _QoS(object):
  344. prev = 3
  345. next = 4
  346. def update(self):
  347. self.prev = self.next
  348. class _Listener(MyCarrotListener):
  349. iterations = 0
  350. wait_method = None
  351. def reset_connection(self):
  352. if self.iterations >= 1:
  353. raise KeyError("foo")
  354. def _detect_wait_method(self):
  355. return self.wait_method
  356. called_back = [False]
  357. def init_callback(listener):
  358. called_back[0] = True
  359. l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
  360. send_events=False, init_callback=init_callback)
  361. l.qos = _QoS()
  362. def raises_KeyError(limit=None):
  363. yield True
  364. l.iterations = 1
  365. raise KeyError("foo")
  366. l.wait_method = raises_KeyError
  367. self.assertRaises(KeyError, l.start)
  368. self.assertTrue(called_back[0])
  369. self.assertEqual(l.iterations, 1)
  370. self.assertEqual(l.qos.prev, l.qos.next)
  371. l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
  372. send_events=False, init_callback=init_callback)
  373. l.qos = _QoS()
  374. def raises_socket_error(limit=None):
  375. yield True
  376. l.iterations = 1
  377. raise socket.error("foo")
  378. l.wait_method = raises_socket_error
  379. self.assertRaises(KeyError, l.start)
  380. self.assertTrue(called_back[0])
  381. self.assertEqual(l.iterations, 1)
  382. class test_WorkController(unittest.TestCase):
  383. def setUp(self):
  384. self.worker = WorkController(concurrency=1, loglevel=0)
  385. self.worker.logger = MockLogger()
  386. def test_with_rate_limits_disabled(self):
  387. conf.DISABLE_RATE_LIMITS = True
  388. try:
  389. worker = WorkController(concurrency=1, loglevel=0)
  390. self.assertTrue(hasattr(worker.ready_queue, "put"))
  391. finally:
  392. conf.DISABLE_RATE_LIMITS = False
  393. def test_attrs(self):
  394. worker = self.worker
  395. self.assertIsInstance(worker.scheduler, Timer)
  396. self.assertTrue(worker.scheduler)
  397. self.assertTrue(worker.pool)
  398. self.assertTrue(worker.listener)
  399. self.assertTrue(worker.mediator)
  400. self.assertTrue(worker.components)
  401. def test_with_embedded_celerybeat(self):
  402. worker = WorkController(concurrency=1, loglevel=0,
  403. embed_clockservice=True)
  404. self.assertTrue(worker.beat)
  405. self.assertIn(worker.beat, worker.components)
  406. def test_process_task(self):
  407. worker = self.worker
  408. worker.pool = MockPool()
  409. backend = MockBackend()
  410. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  411. kwargs={})
  412. task = TaskRequest.from_message(m, m.decode())
  413. worker.process_task(task)
  414. worker.pool.stop()
  415. def test_process_task_raise_base(self):
  416. worker = self.worker
  417. worker.pool = MockPool(raise_base=True)
  418. backend = MockBackend()
  419. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  420. kwargs={})
  421. task = TaskRequest.from_message(m, m.decode())
  422. worker.process_task(task)
  423. worker.pool.stop()
  424. def test_process_task_raise_regular(self):
  425. worker = self.worker
  426. worker.pool = MockPool(raise_regular=True)
  427. backend = MockBackend()
  428. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  429. kwargs={})
  430. task = TaskRequest.from_message(m, m.decode())
  431. worker.process_task(task)
  432. worker.pool.stop()
  433. def test_start__stop(self):
  434. worker = self.worker
  435. w1 = {"started": False}
  436. w2 = {"started": False}
  437. w3 = {"started": False}
  438. w4 = {"started": False}
  439. worker.components = [MockController(w1), MockController(w2),
  440. MockController(w3), MockController(w4)]
  441. worker.start()
  442. for w in (w1, w2, w3, w4):
  443. self.assertTrue(w["started"])
  444. self.assertTrue(worker._running, len(worker.components))
  445. worker.stop()
  446. for component in worker.components:
  447. self.assertTrue(component._stopped)
  448. def test_start__terminate(self):
  449. worker = self.worker
  450. w1 = {"started": False}
  451. w2 = {"started": False}
  452. w3 = {"started": False}
  453. w4 = {"started": False}
  454. worker.components = [MockController(w1), MockController(w2),
  455. MockController(w3), MockController(w4),
  456. MockPool()]
  457. worker.start()
  458. for w in (w1, w2, w3, w4):
  459. self.assertTrue(w["started"])
  460. self.assertTrue(worker._running, len(worker.components))
  461. self.assertEqual(worker._state, RUN)
  462. worker.terminate()
  463. for component in worker.components:
  464. self.assertTrue(component._stopped)
  465. self.assertTrue(worker.components[4]._terminated)