소스 검색

Fixes event dispatcher bug

Ask Solem 12 년 전
부모
커밋
db72777857
1개의 변경된 파일10개의 추가작업 그리고 12개의 파일을 삭제
  1. 10 12
      celery/events/__init__.py

+ 10 - 12
celery/events/__init__.py

@@ -69,6 +69,7 @@ class EventDispatcher(object):
     You need to :meth:`close` this after use.
 
     """
+    DISABLED_TRANSPORTS = set(['sql'])
 
     def __init__(self, connection=None, hostname=None, enabled=True,
                  channel=None, buffer_while_offline=True, app=None,
@@ -88,6 +89,11 @@ class EventDispatcher(object):
         self.enabled = enabled
         if not connection and channel:
             self.connection = channel.connection.client
+        self.enabled = enabled
+        conninfo = self.connection or self.app.connection()
+        self.exchange = get_exchange(conninfo)
+        if conninfo.transport.driver_type in self.DISABLED_TRANSPORTS:
+            self.enabled = False
         if self.enabled:
             self.enable()
 
@@ -97,15 +103,9 @@ class EventDispatcher(object):
     def __exit__(self, *exc_info):
         self.close()
 
-    def get_exchange(self):
-        if self.connection:
-            return get_exchange(self.connection)
-        else:
-            return get_exchange(self.channel.connection.client)
-
     def enable(self):
         self.producer = Producer(self.channel or self.connection,
-                                 exchange=self.get_exchange(),
+                                 exchange=self.exchange,
                                  serializer=self.serializer)
         self.enabled = True
         for callback in self.on_enabled:
@@ -122,7 +122,7 @@ class EventDispatcher(object):
         with self.mutex:
             event = Event(type, hostname=self.hostname,
                           clock=self.app.clock.forward(), **fields)
-            exchange = get_exchange(producer.connection)
+            exchange = self.exchange
             producer.publish(
                 event,
                 routing_key=type.replace('-', '.'),
@@ -194,15 +194,13 @@ class EventReceiver(object):
         self.routing_key = routing_key
         self.node_id = node_id or uuid()
         self.queue_prefix = queue_prefix
+        self.exchange = get_exchange(self.connection or self.app.connection())
         self.queue = Queue('.'.join([self.queue_prefix, self.node_id]),
-                           exchange=self.get_exchange(),
+                           exchange=self.exchange,
                            routing_key=self.routing_key,
                            auto_delete=True,
                            durable=False)
 
-    def get_exchange(self):
-        return get_exchange(self.connection)
-
     def process(self, type, event):
         """Process the received event by dispatching it to the appropriate
         handler."""