Browse Source

Now supports task message protocol 1.0 again

Ask Solem 11 years ago
parent
commit
f838ace597

+ 1 - 1
celery/app/amqp.py

@@ -296,7 +296,7 @@ class AMQP(object):
         return task_message(
             headers={
                 'lang': 'py',
-                'c_type': name,
+                'task': name,
                 'id': task_id,
                 'eta': eta,
                 'expires': expires,

+ 0 - 2
celery/app/base.py

@@ -9,12 +9,10 @@
 from __future__ import absolute_import
 
 import os
-import sys
 import threading
 import warnings
 
 from collections import defaultdict, deque
-from contextlib import contextmanager
 from copy import deepcopy
 from operator import attrgetter
 

+ 9 - 4
celery/app/trace.py

@@ -478,11 +478,16 @@ def _fast_trace_task_v1(task, uuid, args, kwargs, request={}, _loc=_localized):
 
 def _fast_trace_task(task, uuid, request, body, content_type,
                      content_encoding, loads=loads_message, _loc=_localized,
-                     **extra_request):
+                     hostname=None, **_):
     tasks, accept = _loc
-    args, kwargs = loads(body, content_type, content_encoding,
-                         accept=accept)
-    request.update(args=args, kwargs=kwargs, **extra_request)
+    if content_type:
+        args, kwargs = loads(body, content_type, content_encoding,
+                             accept=accept)
+    else:
+        args, kwargs = body
+    request.update({
+        'args': args, 'kwargs': kwargs, 'hostname': hostname,
+    })
     R, I, T, Rstr = tasks[task].__trace__(
         uuid, args, kwargs, request,
     )

+ 1 - 1
celery/tests/worker/test_loops.py

@@ -155,7 +155,7 @@ class test_asynloop(AppCase):
 
     def test_on_task_message_missing_name(self):
         x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
-        msg.headers.pop('c_type')
+        msg.headers.pop('task')
         on_task(msg)
         x.on_unknown_message.assert_called_with(((2, 2), {}), msg)
 

+ 3 - 1
celery/utils/objects.py

@@ -74,7 +74,9 @@ class FallbackContext(object):
     def __enter__(self):
         if self.provided is not None:
             return self.provided
-        context = self._context = self.fallback(*self.fb_args, **self.fb_kwargs).__enter__()
+        context = self._context = self.fallback(
+            *self.fb_args, **self.fb_kwargs
+        ).__enter__()
         return context
 
     def __exit__(self, *exc_info):

+ 17 - 9
celery/worker/consumer.py

@@ -445,24 +445,32 @@ class Consumer(object):
             # will defer deserializing the message body to the pool.
             payload = None
             try:
-                type_ = message.headers['c_type']   # protocol v2
+                type_ = message.headers['task']                # protocol v2
             except TypeError:
                 return on_unknown_message(None, message)
             except KeyError:
                 payload = message.payload
                 try:
-                    type_ = payload['task']         # protocol v1
+                    type_, payload = payload['task'], payload  # protocol v1
                 except (TypeError, KeyError):
                     return on_unknown_message(payload, message)
             try:
-                strategies[type_](
-                    message, None,
-                    message.ack_log_error, message.reject_log_error, callbacks,
-                )
+                strategy = strategies[type_]
             except KeyError as exc:
-                on_unknown_task(payload, message, exc)
-            except InvalidTaskError as exc:
-                on_invalid_task(payload, message, exc)
+                return on_unknown_task(payload, message, exc)
+            else:
+                try:
+                    strategy(
+                        message, payload, message.ack_log_error,
+                        message.reject_log_error, callbacks,
+                    )
+                except InvalidTaskError as exc:
+                    return on_invalid_task(payload, message, exc)
+                except MemoryError:
+                    raise
+                except Exception as exc:
+                    # XXX handle as internal error?
+                    return on_invalid_task(payload, message, exc)
 
         return on_task_received
 

+ 93 - 50
celery/worker/request.py

@@ -44,9 +44,6 @@ debug, info, warn, error = (logger.debug, logger.info,
 _does_info = False
 _does_debug = False
 
-#: Max length of result representation
-RESULT_MAXLEN = 128
-
 
 def __optimize__():
     # this is also called by celery.app.trace.setup_worker_optimizations
@@ -65,75 +62,63 @@ task_accepted = state.task_accepted
 task_ready = state.task_ready
 revoked_tasks = state.revoked
 
-#: Use when no message object passed to :class:`Request`.
-DEFAULT_FIELDS = {
-    'headers': None,
-    'reply_to': None,
-    'correlation_id': None,
-    'delivery_info': {
-        'exchange': None,
-        'routing_key': None,
-        'priority': 0,
-        'redelivered': False,
-    },
-}
-
-
-class RequestV1(object):
-    if not IS_PYPY:
-        __slots__ = (
-            'app', 'message', 'name', 'id', 'root_id', 'parent_id',
-            'on_ack', 'hostname', 'eventer', 'connection_errors', 'task',
-            'eta', 'expires', 'request_dict', 'acknowledged', 'on_reject',
-            'utc', 'time_start', 'worker_pid', '_already_revoked',
-            '_terminate_on_ack', '_apply_result',
-            '_tzlocal', '__weakref__', '__dict__',
-        )
-
 
 class Request(object):
     """A request for task execution."""
-    utc = True
+    acknowledged = False
+    time_start = None
+    worker_pid = None
+    timeouts = (None, None)
+    _already_revoked = False
+    _terminate_on_ack = None
+    _apply_result = None
+    _tzlocal = None
+
     if not IS_PYPY:  # pragma: no cover
         __slots__ = (
             'app', 'name', 'id', 'on_ack', 'body',
             'hostname', 'eventer', 'connection_errors', 'task', 'eta',
-            'expires', 'request_dict', 'acknowledged', 'on_reject',
-            'utc', 'time_start', 'worker_pid', 'timeouts',
+            'expires', 'request_dict', 'on_reject', 'utc',
             'content_type', 'content_encoding',
-            '_already_revoked', '_terminate_on_ack', '_apply_result',
-            '_tzlocal', '__weakref__', '__dict__',
+            '__weakref__', '__dict__',
         )
 
     def __init__(self, message, on_ack=noop,
                  hostname=None, eventer=None, app=None,
                  connection_errors=None, request_dict=None,
-                 task=None, on_reject=noop, body=None, **opts):
-        headers = message.headers
+                 task=None, on_reject=noop, body=None,
+                 headers=None, decoded=False, utc=True,
+                 maybe_make_aware=maybe_make_aware,
+                 maybe_iso8601=maybe_iso8601, **opts):
+        if headers is None:
+            headers = message.headers
+        if body is None:
+            body = message.body
         self.app = app
         self.message = message
-        name = self.name = headers['c_type']
+        self.body = body
+        self.utc = utc
+        if decoded:
+            self.content_type = self.content_encoding = None
+        else:
+            self.content_type, self.content_encoding = (
+                message.content_type, message.content_encoding,
+                )
+
+        name = self.name = headers['task']
         self.id = headers['id']
-        self.body = message.body if body is None else body
-        self.content_type = message.content_type
-        self.content_encoding = message.content_encoding
-        eta = headers.get('eta')
-        expires = headers.get('expires')
-        self.timeouts = (headers['timeouts'] if 'timeouts' in headers
-                         else (None, None))
+        if 'timeouts' in headers:
+            self.timeouts = headers['timeouts']
         self.on_ack = on_ack
         self.on_reject = on_reject
         self.hostname = hostname or socket.gethostname()
         self.eventer = eventer
         self.connection_errors = connection_errors or ()
         self.task = task or self.app.tasks[name]
-        self.acknowledged = self._already_revoked = False
-        self.time_start = self.worker_pid = self._terminate_on_ack = None
-        self._apply_result = None
-        self._tzlocal = None
 
         # timezone means the message is timezone-aware, and the only timezone
         # supported at this point is UTC.
+        eta = headers.get('eta')
         if eta is not None:
             try:
                 eta = maybe_iso8601(eta)
@@ -143,6 +128,8 @@ class Request(object):
             self.eta = maybe_make_aware(eta, self.tzlocal)
         else:
             self.eta = None
+
+        expires = headers.get('expires')
         if expires is not None:
             try:
                 expires = maybe_iso8601(expires)
@@ -186,15 +173,13 @@ class Request(object):
         if self.revoked():
             raise TaskRevokedError(task_id)
 
-        body = self.body
         timeout, soft_timeout = self.timeouts
         timeout = timeout or task.time_limit
         soft_timeout = soft_timeout or task.soft_time_limit
         result = pool.apply_async(
             trace_task_ret,
             args=(self.name, task_id, self.request_dict, self.body,
-                  self.content_type, self.content_encoding),
-            kwargs={'hostname': self.hostname, 'is_eager': False},
+                  self.content_type, self.content_encoding, self.hostname),
             accept_callback=self.on_accepted,
             timeout_callback=self.on_timeout,
             callback=self.on_success,
@@ -449,3 +434,61 @@ class Request(object):
     def correlation_id(self):
         # used similarly to reply_to
         return self.request_dict['correlation_id']
+
+
+def create_request_cls(base, task, pool, hostname, eventer,
+                       ref=ref, revoked_tasks=revoked_tasks,
+                       task_ready=task_ready):
+    from celery.app.trace import trace_task_ret as trace
+    default_time_limit = task.time_limit
+    default_soft_time_limit = task.soft_time_limit
+    apply_async = pool.apply_async
+    acks_late = task.acks_late
+    std_kwargs = {'hostname': hostname, 'is_eager': False}
+    events = eventer and eventer.enabled
+
+    class Request(base):
+
+        def execute_using_pool(self, pool, **kwargs):
+            task_id = self.id
+            if (self.expires or task_id in revoked_tasks) and self.revoked():
+                raise TaskRevokedError(task_id)
+
+            timeout, soft_timeout = self.timeouts
+            timeout = timeout or default_time_limit
+            soft_timeout = soft_timeout or default_soft_time_limit
+            result = apply_async(
+                trace,
+                args=(self.name, task_id, self.request_dict, self.body,
+                      self.content_type, self.content_encoding),
+                kwargs=std_kwargs,
+                accept_callback=self.on_accepted,
+                timeout_callback=self.on_timeout,
+                callback=self.on_success,
+                error_callback=self.on_failure,
+                soft_timeout=soft_timeout,
+                timeout=timeout,
+                correlation_id=task_id,
+            )
+            # cannot create weakref to None
+            self._apply_result = ref(result) if result is not None else result
+            return result
+
+        def on_success(self, failed__retval__runtime, **kwargs):
+            failed, retval, runtime = failed__retval__runtime
+            if failed:
+                if isinstance(retval.exception, (
+                        SystemExit, KeyboardInterrupt)):
+                    raise retval.exception
+                return self.on_failure(retval, return_ok=True)
+            task_ready(self)
+
+            if acks_late:
+                self.acknowledge()
+
+            if events:
+                self.send_event(
+                    'task-succeeded', result=retval, runtime=runtime,
+                )
+
+    return Request

+ 37 - 16
celery/worker/strategy.py

@@ -12,11 +12,12 @@ import logging
 
 from kombu.async.timer import to_timestamp
 
+from celery.exceptions import InvalidTaskError
 from celery.five import buffer_t
 from celery.utils.log import get_logger
 from celery.utils.timeutils import timezone
 
-from .request import Request, RequestV1
+from .request import Request, create_request_cls
 from .state import task_reserved
 
 __all__ = ['default']
@@ -24,13 +25,31 @@ __all__ = ['default']
 logger = get_logger(__name__)
 
 
+def proto1_to_proto2(message, body):
+    """Converts Task message protocol 1 arguments to protocol 2.
+
+    Returns tuple of ``(body, headers, already_decoded_status, utc)``
+
+    """
+    try:
+        args, kwargs = body['args'], body['kwargs']
+        kwargs.items
+    except KeyError:
+        raise InvalidTaskError('Message does not have args/kwargs')
+    except AttributeError:
+        raise InvalidTaskError(
+            'Task keyword arguments must be a mapping',
+        )
+    body['headers'] = message.headers
+    return (args, kwargs), body, True, body.get('utc', True)
+
+
 def default(task, app, consumer,
             info=logger.info, error=logger.error, task_reserved=task_reserved,
-            to_system_tz=timezone.to_system, bytes=bytes, buffer_t=buffer_t):
+            to_system_tz=timezone.to_system, bytes=bytes, buffer_t=buffer_t,
+            proto1_to_proto2=proto1_to_proto2):
     hostname = consumer.hostname
     eventer = consumer.event_dispatcher
