Browse Source

Tests passing

Ask Solem 12 years ago
parent
commit
9a6792d5b2

+ 9 - 2
celery/app/builtins.py

@@ -69,15 +69,22 @@ def add_unlock_chord_task(app):
     """
     from celery.canvas import subtask
     from celery.exceptions import ChordError
+    from celery.result import from_serializable
 
     @app.task(name='celery.chord_unlock', max_retries=None,
               default_retry_delay=1, ignore_result=True, _force_evaluate=True)
     def unlock_chord(group_id, callback, interval=None, propagate=True,
-                     max_retries=None, result=None):
+                     max_retries=None, result=None,
+                     Result=app.AsyncResult, GroupResult=app.GroupResult,
+                     from_serializable=from_serializable):
         if interval is None:
             interval = unlock_chord.default_retry_delay
-        deps = app.GroupResult(group_id, map(app.AsyncResult, result))
+        deps = GroupResult(
+            group_id,
+            [from_serializable(r, Result=Result) for r in result],
+        )
         j = deps.join_native if deps.supports_native_join else deps.join
+
         if deps.ready():
             callback = subtask(callback)
             try:

+ 20 - 14
celery/events/__init__.py

@@ -79,7 +79,7 @@ class EventDispatcher(object):
         self.hostname = hostname or socket.gethostname()
         self.buffer_while_offline = buffer_while_offline
         self.mutex = threading.Lock()
-        self.publisher = None
+        self.producer = None
         self._outbound_buffer = deque()
         self.serializer = serializer or self.app.conf.CELERY_EVENT_SERIALIZER
         self.on_enabled = set()
@@ -104,9 +104,9 @@ class EventDispatcher(object):
             return get_exchange(self.channel.connection.client)
 
     def enable(self):
-        self.publisher = Producer(self.channel or self.connection,
-                                  exchange=self.get_exchange(),
-                                  serializer=self.serializer)
+        self.producer = Producer(self.channel or self.connection,
+                                 exchange=self.get_exchange(),
+                                 serializer=self.serializer)
         self.enabled = True
         for callback in self.on_enabled:
             callback()
@@ -142,7 +142,7 @@ class EventDispatcher(object):
         """
         if self.enabled:
             try:
-                self._send(type, fields, self.producer)
+                self.publish(type, fields, self.producer)
             except Exception, exc:
                 if not self.buffer_while_offline:
                     raise
@@ -162,7 +162,14 @@ class EventDispatcher(object):
     def close(self):
         """Close the event dispatcher."""
         self.mutex.locked() and self.mutex.release()
-        self.publisher = None
+        self.producer = None
+
+    def _get_publisher(self):
+        return self.producer
+
+    def _set_publisher(self, producer):
+        self.producer = producer
+    publisher = property(_get_publisher, _set_publisher)  # XXX compat
 
 
 class EventReceiver(object):
@@ -208,23 +215,22 @@ class EventReceiver(object):
         consumer = Consumer(self.connection,
                             queues=[self.queue], no_ack=True)
         consumer.register_callback(self._receive)
-        with consumer:
-            if wakeup:
-                self.wakeup_workers(channel=consumer.channel)
-            yield consumer
-
-    def itercapture(self, limit=None, timeout=None, wakeup=True):
-        consumer = self.consumer(wakeup=wakeup)
         consumer.consume()
         try:
+            if wakeup:
+                self.wakeup_workers(channel=consumer.channel)
             yield consumer
-            self.drain_events(limit=limit, timeout=timeout)
         finally:
             try:
                 consumer.cancel()
             except self.connection.connection_errors:
                 pass
 
