Browse Source

celery.Celery is now a real class, not a factory

Ask Solem 13 years ago
parent
commit
3ddfeff221

+ 52 - 40
celery/__compat__.py

@@ -11,6 +11,8 @@ MODULE_DEPRECATED = """
 The module %s is deprecated and will be removed in a future version.
 """
 
+DEFAULT_ATTRS = frozenset(("__file__", "__path__", "__doc__", "__all__"))
+
 # im_func is no longer available in Py3.
 # instead the unbound method itself can be used.
 if sys.version_info[0] == 3:
@@ -21,6 +23,11 @@ else:
         return method.im_func
 
 
+def getappattr(path):
+    """Gets attribute from the current_app recursively,
+    e.g. getappattr("amqp.get_task_consumer")``."""
+    from celery import current_app
+    return reduce(lambda a, b: getattr(a, b), [current_app] + path)
 def _compat_task_decorator(*args, **kwargs):
     from celery import current_app
     kwargs.setdefault("accept_magic_kwargs", True)
@@ -33,8 +40,11 @@ def _compat_periodic_task_decorator(*args, **kwargs):
     return periodic_task(*args, **kwargs)
 
 
-modules = {
+COMPAT_MODULES = {
     "celery": {
+        "execute": {
+            "send_task": "send_task",
+        },
         "decorators": {
             "task": _compat_task_decorator,
             "periodic_task": _compat_periodic_task_decorator,
@@ -53,7 +63,7 @@ modules = {
             "TaskConsumer": "amqp.TaskConsumer",
             "establish_connection": "broker_connection",
             "with_connection": "with_default_connection",
-            "get_consumer_set": "amqp.get_task_consumer"
+            "get_consumer_set": "amqp.get_task_consumer",
         },
         "registry": {
             "tasks": "tasks",
@@ -62,27 +72,6 @@ modules = {
 }
 
 
-def rgetattr(obj, path):
-    return reduce(lambda a, b: getattr(a, b), [obj] + path)
-
-
-def get_compat(app, pkg, name, bases=(ModuleType, )):
-    from warnings import warn
-    from .exceptions import CDeprecationWarning
-
-    fqdn = '.'.join([pkg.__name__, name])
-    warn(CDeprecationWarning(MODULE_DEPRECATED % fqdn))
-
-    def build_attr(attr):
-        if isinstance(attr, basestring):
-            return Proxy(rgetattr, (app, attr.split('.')))
-        return attr
-    attrs = dict((name, build_attr(attr))
-                    for name, attr in modules[pkg.__name__][name].iteritems())
-    sys.modules[fqdn] = module = type(name, bases, attrs)(fqdn)
-    return module
-
-
 class class_property(object):
 
     def __init__(self, fget=None, fset=None):
@@ -115,41 +104,64 @@ class MagicModule(ModuleType):
     _compat_modules = ()
     _all_by_module = {}
     _direct = {}
+    _object_origins = {}
 
     def __getattr__(self, name):
-        origins = self._object_origins
-        if name in origins:
-            module = __import__(origins[name], None, None, [name])
-            for extra_name in self._all_by_module[module.__name__]:
-                setattr(self, extra_name, getattr(module, extra_name))
+        if name in self._object_origins:
+            module = __import__(self._object_origins[name], None, None, [name])
+            for item in self._all_by_module[module.__name__]:
+                setattr(self, item, getattr(module, item))
             return getattr(module, name)
         elif name in self._direct:
             module = __import__(self._direct[name], None, None, [name])
             setattr(self, name, module)
             return module
-        elif name in self._compat_modules:
-            setattr(self, name, get_compat(self.current_app, self, name))
         return ModuleType.__getattribute__(self, name)
 
     def __dir__(self):
-        return list(set(self.__all__
-                     + ("__file__", "__path__", "__doc__", "__all__")))
+        return list(set(self.__all__) + DEFAULT_ATTRS)
 
 
+def get_compat_module(pkg, name):
 
-def create_magic_module(name, compat_modules=(), by_module={}, direct={},
-        base=MagicModule, **attrs):
-    old_module = sys.modules[name]
+    def prepare(attr):
+        if isinstance(attr, basestring):
+            return Proxy(getappattr, (attr.split('.'), ))
+        return attr
+
+    return create_module(name, COMPAT_MODULES[pkg.__name__][name],
+                         pkg=pkg, prepare_attr=prepare)
+
+
+def create_module(name, attrs, cls_attrs=None, pkg=None,
+        bases=(MagicModule, ), prepare_attr=None):
+    fqdn = '.'.join([pkg.__name__, name]) if pkg else name
+    cls_attrs = {} if cls_attrs is None else cls_attrs
+
+    attrs = dict((attr_name, prepare_attr(attr) if prepare_attr else attr)
+                    for attr_name, attr in attrs.iteritems())
+    module = sys.modules[fqdn] = type(name, bases, cls_attrs)(fqdn)
+    module.__dict__.update(attrs)
+    return module
+
+
+def get_origins(defs):
     origins = {}
-    for module, items in by_module.iteritems():
-        for item in items:
-            origins[item] = module
+    for module, items in defs.iteritems():
+        origins.update(dict((item, module) for item in items))
+    return origins
+
+def recreate_module(name, compat_modules=(), by_module={}, direct={}, **attrs):
+    old_module = sys.modules[name]
+    origins = get_origins(by_module)
+    compat_modules = COMPAT_MODULES.get(name, ())
 
     cattrs = dict(_compat_modules=compat_modules,
                   _all_by_module=by_module, _direct=direct,
                   _object_origins=origins,
                   __all__=tuple(set(reduce(operator.add, map(tuple, [
                                 compat_modules, origins, direct, attrs])))))
-    new_module = sys.modules[name] = type(name, (base, ), cattrs)(name)
-    new_module.__dict__.update(attrs)
+    new_module = create_module(name, attrs, cls_attrs=cattrs)
+    new_module.__dict__.update(dict((mod, get_compat_module(new_module, mod))
+                                     for mod in compat_modules))
     return old_module, new_module

+ 4 - 23
celery/__init__.py

@@ -22,29 +22,13 @@ if sys.version_info < (2, 5):
         "Please use Celery versions 2.1.x or earlier.")
 
 # Lazy loading
-from types import ModuleType
-from .local import Proxy
-from .__compat__ import create_magic_module
+from .__compat__ import recreate_module
 
 
-def Celery(*args, **kwargs):
-    from .app import App
-    return App(*args, **kwargs)
-
-
-def _get_current_app():
-    from .app import current_app
-    return current_app()
-current_app = Proxy(_get_current_app)
-
-
-def bugreport():
-    return current_app.bugreport()
-
-
-old_module, new_module = create_magic_module(__name__,
-    compat_modules=("messaging", "log", "registry", "decorators"),
+old_module, new_module = recreate_module(__name__,
     by_module={
+        "celery.app": ["Celery", "bugreport"],
+        "celery.app.state": ["current_app", "current_task"],
         "celery.task.sets": ["chain", "group", "subtask"],
         "celery.task.chords": ["chord"],
     },
@@ -59,7 +43,4 @@ old_module, new_module = create_magic_module(__name__,
     __homepage__=__homepage__,
     __docformat__=__docformat__,
     VERSION=VERSION,
-    Celery=Celery,
-    current_app=current_app,
-    bugreport=bugreport,
 )

+ 11 - 8
celery/app/__init__.py

@@ -16,11 +16,11 @@ import os
 from celery.local import Proxy
 
 from . import state
-from .base import App, AppPickler  # noqa
+from .base import Celery, AppPickler  # noqa
 
 set_default_app = state.set_default_app
-current_app = state.current_app
-current_task = state.current_task
+current_app = state.get_current_app
+current_task = state.get_current_task
 default_app = Proxy(lambda: state.default_app)
 
 #: Returns the app provided or the default app if none.
@@ -34,18 +34,18 @@ app_or_default = None
 default_loader = os.environ.get("CELERY_LOADER") or "default"
 
 #: Global fallback app instance.
-set_default_app(App("default", loader=default_loader,
-                               set_as_current=False,
-                               accept_magic_kwargs=True))
+set_default_app(Celery("default", loader=default_loader,
+                                  set_as_current=False,
+                                  accept_magic_kwargs=True))
 
 
 def bugreport():
-    return current_app().bugreport()
+    return current_app.bugreport()
 
 
 def _app_or_default(app=None):
     if app is None:
-        return current_app()
+        return state.get_current_app()
     return app
 
 
@@ -78,3 +78,6 @@ if os.environ.get("CELERY_TRACE_APP"):  # pragma: no cover
     enable_trace()
 else:
     disable_trace()
+
+
+App = Celery  # XXX Compat

+ 2 - 1
celery/app/base.py

@@ -37,7 +37,7 @@ from .state import _tls
 from .utils import AppPickler, Settings, bugreport, _unpickle_app
 
 
-class App(object):
+class Celery(object):
     """Celery Application.
 
     :param main: Name of the main module if running as `__main__`.
@@ -522,3 +522,4 @@ class App(object):
         """
         self.finalize()
         return self._tasks
+App = Celery

+ 8 - 2
celery/app/state.py

@@ -2,6 +2,8 @@ from __future__ import absolute_import
 
 import threading
 
+from celery.local import Proxy
+
 default_app = None
 
 
@@ -21,9 +23,13 @@ def set_default_app(app):
     default_app = app
 
 
-def current_app():
+def get_current_app():
     return getattr(_tls, "current_app", None) or default_app
 
 
-def current_task():
+def get_current_task():
     return getattr(_tls, "current_task", None)
+
+
+current_app = Proxy(get_current_app)
+current_task = Proxy(get_current_task)

+ 2 - 2
celery/app/task/__init__.py

@@ -25,7 +25,7 @@ from celery.utils.functional import mattrgetter, maybe_list
 from celery.utils.imports import instantiate
 from celery.utils.mail import ErrorMail
 
-from celery.app.state import current_task
+from celery.app.state import get_current_task
 from celery.app.registry import _unpickle_task
 
 #: extracts options related to publishing a message from a dict.
@@ -544,7 +544,7 @@ class BaseTask(object):
                 publish.release()
 
         result = self.AsyncResult(task_id)
-        parent = current_task()
+        parent = get_current_task()
         if parent:
             parent.request.children.append(result)
         return result

+ 1 - 1
celery/backends/__init__.py

@@ -5,7 +5,7 @@ import sys
 
 from kombu.utils.url import _parse_url
 
-from celery import current_app
+from celery.app.state import current_app
 from celery.local import Proxy
 from celery.utils.imports import symbol_by_name
 from celery.utils.functional import memoize

+ 2 - 3
celery/bin/celery.py

@@ -249,9 +249,8 @@ apply = command(apply)
 class purge(Command):
 
     def run(self, *args, **kwargs):
-        app = current_app()
-        queues = len(app.amqp.queues.keys())
-        messages_removed = app.control.discard_all()
+        queues = len(current_app.amqp.queues.keys())
+        messages_removed = current_app.control.discard_all()
         if messages_removed:
             self.out("Purged %s %s from %s known task %s." % (
                 messages_removed, pluralize(messages_removed, "message"),

+ 0 - 6
celery/execute/__init__.py

@@ -1,6 +0,0 @@
-from __future__ import absolute_import
-
-from celery import current_app
-from celery.local import Proxy
-
-send_task = Proxy(lambda: current_app.send_task)

+ 1 - 1
celery/loaders/__init__.py

@@ -12,7 +12,7 @@
 """
 from __future__ import absolute_import
 
-from celery import current_app
+from celery.app.state import current_app
 from celery.utils import deprecated
 from celery.utils.imports import symbol_by_name
 

+ 2 - 5
celery/task/__init__.py

@@ -11,11 +11,8 @@
 """
 from __future__ import absolute_import
 
-import sys
-
-from celery import current_app
+from celery.app.state import current_app, current_task as current
 from celery.__compat__ import MagicModule, recreate_module
-from celery.app import current_task as _current_task
 from celery.local import Proxy
 
 
@@ -37,7 +34,7 @@ old_module, new_module = recreate_module(__name__,
     __file__=__file__,
     __path__=__path__,
     __doc__=__doc__,
-    current=Proxy(_current_task),
+    current=current,
     discard_all=Proxy(lambda: current_app.control.discard_all),
     backend_cleanup=Proxy(
         lambda: current_app.tasks["celery.backend_cleanup"]

+ 2 - 2
celery/task/sets.py

@@ -12,7 +12,7 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
-from itertools import chain
+from itertools import chain as _chain
 
 from kombu.utils import reprcall
 
@@ -109,7 +109,7 @@ class subtask(AttributeDict):
     def flatten_links(self):
         """Gives a recursive list of dependencies (unchain if you will,
         but with links intact)."""
-        return list(chain_from_iterable(chain([[self]],
+        return list(chain_from_iterable(_chain([[self]],
                 (link.flatten_links()
                     for link in maybe_list(self.options.get("link")) or []))))
 

+ 0 - 0
celery/execute/trace.py → celery/task/trace.py


+ 3 - 3
celery/tests/test_bin/test_celeryev.py

@@ -27,7 +27,7 @@ class test_EvCommand(Case):
         self.ev = celeryev.EvCommand(app=self.app)
 
     @patch("celery.events.dumper", "evdump", lambda **kw: "me dumper, you?")
-    @patch("celery.platforms", "set_process_title", proctitle)
+    @patch("celery.bin.celeryev", "set_process_title", proctitle)
     def test_run_dump(self):
         self.assertEqual(self.ev.run(dump=True), "me dumper, you?")
         self.assertIn("celeryev:dump", proctitle.last[0])
@@ -39,14 +39,14 @@ class test_EvCommand(Case):
             raise SkipTest("curses monitor requires curses")
 
         @patch("celery.events.cursesmon", "evtop", lambda **kw: "me top, you?")
-        @patch("celery.platforms", "set_process_title", proctitle)
+        @patch("celery.bin.celeryev", "set_process_title", proctitle)
         def _inner():
             self.assertEqual(self.ev.run(), "me top, you?")
             self.assertIn("celeryev:top", proctitle.last[0])
         return _inner()
 
     @patch("celery.events.snapshot", "evcam", lambda *a, **k: (a, k))
-    @patch("celery.platforms", "set_process_title", proctitle)
+    @patch("celery.bin.celeryev", "set_process_title", proctitle)
     def test_run_cam(self):
         a, kw = self.ev.run(camera="foo.bar.baz", logfile="logfile")
         self.assertEqual(a[0], "foo.bar.baz")

+ 0 - 0
celery/tests/test_task/test_execute_trace.py → celery/tests/test_task/test_task_trace.py