Преглед изворни кода

Can now override the execution strategy for each task.

Where the execution strategy is a coroutine, which is sent
the message, body and ack method for each task.
Ask Solem пре 13 година
родитељ
комит
1515d94b39

+ 3 - 3
celery/app/__init__.py

@@ -115,14 +115,14 @@ class App(base.BaseApp):
 
     def Worker(self, **kwargs):
         """Create new :class:`~celery.apps.worker.Worker` instance."""
-        return instantiate("celery.apps.worker.Worker", app=self, **kwargs)
+        return instantiate("celery.apps.worker:Worker", app=self, **kwargs)
 
     def WorkController(self, **kwargs):
-        return instantiate("celery.worker.WorkController", app=self, **kwargs)
+        return instantiate("celery.worker:WorkController", app=self, **kwargs)
 
     def Beat(self, **kwargs):
         """Create new :class:`~celery.apps.beat.Beat` instance."""
-        return instantiate("celery.apps.beat.Beat", app=self, **kwargs)
+        return instantiate("celery.apps.beat:Beat", app=self, **kwargs)
 
     def TaskSet(self, *args, **kwargs):
         """Create new :class:`~celery.task.sets.TaskSet`."""

+ 5 - 5
celery/app/base.py

@@ -72,12 +72,12 @@ class BaseApp(object):
     IS_OSX = platforms.IS_OSX
     IS_WINDOWS = platforms.IS_WINDOWS
 
-    amqp_cls = "celery.app.amqp.AMQP"
+    amqp_cls = "celery.app.amqp:AMQP"
     backend_cls = None
-    events_cls = "celery.events.Events"
-    loader_cls = "celery.loaders.app.AppLoader"
-    log_cls = "celery.log.Logging"
-    control_cls = "celery.task.control.Control"
+    events_cls = "celery.events:Events"
+    loader_cls = "celery.loaders.app:AppLoader"
+    log_cls = "celery.log:Logging"
+    control_cls = "celery.task.control:Control"
 
     _pool = None
 

+ 6 - 12
celery/app/task/__init__.py

@@ -20,7 +20,7 @@ from ...exceptions import MaxRetriesExceededError, RetryTaskError
 from ...execute.trace import TaskTrace
 from ...registry import tasks, _unpickle_task
 from ...result import EagerResult