+    def itercapture(self, limit=None, timeout=None, wakeup=True):
+        with self.consumer(wakeup=wakeup) as consumer:
+            yield consumer
+            self.drain_events(limit=limit, timeout=timeout)
+
     def capture(self, limit=None, timeout=None, wakeup=True):
         """Open up a consumer capturing events.
 

+ 0 - 1
celery/events/dumper.py

@@ -17,7 +17,6 @@ from celery.app import app_or_default
 from celery.datastructures import LRUCache
 from celery.utils.timeutils import humanize_seconds
 
-
 TASK_NAMES = LRUCache(limit=0xFFF)
 
 HUMAN_TYPES = {'worker-offline': 'shutdown',

+ 12 - 12
celery/result.py

@@ -25,17 +25,6 @@ from .datastructures import DependencyGraph
 from .exceptions import IncompleteStream, TimeoutError
 
 
-def from_serializable(r):
-    # earlier backends may just pickle, so check if
-    # result is already prepared.
-    if not isinstance(r, ResultBase):
-        id, nodes = r
-        if nodes:
-            return GroupResult(id, [AsyncResult(id) for id, _ in nodes])
-        return AsyncResult(id)
-    return r
-
-
 class ResultBase(object):
     """Base class for all results"""
 
@@ -198,7 +187,7 @@ class AsyncResult(ResultBase):
 
     def __str__(self):
         """`str(self) -> self.id`"""
-        return self.id
+        return str(self.id)
 
     def __hash__(self):
         """`hash(self) -> hash(self.id)`"""
@@ -717,3 +706,14 @@ class EagerResult(AsyncResult):
     @property
     def supports_native_join(self):
         return False
+
+
+def from_serializable(r, Result=AsyncResult):
+    # earlier backends may just pickle, so check if
+    # result is already prepared.
+    if not isinstance(r, ResultBase):
+        id, nodes = r
+        if nodes:
+            return GroupResult(id, [Result(id) for id, _ in nodes])
+        return AsyncResult(id)
+    return r

+ 2 - 1
celery/tests/app/test_app.py

@@ -431,7 +431,7 @@ class test_App(Case):
         class Dispatcher(object):
             sent = []
 
-            def send(self, type, **fields):
+            def publish(self, type, fields, *args, **kwargs):
                 self.sent.append((type, fields))
 
         conn = self.app.connection()
@@ -447,6 +447,7 @@ class test_App(Case):
 
         prod = self.app.amqp.TaskProducer(
             conn, exchange=Exchange('foo_exchange'),
+            send_sent_event=True,
         )
 
         dispatcher = Dispatcher()

+ 16 - 4
celery/tests/backends/test_amqp.py

@@ -137,6 +137,8 @@ class test_AMQPBackend(AppCase):
         results = Queue()
 
         class Message(object):
+            acked = 0
+            requeued = 0
 
             def __init__(self, **merge):
                 self.payload = dict({'status': states.STARTED,
@@ -145,6 +147,12 @@ class test_AMQPBackend(AppCase):
                 self.content_type = 'application/x-python-serialize'
                 self.content_encoding = 'binary'
 
+            def ack(self, *args, **kwargs):
+                self.acked += 1
+
+            def requeue(self, *args, **kwargs):
+                self.requeued += 1
+
         class MockBinding(object):
 
             def __init__(self, *args, **kwargs):
@@ -172,9 +180,13 @@ class test_AMQPBackend(AppCase):
         backend._republish = Mock()
 
         # FFWD's to the latest state.
-        results.put(Message(status=states.RECEIVED, seq=1))
-        results.put(Message(status=states.STARTED, seq=2))
-        results.put(Message(status=states.FAILURE, seq=3))
+        state_messages = [
+            Message(status=states.RECEIVED, seq=1),
+            Message(status=states.STARTED, seq=2),
+            Message(status=states.FAILURE, seq=3),
+        ]
+        for state_message in state_messages:
+            results.put(state_message)
         r1 = backend.get_task_meta(uuid())
         self.assertDictContainsSubset({'status': states.FAILURE,
                                        'seq': 3}, r1,
@@ -186,7 +198,7 @@ class test_AMQPBackend(AppCase):
         backend.get_task_meta(tid)
         self.assertIn(tid, backend._cache, 'Caches last known state')
 
-        self.assertTrue(backend._republish.called)
+        self.assertTrue(state_messages[-1].requeued)
 
         # Returns cache if no new states.
         results.queue.clear()

+ 3 - 3
celery/tests/backends/test_cache.py

@@ -77,12 +77,12 @@ class test_CacheBackend(Case):
 
             tb.on_chord_apply(task.request.group, [])
 
-            self.assertFalse(deps.join.called)
+            self.assertFalse(deps.join_native.called)
             tb.on_chord_part_return(task)
-            self.assertFalse(deps.join.called)
+            self.assertFalse(deps.join_native.called)
 
             tb.on_chord_part_return(task)
-            deps.join.assert_called_with(propagate=False)
+            deps.join_native.assert_called_with(propagate=True)
             deps.delete.assert_called_with()
 
         finally:

+ 1 - 1
celery/tests/backends/test_redis.py

@@ -172,7 +172,7 @@ class test_RedisBackend(Case):
 
             b.client.incr.return_value = len(deps)
             b.on_chord_part_return(task)
-            deps.join.assert_called_with(propagate=False)
+            deps.join_native.assert_called_with(propagate=True)
             deps.delete.assert_called_with()
 
             self.assertTrue(b.client.expire.call_count)

+ 0 - 1
celery/tests/bin/test_celeryevdump.py

@@ -43,6 +43,5 @@ class test_Dumper(Case):
 
     @patch('celery.events.EventReceiver.capture')
     def test_evdump(self, capture):
-        evdump()
         capture.side_effect = KeyboardInterrupt()
         evdump()

+ 20 - 10
celery/tests/events/test_events.py

@@ -42,8 +42,10 @@ class test_EventDispatcher(AppCase):
 
     def test_send(self):
         producer = MockProducer()
-        eventer = self.app.events.Dispatcher(object(), enabled=False)
-        eventer.publisher = producer
+        producer.connection = self.app.connection()
+        eventer = self.app.events.Dispatcher(object(), enabled=False,
+                                             buffer_while_offline=False)
+        eventer.producer = producer
         eventer.enabled = True
         eventer.send('World War II', ended=True)
         self.assertTrue(producer.has_event('World War II'))
@@ -53,14 +55,14 @@ class test_EventDispatcher(AppCase):
 
         evs = ('Event 1', 'Event 2', 'Event 3')
         eventer.enabled = True
-        eventer.publisher.raise_on_publish = True
+        eventer.producer.raise_on_publish = True
         eventer.buffer_while_offline = False
         with self.assertRaises(KeyError):
             eventer.send('Event X')
         eventer.buffer_while_offline = True
         for ev in evs:
             eventer.send(ev)
-        eventer.publisher.raise_on_publish = False
+        eventer.producer.raise_on_publish = False
         eventer.flush()
         for ev in evs:
             self.assertTrue(producer.has_event(ev))
@@ -99,22 +101,30 @@ class test_EventDispatcher(AppCase):
                                                      enabled=True,
                                                      channel=channel)
             self.assertTrue(dispatcher.enabled)
-            self.assertTrue(dispatcher.publisher.channel)
-            self.assertEqual(dispatcher.publisher.serializer,
+            self.assertTrue(dispatcher.producer.channel)
+            self.assertEqual(dispatcher.producer.serializer,
                              self.app.conf.CELERY_EVENT_SERIALIZER)
 
-            created_channel = dispatcher.publisher.channel
+            created_channel = dispatcher.producer.channel
             dispatcher.disable()
-            dispatcher.disable()  # Disable with no active publisher
+            dispatcher.disable()  # Disable with no active producer
             dispatcher2.disable()
             self.assertFalse(dispatcher.enabled)
-            self.assertIsNone(dispatcher.publisher)
+            self.assertIsNone(dispatcher.producer)
             self.assertFalse(dispatcher2.channel.closed,
                              'does not close manually provided channel')
 
             dispatcher.enable()
             self.assertTrue(dispatcher.enabled)
-            self.assertTrue(dispatcher.publisher)
+            self.assertTrue(dispatcher.producer)
+
+            # XXX test compat attribute
+            self.assertIs(dispatcher.publisher, dispatcher.producer)
+            prev, dispatcher.publisher = dispatcher.producer, 42
+            try:
+                self.assertEqual(dispatcher.producer, 42)
+            finally:
+                dispatcher.producer = prev
         finally:
             channel.close()
             connection.close()

+ 4 - 2
celery/tests/tasks/test_chord.py

@@ -70,7 +70,8 @@ class test_unlock_chord_task(AppCase):
                 subtask, canvas.maybe_subtask = canvas.maybe_subtask, passthru
                 try:
                     unlock('group_id', callback_s,
-                           result=map(AsyncResult, [1, 2, 3]))
+                           result=map(AsyncResult, ['1', '2', '3']),
+                           GroupResult=AlwaysReady)
                 finally:
                     canvas.maybe_subtask = subtask
                 callback.apply_async.assert_called_with(([2, 4, 8, 6], ), {})
@@ -90,7 +91,8 @@ class test_unlock_chord_task(AppCase):
             try:
                 callback = Mock()
                 unlock('group_id', callback, interval=10, max_retries=30,
-                       result=map(AsyncResult, [1, 2, 3]))
+                       result=map(AsyncResult, [1, 2, 3]),
+                       GroupResult=NeverReady)
                 self.assertFalse(callback.delay.call_count)
                 # did retry
                 unlock.retry.assert_called_with(countdown=10, max_retries=30)

+ 2 - 1
celery/tests/tasks/test_result.py

@@ -334,7 +334,8 @@ class SimpleBackend(object):
             self.ids = ids
 
         def get_many(self, *args, **kwargs):
-            return ((id, {'result': i}) for i, id in enumerate(self.ids))
+            return ((id, {'result': i, 'status': states.SUCCESS})
+                    for i, id in enumerate(self.ids))
 
 
 class test_TaskSetResult(AppCase):

+ 8 - 19
celery/tests/tasks/test_tasks.py

@@ -383,25 +383,14 @@ class test_tasks(Case):
     def test_send_task_sent_event(self):
         T1 = self.createTask('c.unittest.t.t1')
         app = T1.app
-        conn = app.connection()
-        chan = conn.channel()
-        app.conf.CELERY_SEND_TASK_SENT_EVENT = True
-        dispatcher = [None]
-
-        class Prod(object):
-            channel = chan
-
-            def publish_task(self, *args, **kwargs):
-                dispatcher[0] = kwargs.get('event_dispatcher')
-
-        try:
-            T1.apply_async(producer=Prod())
-        finally:
-            app.conf.CELERY_SEND_TASK_SENT_EVENT = False
-            chan.close()
-            conn.close()
-
-        self.assertTrue(dispatcher[0])
+        with app.connection() as conn:
+            app.conf.CELERY_SEND_TASK_SENT_EVENT = True
+            del(app.amqp.__dict__['TaskProducer'])
+            try:
+                self.assertTrue(app.amqp.TaskProducer(conn).send_sent_event)
+            finally:
+                app.conf.CELERY_SEND_TASK_SENT_EVENT = False
+                del(app.amqp.__dict__['TaskProducer'])
 
     def test_get_publisher(self):
         connection = app_or_default().connection()

+ 4 - 2
celery/tests/worker/test_worker.py

@@ -1073,6 +1073,7 @@ class test_WorkController(AppCase):
         P = w.pool_cls.return_value = Mock()
         P.timers = {Mock(): 30}
         w.use_eventloop = True
+        w.consumer.restart_count = -1
         pool = Pool(w)
         pool.create(w)
         self.assertIsInstance(w.semaphore, BoundedSemaphore)
@@ -1106,6 +1107,7 @@ class test_WorkController(AppCase):
         cbs['on_timeout_set'](result, None, 10)
         cbs['on_timeout_set'](result, None, None)
 
-        P.did_start_ok.return_value = False
         with self.assertRaises(WorkerLostError):
-            pool.on_poll_init(P, hub)
+            P.did_start_ok.return_value = False
+            w.consumer.restart_count = 0
+            pool.on_poll_init(P, w, hub)