test_control.py 23 KB


  1. import pytest
  2. import sys
  3. import socket
  4. from collections import defaultdict
  5. from datetime import datetime, timedelta
  6. from queue import Queue as FastQueue
  7. from case import Mock, call, patch
  8. from kombu import pidbox
  9. from kombu.utils.uuid import uuid
  10. from celery.utils.timer2 import Timer
  11. from celery.worker import WorkController as _WC # noqa
  12. from celery.worker import consumer
  13. from celery.worker import control
  14. from celery.worker import state as worker_state
  15. from celery.worker.request import Request
  16. from celery.worker.state import revoked
  17. from celery.worker.pidbox import Pidbox, gPidbox
  18. from celery.utils.collections import AttributeDict
  19. hostname = socket.gethostname()
  20. class WorkController:
  21. autoscaler = None
  22. def stats(self):
  23. return {'total': worker_state.total_count}
  24. class Consumer(consumer.Consumer):
  25. def __init__(self, app):
  26. self.app = app
  27. self.buffer = FastQueue()
  28. self.timer = Timer()
  29. self.event_dispatcher = Mock()
  30. self.controller = WorkController()
  31. self.task_consumer = Mock()
  32. self.prefetch_multiplier = 1
  33. self.initial_prefetch_count = 1
  34. from celery.concurrency.base import BasePool
  35. self.pool = BasePool(10)
  36. self.task_buckets = defaultdict(lambda: None)
  37. self.hub = None
  38. def call_soon(self, p, *args, **kwargs):
  39. return p(*args, **kwargs)
  40. class test_Pidbox:
  41. def test_shutdown(self):
  42. with patch('celery.worker.pidbox.ignore_errors') as eig:
  43. parent = Mock()
  44. pbox = Pidbox(parent)
  45. pbox._close_channel = Mock()
  46. assert pbox.c is parent
  47. pconsumer = pbox.consumer = Mock()
  48. cancel = pconsumer.cancel
  49. pbox.shutdown(parent)
  50. eig.assert_called_with(parent, cancel)
  51. pbox._close_channel.assert_called_with(parent)
  52. class test_Pidbox_green:
  53. def test_stop(self):
  54. parent = Mock()
  55. g = gPidbox(parent)
  56. stopped = g._node_stopped = Mock()
  57. shutdown = g._node_shutdown = Mock()
  58. close_chan = g._close_channel = Mock()
  59. g.stop(parent)
  60. shutdown.set.assert_called_with()
  61. stopped.wait.assert_called_with()
  62. close_chan.assert_called_with(parent)
  63. assert g._node_stopped is None
  64. assert g._node_shutdown is None
  65. close_chan.reset()
  66. g.stop(parent)
  67. close_chan.assert_called_with(parent)
  68. def test_resets(self):
  69. parent = Mock()
  70. g = gPidbox(parent)
  71. g._resets = 100
  72. g.reset()
  73. assert g._resets == 101
  74. def test_loop(self):
  75. parent = Mock()
  76. conn = self.app.connection_for_read()
  77. parent.connection_for_read.return_value = conn
  78. drain = conn.drain_events = Mock()
  79. g = gPidbox(parent)
  80. parent.connection = Mock()
  81. do_reset = g._do_reset = Mock()
  82. call_count = [0]
  83. def se(*args, **kwargs):
  84. if call_count[0] > 2:
  85. g._node_shutdown.set()
  86. g.reset()
  87. call_count[0] += 1
  88. drain.side_effect = se
  89. g.loop(parent)
  90. assert do_reset.call_count == 4
  91. class test_ControlPanel:
  92. def setup(self):
  93. self.panel = self.create_panel(consumer=Consumer(self.app))
  94. @self.app.task(name='c.unittest.mytask', rate_limit=200, shared=False)
  95. def mytask():
  96. pass
  97. self.mytask = mytask
  98. def create_state(self, **kwargs):
  99. kwargs.setdefault('app', self.app)
  100. kwargs.setdefault('hostname', hostname)
  101. kwargs.setdefault('tset', set)
  102. return AttributeDict(kwargs)
  103. def create_panel(self, **kwargs):
  104. return self.app.control.mailbox.Node(
  105. hostname=hostname,
  106. state=self.create_state(**kwargs),
  107. handlers=control.Panel.data,
  108. )
  109. def test_enable_events(self):
  110. consumer = Consumer(self.app)
  111. panel = self.create_panel(consumer=consumer)
  112. evd = consumer.event_dispatcher
  113. evd.groups = set()
  114. panel.handle('enable_events')
  115. assert not evd.groups
  116. evd.groups = {'worker'}
  117. panel.handle('enable_events')
  118. assert 'task' in evd.groups
  119. evd.groups = {'task'}
  120. assert 'already enabled' in panel.handle('enable_events')['ok']
  121. def test_disable_events(self):
  122. consumer = Consumer(self.app)
  123. panel = self.create_panel(consumer=consumer)
  124. evd = consumer.event_dispatcher
  125. evd.enabled = True
  126. evd.groups = {'task'}
  127. panel.handle('disable_events')
  128. assert 'task' not in evd.groups
  129. assert 'already disabled' in panel.handle('disable_events')['ok']
  130. def test_clock(self):
  131. consumer = Consumer(self.app)
  132. panel = self.create_panel(consumer=consumer)
  133. panel.state.app.clock.value = 313
  134. x = panel.handle('clock')
  135. assert x['clock'] == 313
  136. def test_hello(self):
  137. consumer = Consumer(self.app)
  138. panel = self.create_panel(consumer=consumer)
  139. panel.state.app.clock.value = 313
  140. panel.state.hostname = 'elaine@vandelay.com'
  141. worker_state.revoked.add('revoked1')
  142. try:
  143. assert panel.handle('hello', {
  144. 'from_node': 'elaine@vandelay.com',
  145. }) is None
  146. x = panel.handle('hello', {
  147. 'from_node': 'george@vandelay.com',
  148. })
  149. assert x['clock'] == 314 # incremented
  150. x = panel.handle('hello', {
  151. 'from_node': 'george@vandelay.com',
  152. 'revoked': {'1234', '4567', '891'}
  153. })
  154. assert 'revoked1' in x['revoked']
  155. assert '1234' in x['revoked']
  156. assert '4567' in x['revoked']
  157. assert '891' in x['revoked']
  158. assert x['clock'] == 315 # incremented
  159. finally:
  160. worker_state.revoked.discard('revoked1')
  161. def test_conf(self):
  162. consumer = Consumer(self.app)
  163. panel = self.create_panel(consumer=consumer)
  164. panel.app = self.app
  165. panel.app.finalize()
  166. self.app.conf.some_key6 = 'hello world'
  167. x = panel.handle('dump_conf')
  168. assert 'some_key6' in x
  169. def test_election(self):
  170. consumer = Consumer(self.app)
  171. panel = self.create_panel(consumer=consumer)
  172. consumer.gossip = Mock()
  173. panel.handle(
  174. 'election', {'id': 'id', 'topic': 'topic', 'action': 'action'},
  175. )
  176. consumer.gossip.election.assert_called_with('id', 'topic', 'action')
  177. def test_election__no_gossip(self):
  178. consumer = Mock(name='consumer')
  179. consumer.gossip = None
  180. panel = self.create_panel(consumer=consumer)
  181. panel.handle(
  182. 'election', {'id': 'id', 'topic': 'topic', 'action': 'action'},
  183. )
  184. def test_heartbeat(self):
  185. consumer = Consumer(self.app)
  186. panel = self.create_panel(consumer=consumer)
  187. event_dispatcher = consumer.event_dispatcher
  188. event_dispatcher.enabled = True
  189. panel.handle('heartbeat')
  190. assert ('worker-heartbeat',) in event_dispatcher.send.call_args
  191. def test_time_limit(self):
  192. panel = self.create_panel(consumer=Mock())
  193. r = panel.handle('time_limit', arguments=dict(
  194. task_name=self.mytask.name, hard=30, soft=10))
  195. assert self.mytask.time_limit == 30
  196. assert self.mytask.soft_time_limit == 10
  197. assert 'ok' in r
  198. r = panel.handle('time_limit', arguments=dict(
  199. task_name=self.mytask.name, hard=None, soft=None))
  200. assert self.mytask.time_limit is None
  201. assert self.mytask.soft_time_limit is None
  202. assert 'ok' in r
  203. r = panel.handle('time_limit', arguments=dict(
  204. task_name='248e8afya9s8dh921eh928', hard=30))
  205. assert 'error' in r
  206. def test_active_queues(self):
  207. import kombu
  208. x = kombu.Consumer(self.app.connection_for_read(),
  209. [kombu.Queue('foo', kombu.Exchange('foo'), 'foo'),
  210. kombu.Queue('bar', kombu.Exchange('bar'), 'bar')],
  211. auto_declare=False)
  212. consumer = Mock()
  213. consumer.task_consumer = x
  214. panel = self.create_panel(consumer=consumer)
  215. r = panel.handle('active_queues')
  216. assert list(sorted(q['name'] for q in r)) == ['bar', 'foo']
  217. def test_active_queues__empty(self):
  218. consumer = Mock(name='consumer')
  219. panel = self.create_panel(consumer=consumer)
  220. consumer.task_consumer = None
  221. assert not panel.handle('active_queues')
  222. def test_dump_tasks(self):
  223. info = '\n'.join(self.panel.handle('dump_tasks'))
  224. assert 'mytask' in info
  225. assert 'rate_limit=200' in info
  226. def test_dump_tasks2(self):
  227. prev, control.DEFAULT_TASK_INFO_ITEMS = (
  228. control.DEFAULT_TASK_INFO_ITEMS, [])
  229. try:
  230. info = '\n'.join(self.panel.handle('dump_tasks'))
  231. assert 'mytask' in info
  232. assert 'rate_limit=200' not in info
  233. finally:
  234. control.DEFAULT_TASK_INFO_ITEMS = prev
  235. def test_stats(self):
  236. prev_count, worker_state.total_count = worker_state.total_count, 100
  237. try:
  238. assert self.panel.handle('stats')['total'] == 100
  239. finally:
  240. worker_state.total_count = prev_count
  241. def test_report(self):
  242. self.panel.handle('report')
  243. def test_active(self):
  244. r = Request(
  245. self.TaskMessage(self.mytask.name, 'do re mi'),
  246. app=self.app,
  247. )
  248. worker_state.active_requests.add(r)
  249. try:
  250. assert self.panel.handle('dump_active')
  251. finally:
  252. worker_state.active_requests.discard(r)
  253. def test_pool_grow(self):
  254. class MockPool:
  255. def __init__(self, size=1):
  256. self.size = size
  257. def grow(self, n=1):
  258. self.size += n
  259. def shrink(self, n=1):
  260. self.size -= n
  261. @property
  262. def num_processes(self):
  263. return self.size
  264. consumer = Consumer(self.app)
  265. consumer.prefetch_multiplier = 8
  266. consumer.qos = Mock(name='qos')
  267. consumer.pool = MockPool(1)
  268. panel = self.create_panel(consumer=consumer)
  269. panel.handle('pool_grow')
  270. assert consumer.pool.size == 2
  271. consumer.qos.increment_eventually.assert_called_with(8)
  272. assert consumer.initial_prefetch_count == 16
  273. panel.handle('pool_shrink')
  274. assert consumer.pool.size == 1
  275. consumer.qos.decrement_eventually.assert_called_with(8)
  276. assert consumer.initial_prefetch_count == 8
  277. panel.state.consumer = Mock()
  278. panel.state.consumer.controller = Mock()
  279. sc = panel.state.consumer.controller.autoscaler = Mock()
  280. panel.handle('pool_grow')
  281. sc.force_scale_up.assert_called()
  282. panel.handle('pool_shrink')
  283. sc.force_scale_down.assert_called()
  284. def test_add__cancel_consumer(self):
  285. class MockConsumer:
  286. queues = []
  287. canceled = []
  288. consuming = False
  289. hub = Mock(name='hub')
  290. def add_queue(self, queue):
  291. self.queues.append(queue.name)
  292. def consume(self):
  293. self.consuming = True
  294. def cancel_by_queue(self, queue):
  295. self.canceled.append(queue)
  296. def consuming_from(self, queue):
  297. return queue in self.queues
  298. consumer = Consumer(self.app)
  299. consumer.task_consumer = MockConsumer()
  300. panel = self.create_panel(consumer=consumer)
  301. panel.handle('add_consumer', {'queue': 'MyQueue'})
  302. assert 'MyQueue' in consumer.task_consumer.queues
  303. assert consumer.task_consumer.consuming
  304. panel.handle('add_consumer', {'queue': 'MyQueue'})
  305. panel.handle('cancel_consumer', {'queue': 'MyQueue'})
  306. assert 'MyQueue' in consumer.task_consumer.canceled
  307. def test_revoked(self):
  308. worker_state.revoked.clear()
  309. worker_state.revoked.add('a1')
  310. worker_state.revoked.add('a2')
  311. try:
  312. assert sorted(self.panel.handle('dump_revoked')) == ['a1', 'a2']
  313. finally:
  314. worker_state.revoked.clear()
  315. def test_dump_schedule(self):
  316. consumer = Consumer(self.app)
  317. panel = self.create_panel(consumer=consumer)
  318. assert not panel.handle('dump_schedule')
  319. r = Request(
  320. self.TaskMessage(self.mytask.name, 'CAFEBABE'),
  321. app=self.app,
  322. )
  323. consumer.timer.schedule.enter_at(
  324. consumer.timer.Entry(lambda x: x, (r,)),
  325. datetime.now() + timedelta(seconds=10))
  326. consumer.timer.schedule.enter_at(
  327. consumer.timer.Entry(lambda x: x, (object(),)),
  328. datetime.now() + timedelta(seconds=10))
  329. assert panel.handle('dump_schedule')
  330. def test_dump_reserved(self):
  331. consumer = Consumer(self.app)
  332. req = Request(
  333. self.TaskMessage(self.mytask.name, args=(2, 2)), app=self.app,
  334. ) # ^ need to keep reference for reserved_tasks WeakSet.
  335. worker_state.task_reserved(req)
  336. try:
  337. panel = self.create_panel(consumer=consumer)
  338. response = panel.handle('dump_reserved', {'safe': True})
  339. assert response[0]['name'] == self.mytask.name
  340. assert response[0]['hostname'] == socket.gethostname()
  341. worker_state.reserved_requests.clear()
  342. assert not panel.handle('dump_reserved')
  343. finally:
  344. worker_state.reserved_requests.clear()
  345. def test_rate_limit_invalid_rate_limit_string(self):
  346. e = self.panel.handle('rate_limit', arguments=dict(
  347. task_name='tasks.add', rate_limit='x1240301#%!'))
  348. assert 'Invalid rate limit string' in e.get('error')
  349. def test_rate_limit(self):
  350. class xConsumer:
  351. reset = False
  352. def reset_rate_limits(self):
  353. self.reset = True
  354. consumer = xConsumer()
  355. panel = self.create_panel(app=self.app, consumer=consumer)
  356. task = self.app.tasks[self.mytask.name]
  357. panel.handle('rate_limit', arguments=dict(task_name=task.name,
  358. rate_limit='100/m'))
  359. assert task.rate_limit == '100/m'
  360. assert consumer.reset
  361. consumer.reset = False
  362. panel.handle('rate_limit', arguments=dict(
  363. task_name=task.name,
  364. rate_limit=0,
  365. ))
  366. assert task.rate_limit == 0
  367. assert consumer.reset
  368. def test_rate_limit_nonexistant_task(self):
  369. self.panel.handle('rate_limit', arguments={
  370. 'task_name': 'xxxx.does.not.exist',
  371. 'rate_limit': '1000/s'})
  372. def test_unexposed_command(self):
  373. with pytest.raises(KeyError):
  374. self.panel.handle('foo', arguments={})
  375. def test_revoke_with_name(self):
  376. tid = uuid()
  377. m = {
  378. 'method': 'revoke',
  379. 'destination': hostname,
  380. 'arguments': {
  381. 'task_id': tid,
  382. 'task_name': self.mytask.name,
  383. },
  384. }
  385. self.panel.handle_message(m, None)
  386. assert tid in revoked
  387. def test_revoke_with_name_not_in_registry(self):
  388. tid = uuid()
  389. m = {
  390. 'method': 'revoke',
  391. 'destination': hostname,
  392. 'arguments': {
  393. 'task_id': tid,
  394. 'task_name': 'xxxxxxxxx33333333388888',
  395. },
  396. }
  397. self.panel.handle_message(m, None)
  398. assert tid in revoked
  399. def test_revoke(self):
  400. tid = uuid()
  401. m = {
  402. 'method': 'revoke',
  403. 'destination': hostname,
  404. 'arguments': {
  405. 'task_id': tid,
  406. },
  407. }
  408. self.panel.handle_message(m, None)
  409. assert tid in revoked
  410. m = {
  411. 'method': 'revoke',
  412. 'destination': 'does.not.exist',
  413. 'arguments': {
  414. 'task_id': tid + 'xxx',
  415. },
  416. }
  417. self.panel.handle_message(m, None)
  418. assert tid + 'xxx' not in revoked
  419. def test_revoke_terminate(self):
  420. request = Mock()
  421. request.id = tid = uuid()
  422. state = self.create_state()
  423. state.consumer = Mock()
  424. worker_state.task_reserved(request)
  425. try:
  426. r = control.revoke(state, tid, terminate=True)
  427. assert tid in revoked
  428. assert request.terminate.call_count
  429. assert 'terminate:' in r['ok']
  430. # unknown task id only revokes
  431. r = control.revoke(state, uuid(), terminate=True)
  432. assert 'tasks unknown' in r['ok']
  433. finally:
  434. worker_state.task_ready(request)
  435. def test_autoscale(self):
  436. self.panel.state.consumer = Mock()
  437. self.panel.state.consumer.controller = Mock()
  438. sc = self.panel.state.consumer.controller.autoscaler = Mock()
  439. sc.update.return_value = 10, 2
  440. m = {'method': 'autoscale',
  441. 'destination': hostname,
  442. 'arguments': {'max': '10', 'min': '2'}}
  443. r = self.panel.handle_message(m, None)
  444. assert 'ok' in r
  445. self.panel.state.consumer.controller.autoscaler = None
  446. r = self.panel.handle_message(m, None)
  447. assert 'error' in r
  448. def test_ping(self):
  449. m = {'method': 'ping',
  450. 'destination': hostname}
  451. r = self.panel.handle_message(m, None)
  452. assert r == {'ok': 'pong'}
  453. def test_shutdown(self):
  454. m = {'method': 'shutdown',
  455. 'destination': hostname}
  456. with pytest.raises(SystemExit):
  457. self.panel.handle_message(m, None)
  458. def test_panel_reply(self):
  459. replies = []
  460. class _Node(pidbox.Node):
  461. def reply(self, data, exchange, routing_key, **kwargs):
  462. replies.append(data)
  463. panel = _Node(
  464. hostname=hostname,
  465. state=self.create_state(consumer=Consumer(self.app)),
  466. handlers=control.Panel.data,
  467. mailbox=self.app.control.mailbox,
  468. )
  469. r = panel.dispatch('ping', reply_to={
  470. 'exchange': 'x',
  471. 'routing_key': 'x',
  472. })
  473. assert r == {'ok': 'pong'}
  474. assert replies[0] == {panel.hostname: {'ok': 'pong'}}
  475. def test_pool_restart(self):
  476. consumer = Consumer(self.app)
  477. consumer.controller = _WC(app=self.app)
  478. consumer.controller.consumer = consumer
  479. consumer.controller.pool.restart = Mock()
  480. consumer.reset_rate_limits = Mock(name='reset_rate_limits()')
  481. consumer.update_strategies = Mock(name='update_strategies()')
  482. consumer.event_dispatcher = Mock(name='evd')
  483. panel = self.create_panel(consumer=consumer)
  484. assert panel.state.consumer.controller.consumer is consumer
  485. panel.app = self.app
  486. _import = panel.app.loader.import_from_cwd = Mock()
  487. _reload = Mock()
  488. with pytest.raises(ValueError):
  489. panel.handle('pool_restart', {'reloader': _reload})
  490. self.app.conf.worker_pool_restarts = True
  491. panel.handle('pool_restart', {'reloader': _reload})
  492. consumer.controller.pool.restart.assert_called()
  493. consumer.reset_rate_limits.assert_called_with()
  494. consumer.update_strategies.assert_called_with()
  495. _reload.assert_not_called()
  496. _import.assert_not_called()
  497. consumer.controller.pool.restart.side_effect = NotImplementedError()
  498. panel.handle('pool_restart', {'reloader': _reload})
  499. consumer.controller.consumer = None
  500. panel.handle('pool_restart', {'reloader': _reload})
  501. @patch('celery.worker.worker.logger.debug')
  502. def test_pool_restart_import_modules(self, _debug):
  503. consumer = Consumer(self.app)
  504. consumer.controller = _WC(app=self.app)
  505. consumer.controller.consumer = consumer
  506. consumer.controller.pool.restart = Mock()
  507. consumer.reset_rate_limits = Mock(name='reset_rate_limits()')
  508. consumer.update_strategies = Mock(name='update_strategies()')
  509. panel = self.create_panel(consumer=consumer)
  510. panel.app = self.app
  511. assert panel.state.consumer.controller.consumer is consumer
  512. _import = consumer.controller.app.loader.import_from_cwd = Mock()
  513. _reload = Mock()
  514. self.app.conf.worker_pool_restarts = True
  515. with patch('sys.modules'):
  516. panel.handle('pool_restart', {
  517. 'modules': ['foo', 'bar'],
  518. 'reloader': _reload,
  519. })
  520. consumer.controller.pool.restart.assert_called()
  521. consumer.reset_rate_limits.assert_called_with()
  522. consumer.update_strategies.assert_called_with()
  523. _reload.assert_not_called()
  524. _import.assert_has_calls([call('bar'), call('foo')], any_order=True)
  525. assert _import.call_count == 2
  526. def test_pool_restart_reload_modules(self):
  527. consumer = Consumer(self.app)
  528. consumer.controller = _WC(app=self.app)
  529. consumer.controller.consumer = consumer
  530. consumer.controller.pool.restart = Mock()
  531. consumer.reset_rate_limits = Mock(name='reset_rate_limits()')
  532. consumer.update_strategies = Mock(name='update_strategies()')
  533. panel = self.create_panel(consumer=consumer)
  534. panel.app = self.app
  535. _import = panel.app.loader.import_from_cwd = Mock()
  536. _reload = Mock()
  537. self.app.conf.worker_pool_restarts = True
  538. with patch.dict(sys.modules, {'foo': None}):
  539. panel.handle('pool_restart', {
  540. 'modules': ['foo'],
  541. 'reload': False,
  542. 'reloader': _reload,
  543. })
  544. consumer.controller.pool.restart.assert_called()
  545. _reload.assert_not_called()
  546. _import.assert_not_called()
  547. _import.reset_mock()
  548. _reload.reset_mock()
  549. consumer.controller.pool.restart.reset_mock()
  550. panel.handle('pool_restart', {
  551. 'modules': ['foo'],
  552. 'reload': True,
  553. 'reloader': _reload,
  554. })
  555. consumer.controller.pool.restart.assert_called()
  556. _reload.assert_called()
  557. _import.assert_not_called()
  558. def test_query_task(self):
  559. consumer = Consumer(self.app)
  560. consumer.controller = _WC(app=self.app)
  561. consumer.controller.consumer = consumer
  562. panel = self.create_panel(consumer=consumer)
  563. panel.app = self.app
  564. req1 = Request(
  565. self.TaskMessage(self.mytask.name, args=(2, 2)),
  566. app=self.app,
  567. )
  568. worker_state.task_reserved(req1)
  569. try:
  570. assert not panel.handle('query_task', {'ids': {'1daa'}})
  571. ret = panel.handle('query_task', {'ids': {req1.id}})
  572. assert req1.id in ret
  573. assert ret[req1.id][0] == 'reserved'
  574. worker_state.active_requests.add(req1)
  575. try:
  576. ret = panel.handle('query_task', {'ids': {req1.id}})
  577. assert ret[req1.id][0] == 'active'
  578. finally:
  579. worker_state.active_requests.clear()
  580. ret = panel.handle('query_task', {'ids': {req1.id}})
  581. assert ret[req1.id][0] == 'reserved'
  582. finally:
  583. worker_state.reserved_requests.clear()