test_worker.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import unittest
  2. from Queue import Queue, Empty
  3. from datetime import datetime, timedelta
  4. from multiprocessing import get_logger
  5. from carrot.connection import BrokerConnection
  6. from carrot.backends.base import BaseMessage
  7. from billiard.serialization import pickle
  8. from celery import registry
  9. from celery.utils import gen_unique_id
  10. from celery.worker import CarrotListener, WorkController
  11. from celery.worker.job import TaskWrapper
  12. from celery.worker.scheduler import Scheduler
  13. from celery.decorators import task as task_dec
  14. from celery.decorators import periodic_task as periodic_task_dec
  15. class MockEventDispatcher(object):
  16. def send(self, *args, **kwargs):
  17. pass
  18. def close(self):
  19. pass
  20. @task_dec()
  21. def foo_task(x, y, z, **kwargs):
  22. return x * y * z
  23. @periodic_task_dec(run_every=60)
  24. def foo_periodic_task():
  25. return "foo"
  26. class MockLogger(object):
  27. def critical(self, *args, **kwargs):
  28. pass
  29. def info(self, *args, **kwargs):
  30. pass
  31. def error(self, *args, **kwargs):
  32. pass
  33. def debug(self, *args, **kwargs):
  34. pass
  35. class MockBackend(object):
  36. _acked = False
  37. def ack(self, delivery_tag):
  38. self._acked = True
  39. class MockPool(object):
  40. def __init__(self, *args, **kwargs):
  41. self.raise_regular = kwargs.get("raise_regular", False)
  42. self.raise_base = kwargs.get("raise_base", False)
  43. def apply_async(self, *args, **kwargs):
  44. if self.raise_regular:
  45. raise KeyError("some exception")
  46. if self.raise_base:
  47. raise KeyboardInterrupt("Ctrl+c")
  48. def start(self):
  49. pass
  50. def stop(self):
  51. pass
  52. return True
  53. class MockController(object):
  54. def __init__(self, w, *args, **kwargs):
  55. self._w = w
  56. self._stopped = False
  57. def start(self):
  58. self._w["started"] = True
  59. self._stopped = False
  60. def stop(self):
  61. self._stopped = True
  62. def create_message(backend, **data):
  63. data.setdefault("id", gen_unique_id())
  64. return BaseMessage(backend, body=pickle.dumps(dict(**data)),
  65. content_type="application/x-python-serialize",
  66. content_encoding="binary")
  67. class TestCarrotListener(unittest.TestCase):
  68. def setUp(self):
  69. self.ready_queue = Queue()
  70. self.eta_schedule = Scheduler(self.ready_queue)
  71. self.logger = get_logger()
  72. self.logger.setLevel(0)
  73. def test_connection(self):
  74. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  75. send_events=False)
  76. c = l.reset_connection()
  77. self.assertTrue(isinstance(l.amqp_connection, BrokerConnection))
  78. l.close_connection()
  79. self.assertTrue(l.amqp_connection is None)
  80. self.assertTrue(l.task_consumer is None)
  81. c = l.reset_connection()
  82. self.assertTrue(isinstance(l.amqp_connection, BrokerConnection))
  83. l.stop()
  84. self.assertTrue(l.amqp_connection is None)
  85. self.assertTrue(l.task_consumer is None)
  86. def test_receieve_message(self):
  87. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  88. send_events=False)
  89. backend = MockBackend()
  90. m = create_message(backend, task=foo_task.name,
  91. args=[2, 4, 8], kwargs={})
  92. l.event_dispatcher = MockEventDispatcher()
  93. l.receive_message(m.decode(), m)
  94. in_bucket = self.ready_queue.get_nowait()
  95. self.assertTrue(isinstance(in_bucket, TaskWrapper))
  96. self.assertEquals(in_bucket.task_name, foo_task.name)
  97. self.assertEquals(in_bucket.execute(), 2 * 4 * 8)
  98. self.assertTrue(self.eta_schedule.empty())
  99. def test_revoke(self):
  100. ready_queue = Queue()
  101. l = CarrotListener(ready_queue, self.eta_schedule, self.logger,
  102. send_events=False)
  103. backend = MockBackend()
  104. id = gen_unique_id()
  105. c = create_message(backend, control={"command": "revoke",
  106. "task_id": id})
  107. t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
  108. kwargs={}, id=id)
  109. l.event_dispatcher = MockEventDispatcher()
  110. l.receive_message(c.decode(), c)
  111. from celery.worker.revoke import revoked
  112. self.assertTrue(id in revoked)
  113. l.receive_message(t.decode(), t)
  114. self.assertTrue(ready_queue.empty())
  115. def test_receieve_message_not_registered(self):
  116. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  117. send_events=False)
  118. backend = MockBackend()
  119. m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
  120. l.event_dispatcher = MockEventDispatcher()
  121. self.assertFalse(l.receive_message(m.decode(), m))
  122. self.assertRaises(Empty, self.ready_queue.get_nowait)
  123. self.assertTrue(self.eta_schedule.empty())
  124. def test_receieve_message_eta(self):
  125. l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
  126. send_events=False)
  127. backend = MockBackend()
  128. m = create_message(backend, task=foo_task.name,
  129. args=[2, 4, 8], kwargs={},
  130. eta=(datetime.now() +
  131. timedelta(days=1)).isoformat())
  132. l.reset_connection()
  133. l.receive_message(m.decode(), m)
  134. in_hold = self.eta_schedule.queue[0]
  135. self.assertEquals(len(in_hold), 4)
  136. eta, priority, task, on_accept = in_hold
  137. self.assertTrue(isinstance(task, TaskWrapper))
  138. self.assertTrue(callable(on_accept))
  139. self.assertEquals(task.task_name, foo_task.name)
  140. self.assertEquals(task.execute(), 2 * 4 * 8)
  141. self.assertRaises(Empty, self.ready_queue.get_nowait)
  142. class TestWorkController(unittest.TestCase):
  143. def setUp(self):
  144. self.worker = WorkController(concurrency=1,
  145. loglevel=0,
  146. is_detached=False)
  147. self.worker.logger = MockLogger()
  148. def test_attrs(self):
  149. worker = self.worker
  150. self.assertTrue(isinstance(worker.eta_schedule, Scheduler))
  151. self.assertTrue(worker.scheduler)
  152. self.assertTrue(worker.pool)
  153. self.assertTrue(worker.listener)
  154. self.assertTrue(worker.mediator)
  155. self.assertTrue(worker.components)
  156. def test_process_task(self):
  157. worker = self.worker
  158. worker.pool = MockPool()
  159. backend = MockBackend()
  160. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  161. kwargs={})
  162. task = TaskWrapper.from_message(m, m.decode())
  163. worker.process_task(task)
  164. worker.pool.stop()
  165. def test_process_task_raise_base(self):
  166. worker = self.worker
  167. worker.pool = MockPool(raise_base=True)
  168. backend = MockBackend()
  169. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  170. kwargs={})
  171. task = TaskWrapper.from_message(m, m.decode())
  172. worker.process_task(task)
  173. worker.pool.stop()
  174. def test_process_task_raise_regular(self):
  175. worker = self.worker
  176. worker.pool = MockPool(raise_regular=True)
  177. backend = MockBackend()
  178. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  179. kwargs={})
  180. task = TaskWrapper.from_message(m, m.decode())
  181. worker.process_task(task)
  182. worker.pool.stop()
  183. def test_start_stop(self):
  184. worker = self.worker
  185. w1 = {"started": False}
  186. w2 = {"started": False}
  187. w3 = {"started": False}
  188. w4 = {"started": False}
  189. worker.components = [MockController(w1), MockController(w2),
  190. MockController(w3), MockController(w4)]
  191. worker.start()
  192. for w in (w1, w2, w3, w4):
  193. self.assertTrue(w["started"])
  194. for component in worker.components:
  195. self.assertTrue(component._stopped)