Преглед на файлове

Needs to copy buffer into bytes early so that librabbitmq does not release the buffer

Ask Solem преди 11 години
родител
ревизия
acee6680ce
променени са 6 файла, в които са добавени 22 реда и са изтрити 15 реда
  1. 1 0
      celery/concurrency/base.py
  2. 1 0
      celery/concurrency/solo.py
  3. 8 0
      celery/five.py
  4. 1 9
      celery/worker/consumer.py
  5. 3 4
      celery/worker/request.py
  6. 8 2
      celery/worker/strategy.py

+ 1 - 0
celery/concurrency/base.py

@@ -72,6 +72,7 @@ class BasePool(object):
     uses_semaphore = False
 
     task_join_will_block = True
+    body_can_be_buffer = False
 
     def __init__(self, limit=None, putlocks=True,
                  forking_enable=True, callbacks_propagate=(), **options):

+ 1 - 0
celery/concurrency/solo.py

@@ -17,6 +17,7 @@ __all__ = ['TaskPool']
 
 class TaskPool(BasePool):
     """Solo task pool (blocking, inline, fast)."""
+    body_can_be_buffer = True
 
     def __init__(self, *args, **kwargs):
         super(TaskPool, self).__init__(*args, **kwargs)

+ 8 - 0
celery/five.py

@@ -28,6 +28,14 @@ except ImportError:  # pragma: no cover
     def Counter():  # noqa
         return defaultdict(int)
 
+try:
+    buffer_t = buffer
+except NameError:  # pragma: no cover
+    # Py3 does not have buffer, but we only need isinstance.
+
+    class buffer_t(object):  # noqa
+        pass
+
 ############## py3k #########################################################
 import sys
 PY3 = sys.version_info[0] == 3

+ 1 - 9
celery/worker/consumer.py

@@ -35,7 +35,7 @@ from celery import bootsteps
 from celery.app.trace import build_tracer
 from celery.canvas import signature
 from celery.exceptions import InvalidTaskError
-from celery.five import items, values
+from celery.five import buffer_t, items, values
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
 from celery.utils.text import truncate
@@ -44,14 +44,6 @@ from celery.utils.timeutils import humanize_seconds, rate
 from . import heartbeat, loops, pidbox
 from .state import task_reserved, maybe_shutdown, revoked, reserved_requests
 
-try:
-    buffer_t = buffer
-except NameError:  # pragma: no cover
-    # Py3 does not have buffer, but we only need isinstance.
-
-    class buffer_t(object):  # noqa
-        pass
-
 __all__ = [
     'Consumer', 'Connection', 'Events', 'Heart', 'Control',
     'Tasks', 'Evloop', 'Agent', 'Mingle', 'Gossip', 'dump_body',

+ 3 - 4
celery/worker/request.py

@@ -108,13 +108,13 @@ class Request(object):
     def __init__(self, message, on_ack=noop,
                  hostname=None, eventer=None, app=None,
                  connection_errors=None, request_dict=None,
-                 task=None, on_reject=noop, **opts):
+                 task=None, on_reject=noop, body=None, **opts):
         headers = message.headers
         self.app = app
         self.message = message
         name = self.name = headers['c_type']
         self.id = headers['id']
-        self.body = message.body
+        self.body = message.body if body is None else body
         self.content_type = message.content_type
         self.content_encoding = message.content_encoding
         eta = headers.get('eta')
@@ -192,8 +192,7 @@ class Request(object):
         soft_timeout = soft_timeout or task.soft_time_limit
         result = pool.apply_async(
             trace_task_ret,
-            args=(self.name, task_id, self.request_dict,
-                  bytes(body) if isinstance(body, buffer) else body,
+            args=(self.name, task_id, self.request_dict, self.body,
                   self.content_type, self.content_encoding),
             kwargs={'hostname': self.hostname, 'is_eager': False},
             accept_callback=self.on_accepted,

+ 8 - 2
celery/worker/strategy.py

@@ -12,6 +12,7 @@ import logging
 
 from kombu.async.timer import to_timestamp
 
+from celery.five import buffer_t
 from celery.utils.log import get_logger
 from celery.utils.timeutils import timezone
 
@@ -25,7 +26,7 @@ logger = get_logger(__name__)
 
 def default(task, app, consumer,
             info=logger.info, error=logger.error, task_reserved=task_reserved,
-            to_system_tz=timezone.to_system):
+            to_system_tz=timezone.to_system, bytes=bytes, buffer_t=buffer_t):
     hostname = consumer.hostname
     eventer = consumer.event_dispatcher
     ReqV2 = Request
@@ -40,14 +41,19 @@ def default(task, app, consumer,
     bucket = consumer.task_buckets[task.name]
     handle = consumer.on_task_request
     limit_task = consumer._limit_task
+    body_can_be_buffer = consumer.pool.body_can_be_buffer
 
     def task_message_handler(message, body, ack, reject, callbacks,
                              to_timestamp=to_timestamp):
         if body is None:
+            body = message.body
+            if not body_can_be_buffer:
+                body = bytes(body) if isinstance(body, buffer_t) else body
             req = ReqV2(message,
                         on_ack=ack, on_reject=reject, app=app,
                         hostname=hostname, eventer=eventer, task=task,
-                        connection_errors=connection_errors)
+                        connection_errors=connection_errors,
+                        body=body)
         else:
             req = ReqV1(body,
                         on_ack=ack, on_reject=reject, app=app,