case.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import atexit
  2. import logging
  3. import os
  4. import signal
  5. import socket
  6. import sys
  7. import traceback
  8. import unittest2 as unittest
  9. from itertools import count
  10. from celery.exceptions import TimeoutError
  11. from celery.task.control import broadcast, ping, flatten_reply, inspect
  12. from celery.utils import get_full_cls_name
  13. HOSTNAME = socket.gethostname()
  14. def say(msg):
  15. sys.stderr.write("%s\n" % msg)
  16. def try_while(fun, reason="Timed out", timeout=10, interval=0.5):
  17. for iterations in count(0):
  18. if iterations * interval >= timeout:
  19. raise TimeoutError()
  20. ret = fun()
  21. if ret:
  22. return ret
  23. class Worker(object):
  24. started = False
  25. next_worker_id = count(1).next
  26. _shutdown_called = False
  27. def __init__(self, hostname, loglevel="error"):
  28. self.hostname = hostname
  29. self.loglevel = loglevel
  30. def start(self):
  31. if not self.started:
  32. self._fork_and_exec()
  33. self.started = True
  34. def _fork_and_exec(self):
  35. pid = os.fork()
  36. if pid == 0:
  37. os.execv(sys.executable,
  38. [sys.executable] + ["-m", "celery.bin.celeryd",
  39. "-l", self.loglevel,
  40. "-n", self.hostname])
  41. os.exit()
  42. self.pid = pid
  43. def is_alive(self, timeout=1):
  44. r = ping(destination=[self.hostname],
  45. timeout=timeout)
  46. return self.hostname in flatten_reply(r)
  47. def wait_until_started(self, timeout=10, interval=0.2):
  48. try_while(lambda: self.is_alive(interval),
  49. "Worker won't start (after %s secs.)" % timeout,
  50. interval=0.2, timeout=10)
  51. say("--WORKER %s IS ONLINE--" % self.hostname)
  52. def ensure_shutdown(self, timeout=10, interval=0.5):
  53. os.kill(self.pid, signal.SIGTERM)
  54. try_while(lambda: not self.is_alive(interval),
  55. "Worker won't shutdown (after %s secs.)" % timeout,
  56. timeout=10, interval=0.5)
  57. say("--WORKER %s IS SHUTDOWN--" % self.hostname)
  58. self._shutdown_called = True
  59. def ensure_started(self):
  60. self.start()
  61. self.wait_until_started()
  62. @classmethod
  63. def managed(cls, hostname=None, caller=None):
  64. hostname = hostname or socket.gethostname()
  65. if caller:
  66. hostname = ".".join([get_full_cls_name(caller), hostname])
  67. else:
  68. hostname += str(cls.next_worker_id())
  69. worker = cls(hostname)
  70. worker.ensure_started()
  71. stack = traceback.format_stack()
  72. @atexit.register
  73. def _ensure_shutdown_once():
  74. if not worker._shutdown_called:
  75. say("-- Found worker not stopped at shutdown: %s\n%s" % (
  76. worker.hostname,
  77. "\n".join(stack)))
  78. worker.ensure_shutdown()
  79. return worker
  80. class WorkerCase(unittest.TestCase):
  81. hostname = HOSTNAME
  82. worker = None
  83. @classmethod
  84. def setUpClass(cls):
  85. logging.getLogger("amqplib").setLevel(logging.ERROR)
  86. cls.worker = Worker.managed(cls.hostname, caller=cls)
  87. @classmethod
  88. def tearDownClass(cls):
  89. cls.worker.ensure_shutdown()
  90. def assertWorkerAlive(self, timeout=1):
  91. self.assertTrue(self.worker.is_alive)
  92. def inspect(self, timeout=1):
  93. return inspect(self.worker.hostname, timeout=timeout)
  94. def my_response(self, response):
  95. return flatten_reply(response)[self.worker.hostname]
  96. def is_accepted(self, task_id, interval=0.5):
  97. active = self.inspect(timeout=interval).active()
  98. if active:
  99. for task in active:
  100. if task["id"] == task_id:
  101. return True
  102. return False
  103. def is_reserved(self, task_id, interval=0.5):
  104. reserved = self.inspect(timeout=interval).reserved()
  105. if reserved:
  106. for task in reserved:
  107. if task["id"] == task_id:
  108. return True
  109. return False
  110. def is_scheduled(self, task_id, interval=0.5):
  111. schedule = self.inspect(timeout=interval).scheduled()
  112. if schedule:
  113. for item in schedule:
  114. if item["request"]["id"] == task_id:
  115. return True
  116. return False
  117. def is_received(self, task_id, interval=0.5):
  118. return (self.is_reserved(task_id, interval) or
  119. self.is_scheduled(task_id, interval) or
  120. self.is_accepted(task_id, interval))
  121. def ensure_accepted(self, task_id, interval=0.5, timeout=10):
  122. return try_while(lambda: self.is_accepted(task_id, interval),
  123. "Task not accepted within timeout",
  124. interval=0.5, timeout=10)
  125. def ensure_received(self, task_id, interval=0.5, timeout=10):
  126. return try_while(lambda: self.is_received(task_id, interval),
  127. "Task not receied within timeout",
  128. interval=0.5, timeout=10)
  129. def ensure_scheduled(self, task_id, interval=0.5, timeout=10):
  130. return try_while(lambda: self.is_scheduled(task_id, interval),
  131. "Task not scheduled within timeout",
  132. interval=0.5, timeout=10)