Explorar el Código

Adds @task(bare=True) which mean the worker will execute the task directly, without calling callbacks, writing results etc.

Ask Solem hace 12 años
padre
commit
54d160a0b5
Se han modificado 5 ficheros con 58 adiciones y 44 borrados
  1. 8 2
      celery/app/amqp.py
  2. 1 34
      celery/app/task.py
  3. 30 4
      celery/task/trace.py
  4. 18 3
      celery/worker/job.py
  5. 1 1
      funtests/benchmarks/bench_worker.py

+ 8 - 2
celery/app/amqp.py

@@ -29,6 +29,9 @@ QUEUE_FORMAT = """
 . %(name)s exchange:%(exchange)s(%(exchange_type)s) binding:%(routing_key)s
 """
 
+TASK_BARE = 0x004
+TASK_DEFAULT = 0
+
 
 class Queues(dict):
     """Queue name⇒ declaration mapping.
@@ -155,7 +158,7 @@ class TaskProducer(Producer):
             queue=None, now=None, retries=0, chord=None, callbacks=None,
             errbacks=None, mandatory=None, priority=None, immediate=None,
             routing_key=None, serializer=None, delivery_mode=None,
-            compression=None, **kwargs):
+            compression=None, bare=False, **kwargs):
         """Send task message."""
         # merge default and custom policy
         _rp = (dict(self.retry_policy, **retry_policy) if retry_policy
@@ -175,6 +178,8 @@ class TaskProducer(Producer):
             expires = now + timedelta(seconds=expires)
         eta = eta and eta.isoformat()
         expires = expires and expires.isoformat()
+        flags = TASK_DEFAULT
+        flags |= TASK_BARE if bare else 0
 
         body = {"task": task_name,
                 "id": task_id,
@@ -185,7 +190,8 @@ class TaskProducer(Producer):
                 "expires": expires,
                 "utc": self.utc,
                 "callbacks": callbacks,
-                "errbacks": errbacks}
+                "errbacks": errbacks,
+                "flags": flags}
         if taskset_id:
             body["taskset"] = taskset_id
         if chord:

+ 1 - 34
celery/app/task.py

@@ -40,7 +40,7 @@ extract_exec_options = mattrgetter("queue", "routing_key",
                                    "exchange", "immediate",
                                    "mandatory", "priority",
                                    "serializer", "delivery_mode",
-                                   "compression", "expires")
+                                   "compression", "expires", "bare")
 
 
 class Context(object):
@@ -789,24 +789,6 @@ class BaseTask(object):
         """
         pass
 
-    def after_return(self, status, retval, task_id, args, kwargs, einfo):
-        """Handler called after the task returns.
-
-        :param status: Current task state.
-        :param retval: Task return value/exception.
-        :param task_id: Unique id of the task.
-        :param args: Original arguments for the task that failed.
-        :param kwargs: Original keyword arguments for the task
-                       that failed.
-
-        :keyword einfo: :class:`~celery.datastructures.ExceptionInfo`
-                        instance, containing the traceback (if any).
-
-        The return value of this handler is ignored.
-
-        """
-        pass
-
     def on_failure(self, exc, task_id, args, kwargs, einfo):
         """Error handler.
 
@@ -830,21 +812,6 @@ class BaseTask(object):
         if self.send_error_emails and not self.disable_error_emails:
             self.ErrorMail(self, **kwargs).send(context, exc)
 
-    def on_success(self, retval, task_id, args, kwargs):
-        """Success handler.
-
-        Run by the worker if the task executes successfully.
-
-        :param retval: The return value of the task.
-        :param task_id: Unique id of the executed task.
-        :param args: Original arguments for the executed task.
-        :param kwargs: Original keyword arguments for the executed task.
-
-        The return value of this handler is ignored.
-
-        """
-        pass
-
     def execute(self, request, pool, loglevel, logfile, **kwargs):
         """The method the worker calls to execute the task.
 

+ 30 - 4
celery/task/trace.py

@@ -128,6 +128,30 @@ class TraceInfo(object):
             del(tb)
 
 
