test_worker.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import unittest
  2. from Queue import Queue, Empty
  3. from carrot.connection import BrokerConnection
  4. from celery.messaging import TaskConsumer
  5. from celery.worker.job import TaskWrapper
  6. from celery.worker import AMQPListener, WorkController
  7. from multiprocessing import get_logger
  8. from carrot.backends.base import BaseMessage
  9. from celery import registry
  10. from celery.serialization import pickle
  11. from celery.utils import gen_unique_id
  12. from datetime import datetime, timedelta
  13. def foo_task(x, y, z, **kwargs):
  14. return x * y * z
  15. registry.tasks.register(foo_task, name="c.u.foo")
  16. class MockLogger(object):
  17. def critical(self, *args, **kwargs):
  18. pass
  19. def info(self, *args, **kwargs):
  20. pass
  21. def error(self, *args, **kwargs):
  22. pass
  23. def debug(self, *args, **kwargs):
  24. pass
  25. class MockBackend(object):
  26. _acked = False
  27. def ack(self, delivery_tag):
  28. self._acked = True
  29. class MockPool(object):
  30. def __init__(self, *args, **kwargs):
  31. self.raise_regular = kwargs.get("raise_regular", False)
  32. self.raise_base = kwargs.get("raise_base", False)
  33. def apply_async(self, *args, **kwargs):
  34. if self.raise_regular:
  35. raise KeyError("some exception")
  36. if self.raise_base:
  37. raise KeyboardInterrupt("Ctrl+c")
  38. def start(self):
  39. pass
  40. def stop(self):
  41. pass
  42. return True
  43. class MockController(object):
  44. def __init__(self, w, *args, **kwargs):
  45. self._w = w
  46. self._stopped = False
  47. def start(self):
  48. self._w["started"] = True
  49. self._stopped = False
  50. def stop(self):
  51. self._stopped = True
  52. def create_message(backend, **data):
  53. data["id"] = gen_unique_id()
  54. return BaseMessage(backend, body=pickle.dumps(dict(**data)),
  55. content_type="application/x-python-serialize",
  56. content_encoding="binary")
  57. class TestAMQPListener(unittest.TestCase):
  58. def setUp(self):
  59. self.bucket_queue = Queue()
  60. self.hold_queue = Queue()
  61. self.logger = get_logger()
  62. self.logger.setLevel(0)
  63. def test_connection(self):
  64. l = AMQPListener(self.bucket_queue, self.hold_queue, self.logger)
  65. c = l.reset_connection()
  66. self.assertTrue(isinstance(l.amqp_connection, BrokerConnection))
  67. l.close_connection()
  68. self.assertTrue(l.amqp_connection is None)
  69. self.assertTrue(l.task_consumer is None)
  70. c = l.reset_connection()
  71. self.assertTrue(isinstance(l.amqp_connection, BrokerConnection))
  72. l.stop()
  73. self.assertTrue(l.amqp_connection is None)
  74. self.assertTrue(l.task_consumer is None)
  75. def test_receieve_message(self):
  76. l = AMQPListener(self.bucket_queue, self.hold_queue, self.logger)
  77. backend = MockBackend()
  78. m = create_message(backend, task="c.u.foo", args=[2, 4, 8], kwargs={})
  79. l.receive_message(m.decode(), m)
  80. in_bucket = self.bucket_queue.get_nowait()
  81. self.assertTrue(isinstance(in_bucket, TaskWrapper))
  82. self.assertEquals(in_bucket.task_name, "c.u.foo")
  83. self.assertEquals(in_bucket.execute(), 2 * 4 * 8)
  84. self.assertRaises(Empty, self.hold_queue.get_nowait)
  85. def test_receieve_message_not_registered(self):
  86. l = AMQPListener(self.bucket_queue, self.hold_queue, self.logger)
  87. backend = MockBackend()
  88. m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
  89. self.assertFalse(l.receive_message(m.decode(), m))
  90. self.assertRaises(Empty, self.bucket_queue.get_nowait)
  91. self.assertRaises(Empty, self.hold_queue.get_nowait)
  92. def test_receieve_message_eta(self):
  93. l = AMQPListener(self.bucket_queue, self.hold_queue, self.logger)
  94. backend = MockBackend()
  95. m = create_message(backend, task="c.u.foo", args=[2, 4, 8], kwargs={},
  96. eta=datetime.now() + timedelta(days=1))
  97. l.receive_message(m.decode(), m)
  98. in_hold = self.hold_queue.get_nowait()
  99. self.assertEquals(len(in_hold), 2)
  100. task, eta = in_hold
  101. self.assertTrue(isinstance(task, TaskWrapper))
  102. self.assertTrue(isinstance(eta, datetime))
  103. self.assertEquals(task.task_name, "c.u.foo")
  104. self.assertEquals(task.execute(), 2 * 4 * 8)
  105. self.assertRaises(Empty, self.bucket_queue.get_nowait)
  106. class TestWorkController(unittest.TestCase):
  107. def setUp(self):
  108. self.worker = WorkController(concurrency=1, loglevel=0,
  109. is_detached=False)
  110. self.worker.logger = MockLogger()
  111. def test_attrs(self):
  112. worker = self.worker
  113. self.assertTrue(isinstance(worker.bucket_queue, Queue))
  114. self.assertTrue(isinstance(worker.hold_queue, Queue))
  115. self.assertTrue(worker.periodic_work_controller)
  116. self.assertTrue(worker.pool)
  117. self.assertTrue(worker.amqp_listener)
  118. self.assertTrue(worker.mediator)
  119. self.assertTrue(worker.components)
  120. def test_safe_process_task(self):
  121. worker = self.worker
  122. worker.pool = MockPool()
  123. backend = MockBackend()
  124. m = create_message(backend, task="c.u.foo", args=[4, 8, 10],
  125. kwargs={})
  126. task = TaskWrapper.from_message(m, m.decode())
  127. worker.safe_process_task(task)
  128. worker.pool.stop()
  129. def test_safe_process_task_raise_base(self):
  130. worker = self.worker
  131. worker.pool = MockPool(raise_base=True)
  132. backend = MockBackend()
  133. m = create_message(backend, task="c.u.foo", args=[4, 8, 10],
  134. kwargs={})
  135. task = TaskWrapper.from_message(m, m.decode())
  136. worker.safe_process_task(task)
  137. worker.pool.stop()
  138. def test_safe_process_task_raise_regular(self):
  139. worker = self.worker
  140. worker.pool = MockPool(raise_regular=True)
  141. backend = MockBackend()
  142. m = create_message(backend, task="c.u.foo", args=[4, 8, 10],
  143. kwargs={})
  144. task = TaskWrapper.from_message(m, m.decode())
  145. worker.safe_process_task(task)
  146. worker.pool.stop()
  147. def test_start_stop(self):
  148. worker = self.worker
  149. w1 = {"started": False}
  150. w2 = {"started": False}
  151. w3 = {"started": False}
  152. w4 = {"started": False}
  153. worker.components = [MockController(w1), MockController(w2),
  154. MockController(w3), MockController(w4)]
  155. worker.start()
  156. for w in (w1, w2, w3, w4):
  157. self.assertTrue(w["started"])
  158. for component in worker.components:
  159. self.assertTrue(component._stopped)