-from ...utils import fun_takes_kwargs, mattrgetter, uuid
+from ...utils import fun_takes_kwargs, instantiate, mattrgetter, uuid
 from ...utils.mail import ErrorMail
 
 extract_exec_options = mattrgetter("queue", "routing_key",
@@ -246,6 +246,9 @@ class BaseTask(object):
     #: The type of task *(no longer used)*.
     type = "regular"
 
+    #: Execution strategy used, or the qualified name of one.
+    Strategy = "celery.worker.strategy:default"
+
     def __call__(self, *args, **kwargs):
         return self.run(*args, **kwargs)
 
@@ -256,17 +259,8 @@ class BaseTask(object):
         """The body of the task executed by workers."""
         raise NotImplementedError("Tasks must define the run method.")
 
-    def execution_strategy(self, app, logger, hostname, eventer):
-        from celery.worker.job import TaskRequest
-        create = TaskRequest.from_message
-
-        def handle_message(message, body, ack):
-            return create(message, body, ack,
-                          app=app, logger=logger,
-                          hostname=hostname, eventer=eventer)
-
-        return handle_message
-
+    def start_strategy(self, app, consumer):
+        return instantiate(self.Strategy, self, app, consumer)
 
     @classmethod
     def get_logger(self, loglevel=None, logfile=None, propagate=False,

+ 8 - 8
celery/backends/__init__.py

@@ -7,14 +7,14 @@ from ..utils import get_cls_by_name
 from ..utils.functional import memoize
 
 BACKEND_ALIASES = {
-    "amqp": "celery.backends.amqp.AMQPBackend",
-    "cache": "celery.backends.cache.CacheBackend",
-    "redis": "celery.backends.redis.RedisBackend",
-    "mongodb": "celery.backends.mongodb.MongoBackend",
-    "tyrant": "celery.backends.tyrant.TyrantBackend",
-    "database": "celery.backends.database.DatabaseBackend",
-    "cassandra": "celery.backends.cassandra.CassandraBackend",
-    "disabled": "celery.backends.base.DisabledBackend",
+    "amqp": "celery.backends.amqp:AMQPBackend",
+    "cache": "celery.backends.cache:CacheBackend",
+    "redis": "celery.backends.redis:RedisBackend",
+    "mongodb": "celery.backends.mongodb:MongoBackend",
+    "tyrant": "celery.backends.tyrant:TyrantBackend",
+    "database": "celery.backends.database:DatabaseBackend",
+    "cassandra": "celery.backends.cassandra:CassandraBackend",
+    "disabled": "celery.backends.base:DisabledBackend",
 }
 
 

+ 1 - 1
celery/bin/celerybeat.py

@@ -77,7 +77,7 @@ class BeatCommand(Command):
                 default=None,
                 action="store", dest="scheduler_cls",
                 help="Scheduler class. Default is "
-                     "celery.beat.PersistentScheduler"),
+                     "celery.beat:PersistentScheduler"),
             Option('-l', '--loglevel',
                 default=conf.CELERYBEAT_LOG_LEVEL,
                 action="store", dest="loglevel",

+ 1 - 1
celery/bin/celeryd.py

@@ -140,7 +140,7 @@ class WorkerCommand(Command):
                 default=None,
                 action="store", dest="scheduler_cls",
                 help="Scheduler class. Default is "
-                     "celery.beat.PersistentScheduler"),
+                     "celery.beat:PersistentScheduler"),
             Option('-S', '--statedb', default=conf.CELERYD_STATE_DB,
                 action="store", dest="db",
                 help="Path to the state database. The extension '.db' will "

+ 5 - 5
celery/concurrency/__init__.py

@@ -4,11 +4,11 @@ from __future__ import absolute_import
 from ..utils import get_cls_by_name
 
 ALIASES = {
-    "processes": "celery.concurrency.processes.TaskPool",
-    "eventlet": "celery.concurrency.eventlet.TaskPool",
-    "gevent": "celery.concurrency.gevent.TaskPool",
-    "threads": "celery.concurrency.threads.TaskPool",
-    "solo": "celery.concurrency.solo.TaskPool",
+    "processes": "celery.concurrency.processes:TaskPool",
+    "eventlet": "celery.concurrency.eventlet:TaskPool",
+    "gevent": "celery.concurrency.gevent:TaskPool",
+    "threads": "celery.concurrency.threads:TaskPool",
+    "solo": "celery.concurrency.solo:TaskPool",
 }
 
 

+ 4 - 0
celery/exceptions.py

@@ -84,6 +84,10 @@ class NotConfigured(UserWarning):
     """Celery has not been configured, as no config module has been found."""
 
 
+class InvalidTaskError(Exception):
+    """The task has invalid data or is not properly constructed."""
+
+
 class CPendingDeprecationWarning(PendingDeprecationWarning):
     pass
 

+ 3 - 3
celery/loaders/__init__.py

@@ -15,9 +15,9 @@ from __future__ import absolute_import
 from .. import current_app
 from ..utils import deprecated, get_cls_by_name
 
-LOADER_ALIASES = {"app": "celery.loaders.app.AppLoader",
-                  "default": "celery.loaders.default.Loader",
-                  "django": "djcelery.loaders.DjangoLoader"}
+LOADER_ALIASES = {"app": "celery.loaders.app:AppLoader",
+                  "default": "celery.loaders.default:Loader",
+                  "django": "djcelery.loaders:DjangoLoader"}
 
 
 def get_loader_cls(loader):

+ 4 - 4
celery/platforms.py

@@ -561,18 +561,18 @@ def set_process_title(progname, info=None):
     return proctitle
 
 
-def set_mp_process_title(progname, info=None, hostname=None):
+def set_mp_process_title(progname, info=None, hostname=None, rate_limit=False):
     """Set the ps name using the multiprocessing process name.
 
     Only works if :mod:`setproctitle` is installed.
 
     """
-    if _setps_bucket.can_consume(1):
+    if not rate_limit or _setps_bucket.can_consume(1):
         if hostname:
             progname = "%s@%s" % (progname, hostname.split(".")[0])
         if current_process is not None:
-            return set_process_title("%s:%s" % (progname,
-                                                current_process().name), info=info)
+            return set_process_title(
+                "%s:%s" % (progname, current_process().name), info=info)
         else:
             return set_process_title(progname, info=info)
 

+ 4 - 0
celery/tests/test_worker/__init__.py

@@ -297,6 +297,7 @@ class test_Consumer(unittest.TestCase):
                                    eta=datetime.now().isoformat())
         l.event_dispatcher = Mock()
         l.pidbox_node = MockNode()