+def execute_bare(task, uuid, args, kwargs, request=None):
+    R = I = None
+    kwargs = kwdict(kwargs)
+    try:
+        try:
+            R = retval = task(*args, **kwargs)
+            state = SUCCESS
+        except Exception, exc:
+            I = Info(FAILURE, exc)
+            state, retval = I.state, I.retval
+            R = I.handle_error_state(task)
+        except BaseException, exc:
+            raise
+        except:  # pragma: no cover
+            # For Python2.5 where raising strings are still allowed
+            # (but deprecated)
+            I = Info(FAILURE, None)
+            state, retval = I.state, I.retval
+            R = I.handle_error_state(task, eager=eager)
+    except Exception, exc:
+        R = report_internal_error(task, exc)
+    return R
+
+
 def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         Info=TraceInfo, eager=False, propagate=False):
     # If the task doesn't define a custom __call__ method
@@ -146,8 +170,8 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
     loader_task_init = loader.on_task_init
     loader_cleanup = loader.on_process_cleanup
 
-    task_on_success = task.on_success
-    task_after_return = task.after_return
+    task_on_success = getattr(task, "on_success", None)
+    task_after_return = getattr(task, "after_return", None)
 
     store_result = backend.store_result
     backend_cleanup = backend.process_cleanup
@@ -215,14 +239,16 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                     # stored, so that result.children is populated.
                     [subtask(callback).apply_async((retval, ))
                         for callback in task_request.callbacks or []]
-                    task_on_success(retval, uuid, args, kwargs)
+                    if task_on_success:
+                        task_on_success(retval, uuid, args, kwargs)
                     if success_receivers:
                         send_success(sender=task, result=retval)
 
                 # -* POST *-
                 if task_request.chord:
                     on_chord_part_return(task)
-                task_after_return(state, retval, uuid, args, kwargs, None)
+                if task_after_return:
+                    task_after_return(state, retval, uuid, args, kwargs, None)
                 if postrun_receivers:
                     send_postrun(sender=task, task_id=uuid, task=task,
                                  args=args, kwargs=kwargs,

+ 18 - 3
celery/worker/job.py

@@ -26,7 +26,12 @@ from celery import current_app
 from celery import exceptions
 from celery.app import app_or_default
 from celery.datastructures import ExceptionInfo
-from celery.task.trace import build_tracer, trace_task, report_internal_error
+from celery.task.trace import (
+    build_tracer,
+    trace_task,
+    report_internal_error,
+    execute_bare,
+)
 from celery.platforms import set_mp_process_title as setps
 from celery.utils import fun_takes_kwargs
 from celery.utils.functional import noop
@@ -78,7 +83,7 @@ class Request(object):
                  "on_ack", "delivery_info", "hostname",
                  "callbacks", "errbacks",
                  "eventer", "connection_errors",
-                 "task", "eta", "expires",
+                 "task", "eta", "expires", "bare",
                  "request_dict", "acknowledged", "success_msg",
                  "error_msg", "retry_msg", "time_start", "worker_pid",
                  "_already_revoked", "_terminate_on_ack", "_tzlocal")
@@ -120,6 +125,7 @@ class Request(object):
         eta = body.get("eta")
         expires = body.get("expires")
         utc = body.get("utc", False)
+        self.flags = body.get("flags", False)
         self.on_ack = on_ack
         self.hostname = hostname or socket.gethostname()
         self.eventer = eventer
@@ -194,10 +200,19 @@ class Request(object):
         :keyword logfile: The logfile used by the task.
 
         """
+        task = self.task
+        if self.flags & 0x004:
+            return pool.apply_async(execute_bare,
+                    args=(self.task, self.id, self.args, self.kwargs),
+                    accept_callback=self.on_accepted,
+                    timeout_callback=self.on_timeout,
+                    callback=self.on_success,
+                    error_callback=self.on_failure,
+                    soft_timeout=task.soft_time_limit,
+                    timeout=task.time_limit)
         if self.revoked():
             return
 
-        task = self.task
         hostname = self.hostname
         kwargs = self.kwargs
         if self.task.accept_magic_kwargs:

+ 1 - 1
funtests/benchmarks/bench_worker.py

@@ -46,7 +46,7 @@ def tdiff(then):
     return time.time() - then
 
 
-@celery.task(cur=0, time_start=None, queue="bench.worker")
+@celery.task(cur=0, time_start=None, queue="bench.worker", bare=True)
 def it(_, n):
     i = it.cur  # use internal counter, as ordering can be skewed
                 # by previous runs, or the broker.