|  | @@ -9,19 +9,25 @@ import unittest2 as unittest
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from itertools import count
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from celery.task.control import broadcast, ping
 | 
	
		
			
				|  |  | +from celery.exceptions import TimeoutError
 | 
	
		
			
				|  |  | +from celery.task.control import broadcast, ping, flatten_reply, inspect
 | 
	
		
			
				|  |  |  from celery.utils import get_full_cls_name
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  HOSTNAME = socket.gethostname()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  def say(msg):
 | 
	
		
			
				|  |  |      sys.stderr.write("%s\n" % msg)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -def flatten_response(response):
 | 
	
		
			
				|  |  | -    flat = {}
 | 
	
		
			
				|  |  | -    for item in response:
 | 
	
		
			
				|  |  | -        flat.update(item)
 | 
	
		
			
				|  |  | -    return flat
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def try_while(fun, reason="Timed out", timeout=10, interval=0.5):
 | 
	
		
			
				|  |  | +    for iterations in count(0):
 | 
	
		
			
				|  |  | +        if iterations * interval >= timeout:
 | 
	
		
			
				|  |  | +            raise TimeoutError()
 | 
	
		
			
				|  |  | +        ret = fun()
 | 
	
		
			
				|  |  | +        if ret:
 | 
	
		
			
				|  |  | +            return ret
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class Worker(object):
 | 
	
		
			
				|  |  |      started = False
 | 
	
	
		
			
				|  | @@ -50,26 +56,19 @@ class Worker(object):
 | 
	
		
			
				|  |  |      def is_alive(self, timeout=1):
 | 
	
		
			
				|  |  |          r = ping(destination=[self.hostname],
 | 
	
		
			
				|  |  |                   timeout=timeout)
 | 
	
		
			
				|  |  | -        return self.hostname in flatten_response(r)
 | 
	
		
			
				|  |  | +        return self.hostname in flatten_reply(r)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def wait_until_started(self, timeout=10, interval=0.2):
 | 
	
		
			
				|  |  | -        for iteration in count(0):
 | 
	
		
			
				|  |  | -            if iteration * interval >= timeout:
 | 
	
		
			
				|  |  | -                raise Exception(
 | 
	
		
			
				|  |  | -                        "Worker won't start (after %s secs.)" % timeout)
 | 
	
		
			
				|  |  | -            if self.is_alive(interval):
 | 
	
		
			
				|  |  | -                break
 | 
	
		
			
				|  |  | +        try_while(lambda: self.is_alive(interval),
 | 
	
		
			
				|  |  | +                "Worker won't start (after %s secs.)" % timeout,
 | 
	
		
			
				|  |  | +                interval=0.2, timeout=10)
 | 
	
		
			
				|  |  |          say("--WORKER %s IS ONLINE--" % self.hostname)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def ensure_shutdown(self, timeout=10, interval=0.5):
 | 
	
		
			
				|  |  |          os.kill(self.pid, signal.SIGTERM)
 | 
	
		
			
				|  |  | -        for iteration in count(0):
 | 
	
		
			
				|  |  | -            if iteration * interval >= timeout:
 | 
	
		
			
				|  |  | -                raise Exception(
 | 
	
		
			
				|  |  | -                        "Worker won't shutdown (after %s secs.)" % timeout)
 | 
	
		
			
				|  |  | -            broadcast("shutdown", destination=[self.hostname])
 | 
	
		
			
				|  |  | -            if not self.is_alive(interval):
 | 
	
		
			
				|  |  | -                break
 | 
	
		
			
				|  |  | +        try_while(lambda: not self.is_alive(interval),
 | 
	
		
			
				|  |  | +                  "Worker won't shutdown (after %s secs.)" % timeout,
 | 
	
		
			
				|  |  | +                  timeout=10, interval=0.5)
 | 
	
		
			
				|  |  |          say("--WORKER %s IS SHUTDOWN--" % self.hostname)
 | 
	
		
			
				|  |  |          self._shutdown_called = True
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -115,6 +114,53 @@ class WorkerCase(unittest.TestCase):
 | 
	
		
			
				|  |  |      def assertWorkerAlive(self, timeout=1):
 | 
	
		
			
				|  |  |          self.assertTrue(self.worker.is_alive)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def my_response(self, response):
 | 
	
		
			
				|  |  | -        return flatten_response(response)[self.worker.hostname]
 | 
	
		
			
				|  |  | +    def inspect(self, timeout=1):
 | 
	
		
			
				|  |  | +        return inspect(self.worker.hostname, timeout=timeout)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def my_response(self, response):
 | 
	
		
			
				|  |  | +        return flatten_reply(response)[self.worker.hostname]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def is_accepted(self, task_id, interval=0.5):
 | 
	
		
			
				|  |  | +        active = self.inspect(timeout=interval).active()
 | 
	
		
			
				|  |  | +        if active:
 | 
	
		
			
				|  |  | +            for task in active:
 | 
	
		
			
				|  |  | +                if task["id"] == task_id:
 | 
	
		
			
				|  |  | +                    return True
 | 
	
		
			
				|  |  | +        return False
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def is_reserved(self, task_id, interval=0.5):
 | 
	
		
			
				|  |  | +        reserved = self.inspect(timeout=interval).reserved()
 | 
	
		
			
				|  |  | +        if reserved:
 | 
	
		
			
				|  |  | +            for task in reserved:
 | 
	
		
			
				|  |  | +                if task["id"] == task_id:
 | 
	
		
			
				|  |  | +                    return True
 | 
	
		
			
				|  |  | +        return False
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def is_scheduled(self, task_id, interval=0.5):
 | 
	
		
			
				|  |  | +        schedule = self.inspect(timeout=interval).scheduled()
 | 
	
		
			
				|  |  | +        if schedule:
 | 
	
		
			
				|  |  | +            for item in schedule:
 | 
	
		
			
				|  |  | +                if item["request"]["id"] == task_id:
 | 
	
		
			
				|  |  | +                    return True
 | 
	
		
			
				|  |  | +        return False
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def is_received(self, task_id, interval=0.5):
 | 
	
		
			
				|  |  | +        return (self.is_reserved(task_id, interval) or
 | 
	
		
			
				|  |  | +                self.is_scheduled(task_id, interval) or
 | 
	
		
			
				|  |  | +                self.is_accepted(task_id, interval))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def ensure_accepted(self, task_id, interval=0.5, timeout=10):
 | 
	
		
			
				|  |  | +        return try_while(lambda: self.is_accepted(task_id, interval),
 | 
	
		
			
				|  |  | +                         "Task not accepted within timeout",
 | 
	
		
			
				|  |  | +                         interval=0.5, timeout=10)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def ensure_received(self, task_id, interval=0.5, timeout=10):
 | 
	
		
			
				|  |  | +        return try_while(lambda: self.is_received(task_id, interval),
 | 
	
		
			
				|  |  | +                        "Task not receied within timeout",
 | 
	
		
			
				|  |  | +                        interval=0.5, timeout=10)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def ensure_scheduled(self, task_id, interval=0.5, timeout=10):
 | 
	
		
			
				|  |  | +        return try_while(lambda: self.is_scheduled(task_id, interval),
 | 
	
		
			
				|  |  | +                        "Task not scheduled within timeout",
 | 
	
		
			
				|  |  | +                        interval=0.5, timeout=10)
 |