Browse Source

[Taskv2] callbacks, errbacks, chord and chain moved to message body to avoid header limitations.

Ask Solem 11 years ago
parent
commit
2a60655140

+ 8 - 5
celery/app/amqp.py

@@ -300,11 +300,7 @@ class AMQP(object):
                 'id': task_id,
                 'eta': eta,
                 'expires': expires,
-                'callbacks': callbacks,
-                'errbacks': errbacks,
-                'chain': None,  # TODO
                 'group': group_id,
-                'chord': chord,
                 'retries': retries,
                 'timelimit': [time_limit, soft_time_limit],
                 'root_id': root_id,
@@ -314,7 +310,14 @@ class AMQP(object):
                 'correlation_id': task_id,
                 'reply_to': reply_to or '',
             },
-            body=(args, kwargs),
+            body=(
+                args, kwargs, {
+                    'callbacks': callbacks,
+                    'errbacks': errbacks,
+                    'chain': None,  # TODO
+                    'chord': chord,
+                },
+            ),
             sent_event={
                 'uuid': task_id,
                 'root': root_id,

+ 25 - 18
celery/app/trace.py

@@ -466,36 +466,40 @@ def _trace_task_ret(name, uuid, request, body, content_type,
                     content_encoding, loads=loads_message, app=None,
                     **extra_request):
     app = app or current_app._get_current_object()
-    accept = prepare_accept_content(app.conf.CELERY_ACCEPT_CONTENT)
-    args, kwargs = loads(body, content_type, content_encoding, accept=accept)
-    request.update(args=args, kwargs=kwargs, **extra_request)
+    embed = None
+    if content_type:
+        accept = prepare_accept_content(app.conf.CELERY_ACCEPT_CONTENT)
+        args, kwargs, embed = loads(
+            body, content_type, content_encoding, accept=accept,
+        )
+    else:
+        args, kwargs = body
+    hostname = socket.gethostname()
+    request.update({
+        'args': args, 'kwargs': kwargs,
+        'hostname': hostname, 'is_eager': False,
+    }, **embed or {})
     R, I, T, Rstr = trace_task(app.tasks[name],
                                uuid, args, kwargs, request, app=app)
     return (1, R, T) if I else (0, Rstr, T)
 trace_task_ret = _trace_task_ret
 
 
-def _fast_trace_task_v1(task, uuid, args, kwargs, request={}, _loc=_localized):
-    # setup_worker_optimizations will point trace_task_ret to here,
-    # so this is the function used in the worker.
-    tasks, _ = _loc
-    R, I, T, Rstr = tasks[task].__trace__(uuid, args, kwargs, request)[0]
-    # exception instance if error, else result text
-    return (1, R, T) if I else (0, Rstr, T)
-
-
 def _fast_trace_task(task, uuid, request, body, content_type,
                      content_encoding, loads=loads_message, _loc=_localized,
                      hostname=None, **_):
-    tasks, accept = _loc
+    embed = None
+    tasks, accept, hostname = _loc
     if content_type:
-        args, kwargs = loads(body, content_type, content_encoding,
-                             accept=accept)
+        args, kwargs, embed = loads(
+            body, content_type, content_encoding, accept=accept,
+        )
     else:
         args, kwargs = body
     request.update({
-        'args': args, 'kwargs': kwargs, 'hostname': hostname,
-    })
+        'args': args, 'kwargs': kwargs,
+        'hostname': hostname, 'is_eager': False,
+    }, **embed or {})
     R, I, T, Rstr = tasks[task].__trace__(
         uuid, args, kwargs, request,
     )
@@ -515,9 +519,11 @@ def report_internal_error(task, exc):
         del(_tb)
 
 
-def setup_worker_optimizations(app):
+def setup_worker_optimizations(app, hostname=None):
     global trace_task_ret
 
+    hostname = hostname or socket.gethostname()
+
     # make sure custom Task.__call__ methods that calls super
     # will not mess up the request/task stack.
     _install_stack_protection()
@@ -538,6 +544,7 @@ def setup_worker_optimizations(app):
     _localized[:] = [
         app._tasks,
         prepare_accept_content(app.conf.CELERY_ACCEPT_CONTENT),
+        hostname,
     ]
 
     trace_task_ret = _fast_trace_task

+ 2 - 2
celery/apps/worker.py

