case.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import atexit
  2. import logging
  3. import os
  4. import signal
  5. import socket
  6. import sys
  7. import traceback
  8. from itertools import count
  9. from time import time
  10. from celery.exceptions import TimeoutError
  11. from celery.task.control import ping, flatten_reply, inspect
  12. from celery.utils import get_full_cls_name
  13. from celery.tests.utils import unittest
  14. HOSTNAME = socket.gethostname()
  15. def say(msg):
  16. sys.stderr.write("%s\n" % msg)
  17. def try_while(fun, reason="Timed out", timeout=10, interval=0.5):
  18. time_start = time()
  19. for iterations in count(0):
  20. if time() - time_start >= timeout:
  21. raise TimeoutError()
  22. ret = fun()
  23. if ret:
  24. return ret
  25. class Worker(object):
  26. started = False
  27. next_worker_id = count(1).next
  28. _shutdown_called = False
  29. def __init__(self, hostname, loglevel="error"):
  30. self.hostname = hostname
  31. self.loglevel = loglevel
  32. def start(self):
  33. if not self.started:
  34. self._fork_and_exec()
  35. self.started = True
  36. def _fork_and_exec(self):
  37. pid = os.fork()
  38. if pid == 0:
  39. from celery import current_app
  40. current_app.worker_main(["celeryd", "--loglevel=DEBUG",
  41. "-n", self.hostname])
  42. os._exit(0)
  43. self.pid = pid
  44. def is_alive(self, timeout=1):
  45. r = ping(destination=[self.hostname],
  46. timeout=timeout)
  47. return self.hostname in flatten_reply(r)
  48. def wait_until_started(self, timeout=10, interval=0.5):
  49. try_while(lambda: self.is_alive(interval),
  50. "Worker won't start (after %s secs.)" % timeout,
  51. interval=interval, timeout=timeout)
  52. say("--WORKER %s IS ONLINE--" % self.hostname)
  53. def ensure_shutdown(self, timeout=10, interval=0.5):
  54. os.kill(self.pid, signal.SIGTERM)
  55. try_while(lambda: not self.is_alive(interval),
  56. "Worker won't shutdown (after %s secs.)" % timeout,
  57. timeout=10, interval=0.5)
  58. say("--WORKER %s IS SHUTDOWN--" % self.hostname)
  59. self._shutdown_called = True
  60. def ensure_started(self):
  61. self.start()
  62. self.wait_until_started()
  63. @classmethod
  64. def managed(cls, hostname=None, caller=None):
  65. hostname = hostname or socket.gethostname()
  66. if caller:
  67. hostname = ".".join([get_full_cls_name(caller), hostname])
  68. else:
  69. hostname += str(cls.next_worker_id())
  70. worker = cls(hostname)
  71. worker.ensure_started()
  72. stack = traceback.format_stack()
  73. @atexit.register
  74. def _ensure_shutdown_once():
  75. if not worker._shutdown_called:
  76. say("-- Found worker not stopped at shutdown: %s\n%s" % (
  77. worker.hostname,
  78. "\n".join(stack)))
  79. worker.ensure_shutdown()
  80. return worker
  81. class WorkerCase(unittest.TestCase):
  82. hostname = HOSTNAME
  83. worker = None
  84. @classmethod
  85. def setUpClass(cls):
  86. logging.getLogger("amqplib").setLevel(logging.ERROR)
  87. cls.worker = Worker.managed(cls.hostname, caller=cls)
  88. @classmethod
  89. def tearDownClass(cls):
  90. cls.worker.ensure_shutdown()
  91. def assertWorkerAlive(self, timeout=1):
  92. self.assertTrue(self.worker.is_alive)
  93. def inspect(self, timeout=1):
  94. return inspect([self.worker.hostname], timeout=timeout)
  95. def my_response(self, response):
  96. return flatten_reply(response)[self.worker.hostname]
  97. def is_accepted(self, task_id, interval=0.5):
  98. active = self.inspect(timeout=interval).active()
  99. if active:
  100. for task in active[self.worker.hostname]:
  101. if task["id"] == task_id:
  102. return True
  103. return False
  104. def is_reserved(self, task_id, interval=0.5):
  105. reserved = self.inspect(timeout=interval).reserved()
  106. if reserved:
  107. for task in reserved[self.worker.hostname]:
  108. if task["id"] == task_id:
  109. return True
  110. return False
  111. def is_scheduled(self, task_id, interval=0.5):
  112. schedule = self.inspect(timeout=interval).scheduled()
  113. if schedule:
  114. for item in schedule[self.worker.hostname]:
  115. if item["request"]["id"] == task_id:
  116. return True
  117. return False
  118. def is_received(self, task_id, interval=0.5):
  119. return (self.is_reserved(task_id, interval) or
  120. self.is_scheduled(task_id, interval) or
  121. self.is_accepted(task_id, interval))
  122. def ensure_accepted(self, task_id, interval=0.5, timeout=10):
  123. return try_while(lambda: self.is_accepted(task_id, interval),
  124. "Task not accepted within timeout",
  125. interval=0.5, timeout=10)
  126. def ensure_received(self, task_id, interval=0.5, timeout=10):
  127. return try_while(lambda: self.is_received(task_id, interval),
  128. "Task not receied within timeout",
  129. interval=0.5, timeout=10)
  130. def ensure_scheduled(self, task_id, interval=0.5, timeout=10):
  131. return try_while(lambda: self.is_scheduled(task_id, interval),
  132. "Task not scheduled within timeout",
  133. interval=0.5, timeout=10)