فهرست منبع

Needs to copy buffer into bytes early so that librabbitmq does not release the buffer

Ask Solem 11 سال پیش
والد
کامیت
acee6680ce
6فایلهای تغییر یافته به همراه22 افزوده شده و 15 حذف شده
  1. 1 0
      celery/concurrency/base.py
  2. 1 0
      celery/concurrency/solo.py
  3. 8 0
      celery/five.py
  4. 1 9
      celery/worker/consumer.py
  5. 3 4
      celery/worker/request.py
  6. 8 2
      celery/worker/strategy.py

+ 1 - 0
celery/concurrency/base.py

@@ -72,6 +72,7 @@ class BasePool(object):
     uses_semaphore = False
 
     task_join_will_block = True
+    body_can_be_buffer = False
 
     def __init__(self, limit=None, putlocks=True,
                  forking_enable=True, callbacks_propagate=(), **options):

+ 1 - 0
celery/concurrency/solo.py

@@ -17,6 +17,7 @@ __all__ = ['TaskPool']
 
 class TaskPool(BasePool):
     """Solo task pool (blocking, inline, fast)."""
+    body_can_be_buffer = True
 
     def __init__(self, *args, **kwargs):
         super(TaskPool, self).__init__(*args, **kwargs)

+ 8 - 0
celery/five.py

@@ -28,6 +28,14 @@ except ImportError:  # pragma: no cover
     def Counter():  # noqa
         return defaultdict(int)
 
+try:
+    buffer_t = buffer
+except NameError:  # pragma: no cover
+    # Py3 does not have buffer, but we only need isinstance.
+
+    class buffer_t(object):  # noqa
+        pass
+
 ############## py3k #########################################################
 import sys
 PY3 = sys.version_info[0] == 3

+ 1 - 9
celery/worker/consumer.py

@@ -35,7 +35,7 @@ from celery import bootsteps
 from celery.app.trace import build_tracer
 from celery.canvas import signature
 from celery.exceptions import InvalidTaskError
-from celery.five import items, values
+from celery.five import buffer_t, items, values
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
 from celery.utils.text import truncate
@@ -44,14 +44,6 @@ from celery.utils.timeutils import humanize_seconds, rate
 from . import heartbeat, loops, pidbox
 from .state import task_reserved, maybe_shutdown, revoked, reserved_requests
 
-try:
-    buffer_t = buffer
-except NameError:  # pragma: no cover
-    # Py3 does not have buffer, but we only need isinstance.
-
-    class buffer_t(object):  # noqa
-        pass
-
 __all__ = [
     'Consumer', 'Connection', 'Events', 'Heart', 'Control',
     'Tasks', 'Evloop', 'Agent', 'Mingle', 'Gossip', 'dump_body',

+ 3 - 4
celery/worker/request.py

@@ -108,13 +108,13 @@ class Request(object):
     def __init__(self, message, on_ack=noop,
                  hostname=None, eventer=None, app=None,
                  connection_errors=None, request_dict=None,
-                 task=None, on_reject=noop, **opts):
+                 task=None, on_reject=noop, body=None, **opts):
         headers = message.headers
         self.app = app
         self.message = message
         name = self.name = headers['c_type']
         self.id = headers['id']
-        self.body = message.body
+        self.body = message.body if body is None else body
         self.content_type = message.content_type
         self.content_encoding = message.content_encoding
         eta = headers.get('eta')
@@ -192,8 +192,7 @@ class Request(object):
         soft_timeout = soft_timeout or task.soft_time_limit
         result = pool.apply_async(
             trace_task_ret,
-            args=(self.name, task_id, self.request_dict,
-                  bytes(body) if isinstance(body, buffer) else body,
+            args=(self.name, task_id, self.request_dict, self.body,
                   self.content_type, self.content_encoding),
             kwargs={'hostname': self.hostname, 'is_eager': False},
             accept_callback=self.on_accepted,

+ 8 - 2
celery/worker/strategy.py

@@ -12,6 +12,7 @@ import logging
 
 from kombu.async.timer import to_timestamp
 
+from celery.five import buffer_t
 from celery.utils.log import get_logger
 from celery.utils.timeutils import timezone
 
@@ -25,7 +26,7 @@ logger = get_logger(__name__)
 
 def default(task, app, consumer,
             info=logger.info, error=logger.error, task_reserved=task_reserved,
-            to_system_tz=timezone.to_system):
+            to_system_tz=timezone.to_system, bytes=bytes, buffer_t=buffer_t):
     hostname = consumer.hostname
     eventer = consumer.event_dispatcher
     ReqV2 = Request
@@ -40,14 +41,19 @@ def default(task, app, consumer,
     bucket = consumer.task_buckets[task.name]
     handle = consumer.on_task_request
     limit_task = consumer._limit_task
+    body_can_be_buffer = consumer.pool.body_can_be_buffer
 
     def task_message_handler(message, body, ack, reject, callbacks,
                              to_timestamp=to_timestamp):
         if body is None:
+            body = message.body
+            if not body_can_be_buffer:
+                body = bytes(body) if isinstance(body, buffer_t) else body
             req = ReqV2(message,
                         on_ack=ack, on_reject=reject, app=app,
                         hostname=hostname, eventer=eventer, task=task,
-                        connection_errors=connection_errors)
+                        connection_errors=connection_errors,
+                        body=body)
         else:
             req = ReqV1(body,
                         on_ack=ack, on_reject=reject, app=app,