|  | @@ -0,0 +1,120 @@
 | 
	
		
			
				|  |  | +import atexit
 | 
	
		
			
				|  |  | +import logging
 | 
	
		
			
				|  |  | +import os
 | 
	
		
			
				|  |  | +import signal
 | 
	
		
			
				|  |  | +import socket
 | 
	
		
			
				|  |  | +import sys
 | 
	
		
			
				|  |  | +import traceback
 | 
	
		
			
				|  |  | +import unittest2 as unittest
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from itertools import count
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from celery.task.control import broadcast, ping
 | 
	
		
			
				|  |  | +from celery.utils import get_full_cls_name
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +HOSTNAME = socket.gethostname()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def say(msg):
 | 
	
		
			
				|  |  | +    sys.stderr.write("%s\n" % msg)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def flatten_response(response):
 | 
	
		
			
				|  |  | +    flat = {}
 | 
	
		
			
				|  |  | +    for item in response:
 | 
	
		
			
				|  |  | +        flat.update(item)
 | 
	
		
			
				|  |  | +    return flat
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class Worker(object):
 | 
	
		
			
				|  |  | +    started = False
 | 
	
		
			
				|  |  | +    next_worker_id = count(1).next
 | 
	
		
			
				|  |  | +    _shutdown_called = False
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __init__(self, hostname, loglevel="error"):
 | 
	
		
			
				|  |  | +        self.hostname = hostname
 | 
	
		
			
				|  |  | +        self.loglevel = loglevel
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def start(self):
 | 
	
		
			
				|  |  | +        if not self.started:
 | 
	
		
			
				|  |  | +            self._fork_and_exec()
 | 
	
		
			
				|  |  | +            self.started = True
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _fork_and_exec(self):
 | 
	
		
			
				|  |  | +        pid = os.fork()
 | 
	
		
			
				|  |  | +        if pid == 0:
 | 
	
		
			
				|  |  | +            os.execv(sys.executable,
 | 
	
		
			
				|  |  | +                    [sys.executable] + ["-m", "celery.bin.celeryd",
 | 
	
		
			
				|  |  | +                                        "-l", self.loglevel,
 | 
	
		
			
				|  |  | +                                        "-n", self.hostname])
 | 
	
		
			
				|  |  | +            os.exit()
 | 
	
		
			
				|  |  | +        self.pid = pid
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def is_alive(self, timeout=1):
 | 
	
		
			
				|  |  | +        r = ping(destination=[self.hostname],
 | 
	
		
			
				|  |  | +                 timeout=timeout)
 | 
	
		
			
				|  |  | +        return self.hostname in flatten_response(r)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def wait_until_started(self, timeout=10, interval=0.2):
 | 
	
		
			
				|  |  | +        for iteration in count(0):
 | 
	
		
			
				|  |  | +            if iteration * interval >= timeout:
 | 
	
		
			
				|  |  | +                raise Exception(
 | 
	
		
			
				|  |  | +                        "Worker won't start (after %s secs.)" % timeout)
 | 
	
		
			
				|  |  | +            if self.is_alive(interval):
 | 
	
		
			
				|  |  | +                break
 | 
	
		
			
				|  |  | +        say("--WORKER %s IS ONLINE--" % self.hostname)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def ensure_shutdown(self, timeout=10, interval=0.5):
 | 
	
		
			
				|  |  | +        os.kill(self.pid, signal.SIGTERM)
 | 
	
		
			
				|  |  | +        for iteration in count(0):
 | 
	
		
			
				|  |  | +            if iteration * interval >= timeout:
 | 
	
		
			
				|  |  | +                raise Exception(
 | 
	
		
			
				|  |  | +                        "Worker won't shutdown (after %s secs.)" % timeout)
 | 
	
		
			
				|  |  | +            broadcast("shutdown", destination=[self.hostname])
 | 
	
		
			
				|  |  | +            if not self.is_alive(interval):
 | 
	
		
			
				|  |  | +                break
 | 
	
		
			
				|  |  | +        say("--WORKER %s IS SHUTDOWN--" % self.hostname)
 | 
	
		
			
				|  |  | +        self._shutdown_called = True
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def ensure_started(self):
 | 
	
		
			
				|  |  | +        self.start()
 | 
	
		
			
				|  |  | +        self.wait_until_started()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @classmethod
 | 
	
		
			
				|  |  | +    def managed(cls, hostname=None, caller=None):
 | 
	
		
			
				|  |  | +        hostname = hostname or socket.gethostname()
 | 
	
		
			
				|  |  | +        if caller:
 | 
	
		
			
				|  |  | +            hostname = ".".join([get_full_cls_name(caller), hostname])
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            hostname += str(cls.next_worker_id())
 | 
	
		
			
				|  |  | +        worker = cls(hostname)
 | 
	
		
			
				|  |  | +        worker.ensure_started()
 | 
	
		
			
				|  |  | +        stack = traceback.format_stack()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        @atexit.register
 | 
	
		
			
				|  |  | +        def _ensure_shutdown_once():
 | 
	
		
			
				|  |  | +            if not worker._shutdown_called:
 | 
	
		
			
				|  |  | +                say("-- Found worker not stopped at shutdown: %s\n%s" % (
 | 
	
		
			
				|  |  | +                        worker.hostname,
 | 
	
		
			
				|  |  | +                        "\n".join(stack)))
 | 
	
		
			
				|  |  | +                worker.ensure_shutdown()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return worker
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class WorkerCase(unittest.TestCase):
 | 
	
		
			
				|  |  | +    hostname = HOSTNAME
 | 
	
		
			
				|  |  | +    worker = None
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @classmethod
 | 
	
		
			
				|  |  | +    def setUpClass(cls):
 | 
	
		
			
				|  |  | +        logging.getLogger("amqplib").setLevel(logging.ERROR)
 | 
	
		
			
				|  |  | +        cls.worker = Worker.managed(cls.hostname, caller=cls)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @classmethod
 | 
	
		
			
				|  |  | +    def tearDownClass(cls):
 | 
	
		
			
				|  |  | +        cls.worker.ensure_shutdown()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def assertWorkerAlive(self, timeout=1):
 | 
	
		
			
				|  |  | +        self.assertTrue(self.worker.is_alive)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def my_response(self, response):
 | 
	
		
			
				|  |  | +        return flatten_response(response)[self.worker.hostname]
 | 
	
		
			
				|  |  | +
 |