-    ReqV2 = Request
-    ReqV1 = RequestV1
     connection_errors = consumer.connection_errors
     _does_info = logger.isEnabledFor(logging.INFO)
     events = eventer and eventer.enabled
@@ -42,25 +61,27 @@ def default(task, app, consumer,
     handle = consumer.on_task_request
     limit_task = consumer._limit_task
     body_can_be_buffer = consumer.pool.body_can_be_buffer
+    Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)
+
+    revoked_tasks = consumer.controller.state.revoked
 
     def task_message_handler(message, body, ack, reject, callbacks,
                              to_timestamp=to_timestamp):
         if body is None:
-            body = message.body
+            body, headers, decoded, utc = (
+                message.body, message.headers, False, True,
+            )
             if not body_can_be_buffer:
                 body = bytes(body) if isinstance(body, buffer_t) else body
-            req = ReqV2(message,
-                        on_ack=ack, on_reject=reject, app=app,
-                        hostname=hostname, eventer=eventer, task=task,
-                        connection_errors=connection_errors,
-                        body=body)
         else:
-            req = ReqV1(body,
-                        on_ack=ack, on_reject=reject, app=app,
-                        hostname=hostname, eventer=eventer, task=task,
-                        connection_errors=connection_errors,
-                        message=message)
-        if req.revoked():
+            body, headers, decoded, utc = proto1_to_proto2(message, body)
+        req = Req(
+            message,
+            on_ack=ack, on_reject=reject, app=app, hostname=hostname,
+            eventer=eventer, task=task, connection_errors=connection_errors,
+            body=body, headers=headers, decoded=decoded, utc=utc,
+        )
+        if (req.expires or req.id in revoked_tasks) and req.revoked():
             return
 
         if _does_info:

