Browse Source

Merge branch 'autopool'

Ask Solem 14 years ago
parent
commit
15efd49613

+ 23 - 9
celery/app/amqp.py

@@ -183,6 +183,7 @@ class TaskPublisher(messaging.Publisher):
             event_dispatcher=None, retry=None, retry_policy=None,
             event_dispatcher=None, retry=None, retry_policy=None,
             queue=None, now=None, retries=0, **kwargs):
             queue=None, now=None, retries=0, **kwargs):
         """Send task message."""
         """Send task message."""
+
         connection = self.connection
         connection = self.connection
         _retry_policy = self.retry_policy
         _retry_policy = self.retry_policy
         if retry_policy:  # merge default and custom policy
         if retry_policy:  # merge default and custom policy
@@ -248,14 +249,15 @@ class TaskPublisher(messaging.Publisher):
 
 
 class PublisherPool(Resource):
 class PublisherPool(Resource):
 
 
-    def __init__(self, limit=None, app=None):
+    def __init__(self, app=None):
         self.app = app
         self.app = app
-        self.connections = self.app.broker_connection().Pool(limit=limit)
-        super(PublisherPool, self).__init__(limit=limit)
+        super(PublisherPool, self).__init__(limit=self.app.pool.limit)
 
 
     def create_publisher(self):
     def create_publisher(self):
-        return self.app.amqp.TaskPublisher(self.connections.acquire(),
-                                           auto_declare=False)
+        conn = self.app.pool.acquire(block=True)
+        pub = self.app.amqp.TaskPublisher(conn, auto_declare=False)
+        conn._publisher_chan = pub.channel
+        return pub
 
 
     def new(self):
     def new(self):
         return promise(self.create_publisher)
         return promise(self.create_publisher)
@@ -266,7 +268,18 @@ class PublisherPool(Resource):
                 self._resource.put_nowait(self.new())
                 self._resource.put_nowait(self.new())
 
 
     def prepare(self, publisher):
     def prepare(self, publisher):
-        return maybe_promise(publisher)
+        pub = maybe_promise(publisher)
+        if not pub.connection:
+            pub.connection = self.app.pool.acquire(block=True)
+            if not getattr(pub.connection, "_publisher_chan", None):
+                pub.connection._publisher_chan = pub.connection.channel()
+            pub.revive(pub.connection._publisher_chan)
+        return pub
+
+    def release(self, resource):
+        resource.connection.release()
+        resource.connection = None
+        super(PublisherPool, self).release(resource)
 
 
 
 
 class AMQP(object):
 class AMQP(object):
@@ -327,9 +340,6 @@ class AMQP(object):
                     "app": self}
                     "app": self}
         return TaskPublisher(*args, **self.app.merge(defaults, kwargs))
         return TaskPublisher(*args, **self.app.merge(defaults, kwargs))
 
 
-    def PublisherPool(self, limit=None):
-        return PublisherPool(limit=limit, app=self.app)
-
     def get_task_consumer(self, connection, queues=None, **kwargs):
     def get_task_consumer(self, connection, queues=None, **kwargs):
         """Return consumer configured to consume from all known task
         """Return consumer configured to consume from all known task
         queues."""
         queues."""
@@ -353,3 +363,7 @@ class AMQP(object):
         if self._rtable is None:
         if self._rtable is None:
             self.flush_routes()
             self.flush_routes()
         return self._rtable
         return self._rtable
+
+    @cached_property
+    def publisher_pool(self):
+        return PublisherPool(app=self.app)

+ 21 - 2
celery/app/base.py

@@ -8,6 +8,7 @@ Application Base Class.
 :license: BSD, see LICENSE for more details.
 :license: BSD, see LICENSE for more details.
 
 
 """
 """
+import os
 import platform as _platform
 import platform as _platform
 
 
 from copy import deepcopy
 from copy import deepcopy
@@ -90,6 +91,8 @@ class BaseApp(object):
     log_cls = "celery.log.Logging"
     log_cls = "celery.log.Logging"
     control_cls = "celery.task.control.Control"
     control_cls = "celery.task.control.Control"
 
 
+    _pool = None
+
     def __init__(self, main=None, loader=None, backend=None,
     def __init__(self, main=None, loader=None, backend=None,
             amqp=None, events=None, log=None, control=None,
             amqp=None, events=None, log=None, control=None,
             set_as_current=True, accept_magic_kwargs=False):
             set_as_current=True, accept_magic_kwargs=False):
@@ -243,8 +246,8 @@ class BaseApp(object):
             connection = kwargs.get("connection")
             connection = kwargs.get("connection")
             timeout = kwargs.get("connect_timeout")
             timeout = kwargs.get("connect_timeout")
             kwargs["connection"] = conn = connection or \
             kwargs["connection"] = conn = connection or \