+        l.update_strategies()
 
         l.receive_message(m.decode(), m)
         self.assertTrue(m.acknowledged)
@@ -308,6 +309,7 @@ class test_Consumer(unittest.TestCase):
                            send_events=False)
         m = create_message(Mock(), task=foo_task.name,
                            args=(1, 2), kwargs="foobarbaz", id=1)
+        l.update_strategies()
         l.event_dispatcher = Mock()
         l.pidbox_node = MockNode()
 
@@ -336,6 +338,7 @@ class test_Consumer(unittest.TestCase):
                            send_events=False)
         m = create_message(Mock(), task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
+        l.update_strategies()
 
         l.event_dispatcher = Mock()
         l.receive_message(m.decode(), m)
@@ -463,6 +466,7 @@ class test_Consumer(unittest.TestCase):
         l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
         l.event_dispatcher = Mock()
         l.enabled = False
+        l.update_strategies()
         l.receive_message(m.decode(), m)
         l.eta_schedule.stop()
 

+ 4 - 5
celery/tests/test_worker/test_worker_job.py

@@ -19,15 +19,14 @@ from celery.app import app_or_default
 from celery.concurrency.base import BasePool
 from celery.datastructures import ExceptionInfo
 from celery.task import task as task_dec
-from celery.exceptions import RetryTaskError, NotRegistered, WorkerLostError
+from celery.exceptions import (RetryTaskError, NotRegistered,
+                               WorkerLostError, InvalidTaskError)
 from celery.log import setup_logger
 from celery.result import AsyncResult
 from celery.task.base import Task
 from celery.utils import uuid
-from celery.utils.encoding import from_utf8
-from celery.worker.job import (WorkerTaskTrace, TaskRequest,
-                               InvalidTaskError, execute_and_trace,
-                               default_encode)
+from celery.utils.encoding import from_utf8, default_encode
+from celery.worker.job import WorkerTaskTrace, TaskRequest, execute_and_trace
 from celery.worker.state import revoked
 
 from celery.tests.compat import catch_warnings

+ 18 - 6
celery/utils/__init__.py

@@ -264,10 +264,16 @@ def mattrgetter(*attrs):
                                 for attr in attrs)
 
 
-def get_full_cls_name(cls):
-    """With a class, get its full module and class name."""
-    return ".".join([cls.__module__,
-                     cls.__name__])
+if sys.version_info >= (3, 3):
+
+    def qualname(obj):
+        return obj.__qualname__
+
+else:
+
+    def qualname(obj):  # noqa
+        return '.'.join([cls.__module__, cls.__name__])
+get_full_cls_name = qualname
 
 
 def fun_takes_kwargs(fun, kwlist=[]):
@@ -299,7 +305,8 @@ def fun_takes_kwargs(fun, kwlist=[]):
     return filter(partial(operator.contains, args), kwlist)
 
 
