case.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from __future__ import absolute_import
  2. import atexit
  3. import logging
  4. import os
  5. import signal
  6. import socket
  7. import sys
  8. import traceback
  9. from itertools import count
  10. from time import time
  11. from celery.exceptions import TimeoutError
  12. from celery.task.control import ping, flatten_reply, inspect
  13. from celery.utils import get_full_cls_name
  14. from celery.tests.utils import unittest
  15. HOSTNAME = socket.gethostname()
  16. def say(msg):
  17. sys.stderr.write("%s\n" % msg)
  18. def try_while(fun, reason="Timed out", timeout=10, interval=0.5):
  19. time_start = time()
  20. for iterations in count(0):
  21. if time() - time_start >= timeout:
  22. raise TimeoutError()
  23. ret = fun()
  24. if ret:
  25. return ret
  26. class Worker(object):
  27. started = False
  28. next_worker_id = count(1).next
  29. _shutdown_called = False
  30. def __init__(self, hostname, loglevel="error"):
  31. self.hostname = hostname
  32. self.loglevel = loglevel
  33. def start(self):
  34. if not self.started:
  35. self._fork_and_exec()
  36. self.started = True
  37. def _fork_and_exec(self):
  38. pid = os.fork()
  39. if pid == 0:
  40. from celery import current_app
  41. current_app.worker_main(["celeryd", "--loglevel=DEBUG",
  42. "-n", self.hostname])
  43. os._exit(0)
  44. self.pid = pid
  45. def is_alive(self, timeout=1):
  46. r = ping(destination=[self.hostname],
  47. timeout=timeout)
  48. return self.hostname in flatten_reply(r)
  49. def wait_until_started(self, timeout=10, interval=0.5):
  50. try_while(lambda: self.is_alive(interval),
  51. "Worker won't start (after %s secs.)" % timeout,
  52. interval=interval, timeout=timeout)
  53. say("--WORKER %s IS ONLINE--" % self.hostname)
  54. def ensure_shutdown(self, timeout=10, interval=0.5):
  55. os.kill(self.pid, signal.SIGTERM)
  56. try_while(lambda: not self.is_alive(interval),
  57. "Worker won't shutdown (after %s secs.)" % timeout,
  58. timeout=10, interval=0.5)
  59. say("--WORKER %s IS SHUTDOWN--" % self.hostname)
  60. self._shutdown_called = True
  61. def ensure_started(self):
  62. self.start()
  63. self.wait_until_started()
  64. @classmethod
  65. def managed(cls, hostname=None, caller=None):
  66. hostname = hostname or socket.gethostname()
  67. if caller:
  68. hostname = ".".join([get_full_cls_name(caller), hostname])
  69. else:
  70. hostname += str(cls.next_worker_id())
  71. worker = cls(hostname)
  72. worker.ensure_started()
  73. stack = traceback.format_stack()
  74. @atexit.register
  75. def _ensure_shutdown_once():
  76. if not worker._shutdown_called:
  77. say("-- Found worker not stopped at shutdown: %s\n%s" % (
  78. worker.hostname,
  79. "\n".join(stack)))
  80. worker.ensure_shutdown()
  81. return worker
  82. class WorkerCase(unittest.TestCase):
  83. hostname = HOSTNAME
  84. worker = None
  85. @classmethod
  86. def setUpClass(cls):
  87. logging.getLogger("amqplib").setLevel(logging.ERROR)
  88. cls.worker = Worker.managed(cls.hostname, caller=cls)
  89. @classmethod
  90. def tearDownClass(cls):
  91. cls.worker.ensure_shutdown()
  92. def assertWorkerAlive(self, timeout=1):
  93. self.assertTrue(self.worker.is_alive)
  94. def inspect(self, timeout=1):
  95. return inspect([self.worker.hostname], timeout=timeout)
  96. def my_response(self, response):
  97. return flatten_reply(response)[self.worker.hostname]
  98. def is_accepted(self, task_id, interval=0.5):
  99. active = self.inspect(timeout=interval).active()
  100. if active:
  101. for task in active[self.worker.hostname]:
  102. if task["id"] == task_id:
  103. return True
  104. return False
  105. def is_reserved(self, task_id, interval=0.5):
  106. reserved = self.inspect(timeout=interval).reserved()
  107. if reserved:
  108. for task in reserved[self.worker.hostname]:
  109. if task["id"] == task_id:
  110. return True
  111. return False
  112. def is_scheduled(self, task_id, interval=0.5):
  113. schedule = self.inspect(timeout=interval).scheduled()
  114. if schedule:
  115. for item in schedule[self.worker.hostname]:
  116. if item["request"]["id"] == task_id:
  117. return True
  118. return False
  119. def is_received(self, task_id, interval=0.5):
  120. return (self.is_reserved(task_id, interval) or
  121. self.is_scheduled(task_id, interval) or
  122. self.is_accepted(task_id, interval))
  123. def ensure_accepted(self, task_id, interval=0.5, timeout=10):
  124. return try_while(lambda: self.is_accepted(task_id, interval),
  125. "Task not accepted within timeout",
  126. interval=0.5, timeout=10)
  127. def ensure_received(self, task_id, interval=0.5, timeout=10):
  128. return try_while(lambda: self.is_received(task_id, interval),
  129. "Task not receied within timeout",
  130. interval=0.5, timeout=10)
  131. def ensure_scheduled(self, task_id, interval=0.5, timeout=10):
  132. return try_while(lambda: self.is_scheduled(task_id, interval),
  133. "Task not scheduled within timeout",
  134. interval=0.5, timeout=10)