-                    self.broker_connection(connect_timeout=timeout)
-            close_connection = not connection and conn.close or None
+                    self.pool.acquire(block=True)
+            close_connection = not connection and conn.release or None
 
 
             try:
             try:
                 return fun(*args, **kwargs)
                 return fun(*args, **kwargs)
@@ -300,6 +303,22 @@ class BaseApp(object):
         return ConfigurationView({},
         return ConfigurationView({},
                 [self.prepare_config(self.loader.conf), deepcopy(DEFAULTS)])
                 [self.prepare_config(self.loader.conf), deepcopy(DEFAULTS)])
 
 
+    def _after_fork(self, obj_):
+        if self._pool:
+            self._pool.force_close_all()
+            self._pool = None
+
+    @property
+    def pool(self):
+        if self._pool is None:
+            try:
+                from multiprocessing.util import register_after_fork
+                register_after_fork(self, self._after_fork)
+            except ImportError:
+                pass
+            self._pool = self.broker_connection().Pool(self.conf.BROKER_POOL_LIMIT)
+        return self._pool
+
     @cached_property
     @cached_property
     def amqp(self):
     def amqp(self):
         """Sending/receiving messages.  See :class:`~celery.app.amqp.AMQP`."""
         """Sending/receiving messages.  See :class:`~celery.app.amqp.AMQP`."""

+ 4 - 3
celery/app/defaults.py