-def get_cls_by_name(name, aliases={}, imp=None, package=None, **kwargs):
+def get_cls_by_name(name, aliases={}, imp=None, package=None,
+        sep='.', **kwargs):
     """Get class by name.
 
     The name should be the full dot-separated path to the class::
@@ -311,6 +318,10 @@ def get_cls_by_name(name, aliases={}, imp=None, package=None, **kwargs):
         celery.concurrency.processes.TaskPool
                                     ^- class name
 
+    or using ':' to separate module and symbol::
+
+        celery.concurrency.processes:TaskPool
+
     If `aliases` is provided, a dict containing short name/long name
     mappings, the name is looked up in the aliases first.
 
@@ -336,7 +347,8 @@ def get_cls_by_name(name, aliases={}, imp=None, package=None, **kwargs):
         return name                                 # already a class
 
     name = aliases.get(name) or name
-    module_name, _, cls_name = name.rpartition(".")
+    sep = ':' if ':' in name else sep
+    module_name, _, cls_name = name.rpartition(sep)
     if not module_name and package:
         module_name = package
     try:

+ 140 - 0
celery/utils/coroutine.py

@@ -0,0 +1,140 @@
+from functools import wraps
+from Queue import Queue
+
+from celery.utils import cached_property
+
+
+def coroutine(fun):
+    """Decorator that turns a generator into a coroutine that is
+    started automatically, and that can send values back to the caller.
+
+    **Example coroutine that returns values to caller**::
+
+        @coroutine
+        def adder(self):
+            while 1:
+            x, y = (yield)
+            self.give(x + y)
+
+        >>> c = adder()
+
+        # call sends value and returns the result.
+        >>> c.call(4, 4)
+        8
+
+        # or you can send the value and get the result later.
+        >>> c.send(4, 4)
+        >>> c.get()
+        8
+
+
+    **Example sink (input-only coroutine)**::
+
+        @coroutine
+        def uniq():
+            seen = set()
+            while 1:
+                line = (yield)
+                if line not in seen:
+                    seen.add(line)
+                    print(line)
+
+        >>> u = uniq()
+        >>> [u.send(l) for l in [1, 2, 2, 3]]
+        [1, 2, 3]
+
+    **Example chaining coroutines**::
+
+        @coroutine
+        def uniq(callback):
+            seen = set()
+            while 1:
+                line = (yield)
+                if line not in seen:
+                    callback.send(line)
+                    seen.add(line)
+
+        @coroutine
+        def uppercaser(callback):
+            while 1:
+                line = (yield)
+                callback.send(str(line).upper())
+
+        @coroutine
+        def printer():
+            while 1:
+                line = (yield)
+                print(line)
+
+        >>> pipe = uniq(uppercaser(printer()))
+        >>> for line in file("AUTHORS").readlines():
+                pipe.send(line)
+
+    """
+    @wraps(fun)
+    def start(*args, **kwargs):
+        return Coroutine.start_from(fun, *args, **kwargs)
+    return start
+
+
+class Coroutine(object):
+    _gen = None
+    started = False
+
+    def bind(self, generator):
+        self._gen = generator
+
+    def _next(self):
+        return self._gen.next()
+    next = __next__ = _next
+
+    def start(self):
+        if self.started:
+            raise ValueError("coroutine already started")
+        self.next()
+        self.started = True
+        return self
+
+    def send1(self, value):
+        return self._gen.send(value)
+
+    def call1(self, value, timeout=None):
+        self.send1(value)
+        return self.get(timeout=timeout)
+
+    def send(self, *args):
+        return self._gen.send(args)
+
+    def call(self, *args, **opts):
+        self.send(*args)
+        return self.get(**opts)
+
+    @classmethod
+    def start_from(cls, fun, *args, **kwargs):
+        coro = cls()
+        coro.bind(fun(coro, *args, **kwargs))
+        return coro.start()
+
+    @cached_property
+    def __output__(self):
+        return Queue()
+
+    @property
+    def give(self):
+        return self.__output__.put_nowait
+
+    @property
+    def get(self):
+        return self.__output__.get
+
+if __name__ == "__main__":
+
+    @coroutine
+    def adder(self):
+        while 1:
+            x, y = (yield)
+            self.give(x + y)
+
+    x = adder()
+    for i in xrange(10):
+        print(x.call(i, i))

+ 3 - 0
celery/utils/encoding.py

@@ -39,6 +39,9 @@ if is_py3k:
             return str_to_bytes(s)
         return s
 
+    def default_encode(obj):
+        return obj
+
     str_t = str
     bytes_t = bytes
 

+ 8 - 14
celery/worker/consumer.py

@@ -85,13 +85,12 @@ import warnings
 
 from ..app import app_or_default
 from ..datastructures import AttributeDict
-from ..exceptions import NotRegistered
+from ..exceptions import InvalidTaskError
 from ..registry import tasks
 from ..utils import noop
 from ..utils import timer2
 from ..utils.encoding import safe_repr
 from . import state
-from .job import TaskRequest, InvalidTaskError
 from .control import Panel
 from .heartbeat import Heart
 
@@ -298,13 +297,10 @@ class Consumer(object):
         self._does_info = self.logger.isEnabledFor(logging.INFO)
         self.strategies = {}
 
-    def update_strategies(self, eventer):
+    def update_strategies(self):
         S = self.strategies
         for task in tasks.itervalues():
-            S[task.name] = task.execution_strategy(self.app,
-                                                   self.logger,
-                                                   self.hostname,
-                                                   eventer)
+            S[task.name] = task.start_strategy(self.app, self)
 
     def start(self):
         """Start the consumer.
@@ -434,8 +430,8 @@ class Consumer(object):
             return
 
         try:
-            task = self.strategies[name](message, body, ack)
-        except NotRegistered, exc:
+            self.strategies[name].send(message, body, ack)
+        except KeyError, exc:
             self.logger.error(UNKNOWN_TASK_ERROR, exc, safe_repr(body),
                               exc_info=sys.exc_info())
             ack()
@@ -443,8 +439,6 @@ class Consumer(object):
             self.logger.error(INVALID_TASK_ERROR, str(exc), safe_repr(body),
                               exc_info=sys.exc_info())
             ack()
-        else:
-            self.on_task(task)
 
     def maybe_conn_error(self, fun):
         """Applies function but ignores any connection or channel
@@ -609,12 +603,12 @@ class Consumer(object):
             self.event_dispatcher.copy_buffer(prev_event_dispatcher)
             self.event_dispatcher.flush()
 
-        # reload all task's execution strategies.
-        self.update_strategies(self.event_dispatcher)
-
         # Restart heartbeat thread.
         self.restart_heartbeat()
 
+        # reload all task's execution strategies.
+        self.update_strategies()
+
         # We're back!
         self._state = RUN
 

+ 7 - 21
celery/worker/job.py

@@ -23,13 +23,13 @@ from datetime import datetime
 
 from .. import current_app
 from .. import exceptions
-from .. import platforms
 from .. import registry
 from ..app import app_or_default
 from ..datastructures import ExceptionInfo
 from ..execute.trace import TaskTrace
+from ..platforms import set_mp_process_title as setps
 from ..utils import noop, kwdict, fun_takes_kwargs, truncate_text
-from ..utils.encoding import safe_repr, safe_str, default_encoding
+from ..utils.encoding import safe_repr, safe_str, default_encode
 from ..utils.timeutils import maybe_iso8601, timezone
 from ..utils.serialization import get_pickleable_exception
 
@@ -40,21 +40,6 @@ from . import state
 WANTED_DELIVERY_INFO = ("exchange", "routing_key", "consumer_tag", )
 
 
-class InvalidTaskError(Exception):
-    """The task has invalid data or is not properly constructed."""
-    pass
-
-
-if sys.version_info >= (3, 0):
-
-    def default_encode(obj):
-        return obj
-else:
-
-    def default_encode(obj):  # noqa
-        return unicode(obj, default_encoding())
-
-
 class WorkerTaskTrace(TaskTrace):
     """Wraps the task in a jail, catches all exceptions, and
     saves the status and result of the task execution to the task
@@ -164,11 +149,11 @@ def execute_and_trace(task_name, *args, **kwargs):
 
     """
     hostname = kwargs.get("hostname")
-    platforms.set_mp_process_title("celeryd", task_name, hostname=hostname)
+    setps("celeryd", task_name, hostname, rate_limit=True)
     try:
         return WorkerTaskTrace(task_name, *args, **kwargs).execute_safe()
     finally:
-        platforms.set_mp_process_title("celeryd", "-idle-", hostname)
+        setps("celeryd", "-idle-", hostname, rate_limit=True)
 
 
 class TaskRequest(object):
@@ -291,12 +276,13 @@ class TaskRequest(object):
 
         kwargs = body.get("kwargs", {})
         if not hasattr(kwargs, "items"):
-            raise InvalidTaskError("Task keyword arguments is not a mapping.")
+            raise exceptions.InvalidTaskError(
+                    "Task keyword arguments is not a mapping.")
         try:
             task_name = body["task"]
             task_id = body["id"]
         except KeyError, exc:
-            raise InvalidTaskError(
+            raise exceptions.InvalidTaskError(
                 "Task message is missing required field %r" % (exc, ))
 
         return cls(task_name=task_name,

+ 21 - 0
celery/worker/strategy.py

@@ -0,0 +1,21 @@
+from .job import TaskRequest
+
+from ..utils.coroutine import coroutine
+
+
+def default(task, app, consumer):
+
+    @coroutine
+    def task_message_handler(self):
+        logger = consumer.logger
+        hostname = consumer.hostname
+        eventer = consumer.event_dispatcher
+        Request = TaskRequest.from_message
+        handle = consumer.on_task
+
+        while 1:
+            M, B, A = (yield)
+            handle(Request(M, B, A, app=app, logger=logger,
+                                    hostname=hostname, eventer=eventer))
+
+    return task_message_handler()