+ 6 - 6
docs/internals/protov2.rst

@@ -28,9 +28,9 @@ Notes
 
     - Java/C, etc. can use a thrift/protobuf document as the body
 
-- Dispatches to actor based on ``c_type``, ``c_meth`` headers
+- Dispatches to actor based on ``task``, ``meth`` headers
 
-    ``c_meth`` is unused by python, but may be used in the future
+    ``meth`` is unused by python, but may be used in the future
     to specify class+method pairs.
 
 - Chain gains a dedicated field.
@@ -52,7 +52,7 @@ Notes
 
 - ``root_id`` and ``parent_id`` fields helps keep track of workflows.
 
-- ``c_shadow`` lets you specify a different name for logs, monitors
+- ``shadow`` lets you specify a different name for logs, monitors
   can be used for e.g. meta tasks that calls any function::
 
     from celery.utils.imports import qualname
@@ -108,8 +108,8 @@ Definition
         'parent_id': (uuid)parent_id,
 
         # optional
-        'c_meth': (string)unused,
-        'c_shadow': (string)replace_name,
+        'meth': (string)unused,
+        'shadow': (string)replace_name,
         'eta': (iso8601)eta,
         'expires'; (iso8601)expires,
         'callbacks': (list)Signature,
@@ -135,7 +135,7 @@ Example
         message=json.dumps([[2, 2], {}]),
         application_headers={
             'lang': 'py',
-            'c_type': 'proj.tasks.add',
+            'task': 'proj.tasks.add',
             'chain': [
                 # reversed chain list
                 {'task': 'proj.tasks.add', 'args': (8, )},

+ 4 - 0
funtests/stress/stress/templates.py

@@ -125,3 +125,7 @@ class sqs(default):
     BROKER_TRANSPORT_OPTIONS = {
         'region': os.environ.get('AWS_REGION', 'us-east-1'),
     }
+
+@template()
+class proto1(default):
+    CELERY_TASK_PROTOCOL = 1