case.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from __future__ import absolute_import
  2. import atexit
  3. import logging
  4. import os
  5. import signal
  6. import socket
  7. import sys
  8. import traceback
  9. from itertools import count
  10. from time import time
  11. from celery.exceptions import TimeoutError
  12. from celery.task.control import ping, flatten_reply, inspect
  13. from celery.utils.imports import qualname
  14. from celery.tests.utils import Case
  15. HOSTNAME = socket.gethostname()
  16. def say(msg):
  17. sys.stderr.write('%s\n' % msg)
  18. def try_while(fun, reason='Timed out', timeout=10, interval=0.5):
  19. time_start = time()
  20. for iterations in count(0):
  21. if time() - time_start >= timeout:
  22. raise TimeoutError()
  23. ret = fun()
  24. if ret:
  25. return ret
  26. class Worker(object):
  27. started = False
  28. next_worker_id = count(1).next
  29. _shutdown_called = False
  30. def __init__(self, hostname, loglevel='error'):
  31. self.hostname = hostname
  32. self.loglevel = loglevel
  33. def start(self):
  34. if not self.started:
  35. self._fork_and_exec()
  36. self.started = True
  37. def _fork_and_exec(self):
  38. pid = os.fork()
  39. if pid == 0:
  40. from celery import current_app
  41. current_app.worker_main(['celeryd', '--loglevel=INFO',
  42. '-n', self.hostname,
  43. '-P', 'solo'])
  44. os._exit(0)
  45. self.pid = pid
  46. def is_alive(self, timeout=1):
  47. r = ping(destination=[self.hostname],
  48. timeout=timeout)
  49. return self.hostname in flatten_reply(r)
  50. def wait_until_started(self, timeout=10, interval=0.5):
  51. try_while(lambda: self.is_alive(interval),
  52. "Worker won't start (after %s secs.)" % timeout,
  53. interval=interval, timeout=timeout)
  54. say('--WORKER %s IS ONLINE--' % self.hostname)
  55. def ensure_shutdown(self, timeout=10, interval=0.5):
  56. os.kill(self.pid, signal.SIGTERM)
  57. try_while(lambda: not self.is_alive(interval),
  58. "Worker won't shutdown (after %s secs.)" % timeout,
  59. timeout=10, interval=0.5)
  60. say('--WORKER %s IS SHUTDOWN--' % self.hostname)
  61. self._shutdown_called = True
  62. def ensure_started(self):
  63. self.start()
  64. self.wait_until_started()
  65. @classmethod
  66. def managed(cls, hostname=None, caller=None):
  67. hostname = hostname or socket.gethostname()
  68. if caller:
  69. hostname = '.'.join([qualname(caller), hostname])
  70. else:
  71. hostname += str(cls.next_worker_id())
  72. worker = cls(hostname)
  73. worker.ensure_started()
  74. stack = traceback.format_stack()
  75. @atexit.register
  76. def _ensure_shutdown_once():
  77. if not worker._shutdown_called:
  78. say('-- Found worker not stopped at shutdown: %s\n%s' % (
  79. worker.hostname,
  80. '\n'.join(stack)))
  81. worker.ensure_shutdown()
  82. return worker
  83. class WorkerCase(Case):
  84. hostname = HOSTNAME
  85. worker = None
  86. @classmethod
  87. def setUpClass(cls):
  88. logging.getLogger('amqplib').setLevel(logging.ERROR)
  89. cls.worker = Worker.managed(cls.hostname, caller=cls)
  90. @classmethod
  91. def tearDownClass(cls):
  92. cls.worker.ensure_shutdown()
  93. def assertWorkerAlive(self, timeout=1):
  94. self.assertTrue(self.worker.is_alive)
  95. def inspect(self, timeout=1):
  96. return inspect([self.worker.hostname], timeout=timeout)
  97. def my_response(self, response):
  98. return flatten_reply(response)[self.worker.hostname]
  99. def is_accepted(self, task_id, interval=0.5):
  100. active = self.inspect(timeout=interval).active()
  101. if active:
  102. for task in active[self.worker.hostname]:
  103. if task['id'] == task_id:
  104. return True
  105. return False
  106. def is_reserved(self, task_id, interval=0.5):
  107. reserved = self.inspect(timeout=interval).reserved()
  108. if reserved:
  109. for task in reserved[self.worker.hostname]:
  110. if task['id'] == task_id:
  111. return True
  112. return False
  113. def is_scheduled(self, task_id, interval=0.5):
  114. schedule = self.inspect(timeout=interval).scheduled()
  115. if schedule:
  116. for item in schedule[self.worker.hostname]:
  117. if item['request']['id'] == task_id:
  118. return True
  119. return False
  120. def is_received(self, task_id, interval=0.5):
  121. return (self.is_reserved(task_id, interval) or
  122. self.is_scheduled(task_id, interval) or
  123. self.is_accepted(task_id, interval))
  124. def ensure_accepted(self, task_id, interval=0.5, timeout=10):
  125. return try_while(lambda: self.is_accepted(task_id, interval),
  126. 'Task not accepted within timeout',
  127. interval=0.5, timeout=10)
  128. def ensure_received(self, task_id, interval=0.5, timeout=10):
  129. return try_while(lambda: self.is_received(task_id, interval),
  130. 'Task not receied within timeout',
  131. interval=0.5, timeout=10)
  132. def ensure_scheduled(self, task_id, interval=0.5, timeout=10):
  133. return try_while(lambda: self.is_scheduled(task_id, interval),
  134. 'Task not scheduled within timeout',
  135. interval=0.5, timeout=10)