@@ -50,6 +50,7 @@ NAMESPACES = {
         "CONNECTION_TIMEOUT": Option(4, type="int"),
         "CONNECTION_TIMEOUT": Option(4, type="int"),
         "CONNECTION_RETRY": Option(True, type="bool"),
         "CONNECTION_RETRY": Option(True, type="bool"),
         "CONNECTION_MAX_RETRIES": Option(100, type="int"),
         "CONNECTION_MAX_RETRIES": Option(100, type="int"),
+        "POOL_LIMIT": Option(None, type="int"),
         "INSIST": Option(False, type="bool"),
         "INSIST": Option(False, type="bool"),
         "USE_SSL": Option(False, type="bool"),
         "USE_SSL": Option(False, type="bool"),
         "TRANSPORT_OPTIONS": Option({}, type="dict")
         "TRANSPORT_OPTIONS": Option({}, type="dict")
@@ -91,11 +92,11 @@ NAMESPACES = {
         "SEND_TASK_SENT_EVENT": Option(False, type="bool"),
         "SEND_TASK_SENT_EVENT": Option(False, type="bool"),
         "STORE_ERRORS_EVEN_IF_IGNORED": Option(False, type="bool"),
         "STORE_ERRORS_EVEN_IF_IGNORED": Option(False, type="bool"),
         "TASK_ERROR_WHITELIST": Option((), type="tuple"),
         "TASK_ERROR_WHITELIST": Option((), type="tuple"),
-        "TASK_PUBLISH_RETRY": Option(False, type="bool"),
+        "TASK_PUBLISH_RETRY": Option(True, type="bool"),
         "TASK_PUBLISH_RETRY_POLICY": Option({
         "TASK_PUBLISH_RETRY_POLICY": Option({
-                "max_retries": 3,
+                "max_retries": 100,
                 "interval_start": 0,
                 "interval_start": 0,
-                "interval_max": 0.2,
+                "interval_max": 1,
                 "interval_step": 0.2}, type="dict"),
                 "interval_step": 0.2}, type="dict"),
         "TASK_RESULT_EXPIRES": Option(timedelta(days=1), type="int"),
         "TASK_RESULT_EXPIRES": Option(timedelta(days=1), type="int"),
         "TASK_SERIALIZER": Option("pickle"),
         "TASK_SERIALIZER": Option("pickle"),

+ 4 - 23
celery/backends/amqp.py

@@ -35,9 +35,6 @@ class AMQPBackend(BaseDictBackend):
 
 
     BacklogLimitExceeded = BacklogLimitExceeded
     BacklogLimitExceeded = BacklogLimitExceeded
 
 
-    _pool = None
-    _pool_owner_pid = None
-
     def __init__(self, connection=None, exchange=None, exchange_type=None,
     def __init__(self, connection=None, exchange=None, exchange_type=None,
             persistent=None, serializer=None, auto_delete=True,
             persistent=None, serializer=None, auto_delete=True,
             expires=None, connection_max=None, **kwargs):
             expires=None, connection_max=None, **kwargs):
@@ -109,7 +106,7 @@ class AMQPBackend(BaseDictBackend):
         """Send task return value and status."""
         """Send task return value and status."""
         self.mutex.acquire()
         self.mutex.acquire()
         try:
         try:
-            conn = self.pool.acquire(block=True)
+            conn = self.app.pool.acquire(block=True)
             try:
             try:
 
 
                 def errback(error, delay):
                 def errback(error, delay):
@@ -159,7 +156,7 @@ class AMQPBackend(BaseDictBackend):
             return self.wait_for(task_id, timeout, cache)
             return self.wait_for(task_id, timeout, cache)
 
 
     def poll(self, task_id, backlog_limit=100):
     def poll(self, task_id, backlog_limit=100):
-        conn = self.pool.acquire(block=True)
+        conn = self.app.pool.acquire(block=True)
         channel = conn.channel()
         channel = conn.channel()
         try:
         try:
             binding = self._create_binding(task_id)(channel)
             binding = self._create_binding(task_id)(channel)
@@ -203,7 +200,7 @@ class AMQPBackend(BaseDictBackend):
         return results
         return results
 
 
     def consume(self, task_id, timeout=None):
     def consume(self, task_id, timeout=None):
-        conn = self.pool.acquire(block=True)
+        conn = self.app.pool.acquire(block=True)
         channel = conn.channel()
         channel = conn.channel()
         try:
         try:
             binding = self._create_binding(task_id)
             binding = self._create_binding(task_id)
@@ -218,7 +215,7 @@ class AMQPBackend(BaseDictBackend):
             conn.release()
             conn.release()
 
 
     def get_many(self, task_ids, timeout=None):
     def get_many(self, task_ids, timeout=None):
-        conn = self.pool.acquire(block=True)
+        conn = self.app.pool.acquire(block=True)
         channel = conn.channel()
         channel = conn.channel()
         try:
         try:
             ids = set(task_ids)
             ids = set(task_ids)
@@ -273,19 +270,3 @@ class AMQPBackend(BaseDictBackend):
         """Get the result of a taskset."""
         """Get the result of a taskset."""
         raise NotImplementedError(
         raise NotImplementedError(
                 "restore_taskset is not supported by this backend.")
                 "restore_taskset is not supported by this backend.")
-
-    def _reset_after_fork(self, *args):
-        if self._pool:
-            self._pool.force_close_all()
-            self._pool = None
-
-    @property
-    def pool(self):
-        if self._pool is None:
-            self._pool = self.app.broker_connection().Pool(self.connection_max)
-            try:
-                from multiprocessing.util import register_after_fork
-                register_after_fork(self, self._reset_after_fork)
-            except ImportError:
-                pass
-        return self._pool

+ 2 - 5
celery/task/base.py

@@ -436,9 +436,7 @@ class BaseTask(object):
         exchange_type = options.get("exchange_type")
         exchange_type = options.get("exchange_type")
         expires = expires or self.expires
         expires = expires or self.expires
 
 
-        publish = publisher or self.get_publisher(connection,
-                                                  exchange=exchange,
-                                                  exchange_type=exchange_type)
+        publish = publisher or self.app.amqp.publisher_pool.acquire(block=True)
         evd = None
         evd = None
         if conf.CELERY_SEND_TASK_SENT_EVENT:
         if conf.CELERY_SEND_TASK_SENT_EVENT:
             evd = self.app.events.Dispatcher(channel=publish.channel,
             evd = self.app.events.Dispatcher(channel=publish.channel,
@@ -453,8 +451,7 @@ class BaseTask(object):
                                          **options)
                                          **options)
         finally:
         finally:
             if not publisher:
             if not publisher:
-                publish.close()
-                publish.connection.close()
+                publish.release()
 
 
         return self.AsyncResult(task_id)
         return self.AsyncResult(task_id)
 
 

+ 2 - 1
celery/tests/test_backends/test_amqp.py

@@ -4,6 +4,7 @@ import sys
 from datetime import timedelta
 from datetime import timedelta
 from Queue import Empty, Queue
 from Queue import Empty, Queue
 
 
+from celery import current_app
 from celery import states
 from celery import states
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.backends.amqp import AMQPBackend
 from celery.backends.amqp import AMQPBackend
@@ -202,7 +203,7 @@ class test_AMQPBackend(unittest.TestCase):
                 pass
                 pass
 
 
         b = self.create_backend()
         b = self.create_backend()
-        conn = b.pool.acquire(block=False)
+        conn = current_app.pool.acquire(block=False)
         channel = conn.channel()
         channel = conn.channel()
         try:
         try:
             binding = b._create_binding(gen_unique_id())
             binding = b._create_binding(gen_unique_id())