test_worker_control.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import socket
  2. from celery.tests.utils import unittest
  3. from datetime import datetime, timedelta
  4. from kombu import pidbox
  5. from celery.utils.timer2 import Timer
  6. from celery.app import app_or_default
  7. from celery.datastructures import AttributeDict
  8. from celery.task import task
  9. from celery.registry import tasks
  10. from celery.task import PingTask
  11. from celery.utils import gen_unique_id
  12. from celery.worker.buckets import FastQueue
  13. from celery.worker.job import TaskRequest
  14. from celery.worker.state import revoked
  15. from celery.worker.control.registry import Panel
  16. hostname = socket.gethostname()
  17. @task(rate_limit=200) # for extra info in dump_tasks
  18. def mytask():
  19. pass
  20. class Dispatcher(object):
  21. enabled = None
  22. def __init__(self, *args, **kwargs):
  23. self.sent = []
  24. def enable(self):
  25. self.enabled = True
  26. def disable(self):
  27. self.enabled = False
  28. def send(self, event):
  29. self.sent.append(event)
  30. class Consumer(object):
  31. def __init__(self):
  32. self.ready_queue = FastQueue()
  33. self.ready_queue.put(TaskRequest(task_name=mytask.name,
  34. task_id=gen_unique_id(),
  35. args=(2, 2),
  36. kwargs={}))
  37. self.eta_schedule = Timer()
  38. self.app = app_or_default()
  39. self.event_dispatcher = Dispatcher()
  40. from celery.concurrency.base import BasePool
  41. self.pool = BasePool(10)
  42. @property
  43. def info(self):
  44. return {"xyz": "XYZ"}
  45. class test_ControlPanel(unittest.TestCase):
  46. def setUp(self):
  47. self.app = app_or_default()
  48. self.panel = self.create_panel(consumer=Consumer())
  49. def create_state(self, **kwargs):
  50. kwargs.setdefault("logger", self.app.log.get_default_logger())
  51. kwargs.setdefault("app", self.app)
  52. return AttributeDict(kwargs)
  53. def create_panel(self, **kwargs):
  54. return self.app.control.mailbox.Node(hostname=hostname,
  55. state=self.create_state(**kwargs),
  56. handlers=Panel.data)
  57. def test_enable_events(self):
  58. consumer = Consumer()
  59. panel = self.create_panel(consumer=consumer)
  60. consumer.event_dispatcher.enabled = False
  61. panel.handle("enable_events")
  62. self.assertEqual(consumer.event_dispatcher.enabled, True)
  63. self.assertIn("worker-online", consumer.event_dispatcher.sent)
  64. self.assertTrue(panel.handle("enable_events")["ok"])
  65. def test_disable_events(self):
  66. consumer = Consumer()
  67. panel = self.create_panel(consumer=consumer)
  68. consumer.event_dispatcher.enabled = True
  69. panel.handle("disable_events")
  70. self.assertEqual(consumer.event_dispatcher.enabled, False)
  71. self.assertIn("worker-offline", consumer.event_dispatcher.sent)
  72. self.assertTrue(panel.handle("disable_events")["ok"])
  73. def test_heartbeat(self):
  74. consumer = Consumer()
  75. panel = self.create_panel(consumer=consumer)
  76. consumer.event_dispatcher.enabled = True
  77. panel.handle("heartbeat")
  78. self.assertIn("worker-heartbeat", consumer.event_dispatcher.sent)
  79. def test_dump_tasks(self):
  80. info = "\n".join(self.panel.handle("dump_tasks"))
  81. self.assertIn("mytask", info)
  82. self.assertIn("rate_limit=200", info)
  83. def test_stats(self):
  84. from celery.worker import state
  85. prev_count, state.total_count = state.total_count, 100
  86. try:
  87. self.assertDictContainsSubset({"total": 100,
  88. "consumer": {"xyz": "XYZ"}},
  89. self.panel.handle("stats"))
  90. finally:
  91. state.total_count = prev_count
  92. def test_active(self):
  93. from celery.worker import state
  94. from celery.worker.job import TaskRequest
  95. from celery.task import PingTask
  96. r = TaskRequest(PingTask.name, "do re mi", (), {})
  97. state.active_requests.add(r)
  98. try:
  99. self.assertTrue(self.panel.handle("dump_active"))
  100. finally:
  101. state.active_requests.discard(r)
  102. def test_pool_grow(self):
  103. class MockPool(object):
  104. def __init__(self, size=1):
  105. self.size = size
  106. def grow(self, n=1):
  107. self.size += n
  108. def shrink(self, n=1):
  109. self.size -= n
  110. consumer = Consumer()
  111. consumer.pool = MockPool()
  112. panel = self.create_panel(consumer=consumer)
  113. panel.handle("pool_grow")
  114. self.assertEqual(consumer.pool.size, 2)
  115. panel.handle("pool_shrink")
  116. self.assertEqual(consumer.pool.size, 1)
  117. def test_add__cancel_consumer(self):
  118. class MockConsumer(object):
  119. queues = []
  120. cancelled = []
  121. consuming = False
  122. def add_consumer_from_dict(self, **declaration):
  123. self.queues.append(declaration["queue"])
  124. def consume(self):
  125. self.consuming = True
  126. def cancel_by_queue(self, queue):
  127. self.cancelled.append(queue)
  128. consumer = Consumer()
  129. consumer.task_consumer = MockConsumer()
  130. panel = self.create_panel(consumer=consumer)
  131. panel.handle("add_consumer", {"queue": "MyQueue"})
  132. self.assertIn("MyQueue", consumer.task_consumer.queues)
  133. self.assertTrue(consumer.task_consumer.consuming)
  134. panel.handle("cancel_consumer", {"queue": "MyQueue"})
  135. self.assertIn("MyQueue", consumer.task_consumer.cancelled)
  136. def test_revoked(self):
  137. from celery.worker import state
  138. state.revoked.clear()
  139. state.revoked.add("a1")
  140. state.revoked.add("a2")
  141. try:
  142. self.assertListEqual(self.panel.handle("dump_revoked"),
  143. ["a1", "a2"])
  144. finally:
  145. state.revoked.clear()
  146. def test_dump_schedule(self):
  147. consumer = Consumer()
  148. panel = self.create_panel(consumer=consumer)
  149. self.assertFalse(panel.handle("dump_schedule"))
  150. r = TaskRequest("celery.ping", "CAFEBABE", (), {})
  151. consumer.eta_schedule.schedule.enter(
  152. consumer.eta_schedule.Entry(lambda x: x, (r, )),
  153. datetime.now() + timedelta(seconds=10))
  154. self.assertTrue(panel.handle("dump_schedule"))
  155. def test_dump_reserved(self):
  156. consumer = Consumer()
  157. panel = self.create_panel(consumer=consumer)
  158. response = panel.handle("dump_reserved", {"safe": True})
  159. self.assertDictContainsSubset({"name": mytask.name,
  160. "args": (2, 2),
  161. "kwargs": {},
  162. "hostname": socket.gethostname()},
  163. response[0])
  164. consumer.ready_queue = FastQueue()
  165. self.assertFalse(panel.handle("dump_reserved"))
  166. def test_rate_limit_when_disabled(self):
  167. app = app_or_default()
  168. app.conf.CELERY_DISABLE_RATE_LIMITS = True
  169. try:
  170. e = self.panel.handle("rate_limit", arguments=dict(
  171. task_name=mytask.name, rate_limit="100/m"))
  172. self.assertIn("rate limits disabled", e.get("error"))
  173. finally:
  174. app.conf.CELERY_DISABLE_RATE_LIMITS = False
  175. def test_rate_limit_invalid_rate_limit_string(self):
  176. e = self.panel.handle("rate_limit", arguments=dict(
  177. task_name="tasks.add", rate_limit="x1240301#%!"))
  178. self.assertIn("Invalid rate limit string", e.get("error"))
  179. def test_rate_limit(self):
  180. class Consumer(object):
  181. class ReadyQueue(object):
  182. fresh = False
  183. def refresh(self):
  184. self.fresh = True
  185. def __init__(self):
  186. self.ready_queue = self.ReadyQueue()
  187. consumer = Consumer()
  188. panel = self.create_panel(consumer=consumer)
  189. task = tasks[PingTask.name]
  190. old_rate_limit = task.rate_limit
  191. try:
  192. panel.handle("rate_limit", arguments=dict(task_name=task.name,
  193. rate_limit="100/m"))
  194. self.assertEqual(task.rate_limit, "100/m")
  195. self.assertTrue(consumer.ready_queue.fresh)
  196. consumer.ready_queue.fresh = False
  197. panel.handle("rate_limit", arguments=dict(task_name=task.name,
  198. rate_limit=0))
  199. self.assertEqual(task.rate_limit, 0)
  200. self.assertTrue(consumer.ready_queue.fresh)
  201. finally:
  202. task.rate_limit = old_rate_limit
  203. def test_rate_limit_nonexistant_task(self):
  204. self.panel.handle("rate_limit", arguments={
  205. "task_name": "xxxx.does.not.exist",
  206. "rate_limit": "1000/s"})
  207. def test_unexposed_command(self):
  208. self.assertRaises(KeyError, self.panel.handle, "foo", arguments={})
  209. def test_revoke_with_name(self):
  210. uuid = gen_unique_id()
  211. m = {"method": "revoke",
  212. "destination": hostname,
  213. "arguments": {"task_id": uuid,
  214. "task_name": mytask.name}}
  215. self.panel.dispatch_from_message(m)
  216. self.assertIn(uuid, revoked)
  217. def test_revoke_with_name_not_in_registry(self):
  218. uuid = gen_unique_id()
  219. m = {"method": "revoke",
  220. "destination": hostname,
  221. "arguments": {"task_id": uuid,
  222. "task_name": "xxxxxxxxx33333333388888"}}
  223. self.panel.dispatch_from_message(m)
  224. self.assertIn(uuid, revoked)
  225. def test_revoke(self):
  226. uuid = gen_unique_id()
  227. m = {"method": "revoke",
  228. "destination": hostname,
  229. "arguments": {"task_id": uuid}}
  230. self.panel.dispatch_from_message(m)
  231. self.assertIn(uuid, revoked)
  232. m = {"method": "revoke",
  233. "destination": "does.not.exist",
  234. "arguments": {"task_id": uuid + "xxx"}}
  235. self.panel.dispatch_from_message(m)
  236. self.assertNotIn(uuid + "xxx", revoked)
  237. def test_ping(self):
  238. m = {"method": "ping",
  239. "destination": hostname}
  240. r = self.panel.dispatch_from_message(m)
  241. self.assertEqual(r, "pong")
  242. def test_shutdown(self):
  243. m = {"method": "shutdown",
  244. "destination": hostname}
  245. self.assertRaises(SystemExit, self.panel.dispatch_from_message, m)
  246. def test_panel_reply(self):
  247. replies = []
  248. class _Node(pidbox.Node):
  249. def reply(self, data, exchange, routing_key, **kwargs):
  250. replies.append(data)
  251. panel = _Node(hostname=hostname,
  252. state=self.create_state(consumer=Consumer()),
  253. handlers=Panel.data,
  254. mailbox=self.app.control.mailbox)
  255. r = panel.dispatch("ping", reply_to={"exchange": "x",
  256. "routing_key": "x"})
  257. self.assertEqual(r, "pong")
  258. self.assertDictEqual(replies[0], {panel.hostname: "pong"})