test_worker.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import unittest
  2. from Queue import Queue, Empty
  3. from carrot.connection import AMQPConnection
  4. from celery.messaging import TaskConsumer
  5. from celery.worker.job import TaskWrapper
  6. from celery.worker import AMQPListener, WorkController
  7. from multiprocessing import get_logger
  8. from carrot.backends.base import BaseMessage
  9. from celery import registry
  10. from celery.utils import pickle, gen_unique_id
  11. from datetime import datetime, timedelta
  12. def foo_task(x, y, z, **kwargs):
  13. return x * y * z
  14. registry.tasks.register(foo_task, name="c.u.foo")
  15. class MockBackend(object):
  16. _acked = False
  17. def ack(self, delivery_tag):
  18. self._acked = True
  19. def create_message(backend, **data):
  20. data["id"] = gen_unique_id()
  21. return BaseMessage(backend, body=pickle.dumps(dict(**data)),
  22. content_type="application/x-python-serialize",
  23. content_encoding="binary")
  24. class TestAMQPListener(unittest.TestCase):
  25. def setUp(self):
  26. self.bucket_queue = Queue()
  27. self.hold_queue = Queue()
  28. self.logger = get_logger()
  29. self.logger.setLevel(0)
  30. def test_connection(self):
  31. l = AMQPListener(self.bucket_queue, self.hold_queue, self.logger)
  32. c = l.reset_connection()
  33. self.assertTrue(isinstance(c, TaskConsumer))
  34. self.assertTrue(c is l.task_consumer)
  35. self.assertTrue(isinstance(l.amqp_connection, AMQPConnection))
  36. l.close_connection()
  37. self.assertTrue(l.amqp_connection is None)
  38. self.assertTrue(l.task_consumer is None)
  39. c = l.reset_connection()
  40. self.assertTrue(isinstance(c, TaskConsumer))
  41. self.assertTrue(c is l.task_consumer)
  42. self.assertTrue(isinstance(l.amqp_connection, AMQPConnection))
  43. l.stop()
  44. self.assertTrue(l.amqp_connection is None)
  45. self.assertTrue(l.task_consumer is None)
  46. def test_receieve_message(self):
  47. l = AMQPListener(self.bucket_queue, self.hold_queue, self.logger)
  48. backend = MockBackend()
  49. m = create_message(backend, task="c.u.foo", args=[2, 4, 8], kwargs={})
  50. l.receive_message(m.decode(), m)
  51. in_bucket = self.bucket_queue.get_nowait()
  52. self.assertTrue(isinstance(in_bucket, TaskWrapper))
  53. self.assertEquals(in_bucket.task_name, "c.u.foo")
  54. self.assertEquals(in_bucket.execute(), 2 * 4 * 8)
  55. self.assertRaises(Empty, self.hold_queue.get_nowait)
  56. def test_receieve_message_eta(self):
  57. l = AMQPListener(self.bucket_queue, self.hold_queue, self.logger)
  58. backend = MockBackend()
  59. m = create_message(backend, task="c.u.foo", args=[2, 4, 8], kwargs={},
  60. eta=datetime.now() + timedelta(days=1))
  61. l.receive_message(m.decode(), m)
  62. in_hold = self.hold_queue.get_nowait()
  63. self.assertEquals(len(in_hold), 2)
  64. task, eta = in_hold
  65. self.assertTrue(isinstance(task, TaskWrapper))
  66. self.assertTrue(isinstance(eta, datetime))
  67. self.assertEquals(task.task_name, "c.u.foo")
  68. self.assertEquals(task.execute(), 2 * 4 * 8)
  69. self.assertRaises(Empty, self.bucket_queue.get_nowait)