@@ -112,7 +112,7 @@ EXTRA_INFO_FMT = """
 class Worker(WorkController):
 
     def on_before_init(self, **kwargs):
-        trace.setup_worker_optimizations(self.app)
+        trace.setup_worker_optimizations(self.app, self.hostname)
 
         # this signal can be used to set up configuration for
         # workers by name.
@@ -144,7 +144,7 @@ class Worker(WorkController):
         self._custom_logging = self.setup_logging()
         # apply task execution optimizations
         # -- This will finalize the app!
-        trace.setup_worker_optimizations(self.app)
+        trace.setup_worker_optimizations(self.app, self.hostname)
 
     def on_start(self):
         if not self._custom_logging and self.redirect_stdouts:

+ 1 - 1
celery/concurrency/prefork.py

@@ -68,7 +68,7 @@ def process_initializer(app, hostname):
                   hostname=hostname)
     if os.environ.get('FORKED_BY_MULTIPROCESSING'):
         # pool did execv after fork
-        trace.setup_worker_optimizations(app)
+        trace.setup_worker_optimizations(app, hostname)
     else:
         app.set_current()
         set_default_app(app)

+ 1 - 3
celery/worker/request.py

@@ -179,7 +179,7 @@ class Request(object):
         result = pool.apply_async(
             trace_task_ret,
             args=(self.name, task_id, self.request_dict, self.body,
-                  self.content_type, self.content_encoding, self.hostname),
+                  self.content_type, self.content_encoding),
             accept_callback=self.on_accepted,
             timeout_callback=self.on_timeout,
             callback=self.on_success,
@@ -444,7 +444,6 @@ def create_request_cls(base, task, pool, hostname, eventer,
     default_soft_time_limit = task.soft_time_limit
     apply_async = pool.apply_async
     acks_late = task.acks_late
-    std_kwargs = {'hostname': hostname, 'is_eager': False}
     events = eventer and eventer.enabled
 
     class Request(base):
@@ -461,7 +460,6 @@ def create_request_cls(base, task, pool, hostname, eventer,
                 trace,
                 args=(self.name, task_id, self.request_dict, self.body,
                       self.content_type, self.content_encoding),
-                kwargs=std_kwargs,
                 accept_callback=self.on_accepted,
                 timeout_callback=self.on_timeout,
                 callback=self.on_success,

+ 29 - 39
docs/internals/protov2.rst

@@ -21,7 +21,7 @@ Notes
 
 - Body is only for language specific data.
 
-    - Python stores args/kwargs in body.
+    - Python stores args/kwargs and embedded signatures in body.
 
     - If a message uses raw encoding then the raw data
       will be passed as a single argument to the function.
@@ -43,7 +43,7 @@ Notes
     when sending the next message::
 
         execute_task(message)
-        chain = message.headers['chain']
+        chain = embed['chain']
         if chain:
             sig = maybe_signature(chain.pop())
             sig.apply_async(chain=chain)
@@ -74,16 +74,6 @@ Notes
         return fun(*args, **kwargs)
 
 
-
-Undecided
----------
-
-- May consider moving callbacks/errbacks/chain into body.
-
-    Will huge lists in headers cause overhead?
-    The downside of keeping them in the body is that intermediates
-    won't be able to introspect these values.
-
 Definition
 ==========
 
@@ -93,35 +83,40 @@ Definition
     # 'class' header existing means protocol is v2
 
     properties = {
-        'correlation_id': (uuid)task_id,
-        'content_type': (string)mime,
-        'content_encoding': (string)encoding,
+        'correlation_id': uuid task_id,
+        'content_type': string mimetype,
+        'content_encoding': string encoding,
 
         # optional
-        'reply_to': (string)queue_or_url,
+        'reply_to': string queue_or_url,
     }
     headers = {
-        'lang': (string)'py'
-        'task': (string)task,
-        'id': (uuid)task_id,
-        'root_id': (uuid)root_id,
-        'parent_id': (uuid)parent_id,
+        'lang': string 'py'
+        'task': string task,
+        'id': uuid task_id,
+        'root_id': uuid root_id,
+        'parent_id': uuid parent_id,
+        'group': uuid group_id,
 
         # optional
-        'meth': (string)unused,
-        'shadow': (string)replace_name,
-        'eta': (iso8601)eta,
-        'expires'; (iso8601)expires,
-        'callbacks': (list)Signature,
-        'errbacks': (list)Signature,
-        'chain': (list)Signature,  # non-recursive, reversed list of signatures
-        'group': (uuid)group_id,
-        'chord': (uuid)chord_id,
-        'retries': (int)retries,
-        'timelimit': (tuple)(soft, hard),
+        'meth': string method_name,
+        'shadow': string alias_name,
+        'eta':  iso8601 eta,
+        'expires'; iso8601 expires,
+        'retries': int retries,
+        'timelimit': (soft, hard),
     }
 
-    body = (args, kwargs)
+    body = (
+        object[] args,
+        Mapping kwargs,
+        Mapping embed {
+            'callbacks': Signature[] callbacks,
+            'errbacks': Signature[] errbacks,
+            'chain': Signature[] chain,
+            'chord': Signature chord_callback,
+        }
+    )
 
 Example
 =======
@@ -132,15 +127,10 @@ Example
 
     task_id = uuid()
     basic_publish(
-        message=json.dumps([[2, 2], {}]),
+        message=json.dumps(([2, 2], {}, None),
         application_headers={
             'lang': 'py',
             'task': 'proj.tasks.add',
-            'chain': [
-                # reversed chain list
-                {'task': 'proj.tasks.add', 'args': (8, )},
-                {'task': 'proj.tasks.add', 'args': (4, )},
-            ]
         }
         properties={
             'correlation_id': task_id,