Browse Source

Tasks now automatically registered in the registry, also automatic naming is now always working with relative imports.

Ask Solem 15 years ago
parent
commit
9374d1a0ae

+ 0 - 6
celery/decorators.py

@@ -6,10 +6,6 @@ from inspect import getargspec
 def task(**options):
 def task(**options):
     """Make a task out of any callable.
     """Make a task out of any callable.
 
 
-        :keyword autoregister: Automatically register the task in the
-            task registry.
-
-
         Examples:
         Examples:
 
 
             >>> @task()
             >>> @task()
@@ -37,7 +33,6 @@ def task(**options):
 
 
     def _create_task_cls(fun):
     def _create_task_cls(fun):
         name = options.pop("name", None)
         name = options.pop("name", None)
-        autoregister = options.pop("autoregister", True)
 
 
         cls_name = fun.__name__
         cls_name = fun.__name__
 
 
@@ -51,7 +46,6 @@ def task(**options):
         cls_dict["__module__"] = fun.__module__
         cls_dict["__module__"] = fun.__module__
 
 
         task = type(cls_name, (Task, ), cls_dict)()
         task = type(cls_name, (Task, ), cls_dict)()
-        autoregister and tasks.register(task)
 
 
         return task
         return task
 
 

+ 52 - 17
celery/task/base.py

@@ -3,12 +3,51 @@ from celery import conf
 from celery.messaging import TaskPublisher, TaskConsumer
 from celery.messaging import TaskPublisher, TaskConsumer
 from celery.log import setup_logger
 from celery.log import setup_logger
 from celery.result import TaskSetResult, EagerResult
 from celery.result import TaskSetResult, EagerResult
-from celery.execute import apply_async, delay_task, apply
+from celery.execute import apply_async, apply
 from celery.utils import gen_unique_id, get_full_cls_name
 from celery.utils import gen_unique_id, get_full_cls_name
 from celery.registry import tasks
 from celery.registry import tasks
 from celery.serialization import pickle
 from celery.serialization import pickle
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
 from datetime import timedelta
 from datetime import timedelta
+import sys
+
+
+class TaskType(type):
+    """Metaclass for tasks.
+
+    Automatically registers the task in the task registry, except
+    if the ``abstract`` attribute is set.
+
+    If no ``name`` attribute is provieded, the name is automatically
+    set to the name of the module it was defined in, and the class name.
+
+    """
+
+    def __new__(cls, name, bases, attrs):
+        super_new = super(TaskType, cls).__new__
+        task_module = attrs["__module__"]
+
+        # Abstract class, remove the abstract attribute so the
+        # any class inheriting from this won't be abstract by default.
+        if attrs.pop("abstract", None):
+            return super_new(cls, name, bases, attrs)
+
+        # Automatically generate missing name.
+        if not attrs.get("name"):
+            task_module = sys.modules[task_module]
+            task_name = ".".join([task_module.__name__, name])
+            attrs["name"] = task_name
+
+        # Because of the way import happens (recursively)
+        # we may or may not be the first time the task tries to register
+        # with the framework. There should only be one class for each task
+        # name, so we always return the registered version.
+
+        task_name = attrs["name"]
+        if task_name not in tasks:
+            task_cls = super_new(cls, name, bases, attrs)
+            tasks.register(task_cls)
+        return tasks[task_name].__class__
 
 
 
 
 class Task(object):
 class Task(object):
@@ -26,7 +65,12 @@ class Task(object):
 
 
         *REQUIRED* All subclasses of :class:`Task` has to define the
         *REQUIRED* All subclasses of :class:`Task` has to define the
         :attr:`name` attribute. This is the name of the task, registered
         :attr:`name` attribute. This is the name of the task, registered
-        in the task registry, and passed to :func:`delay_task`.
+        in the task registry, and passed on to the workers.
+
+    .. attribute:: abstract
+
+        Abstract classes are not registered in the task registry, so they're
+        only used for making new tasks by subclassing.
 
 
     .. attribute:: type
     .. attribute:: type
 
 
@@ -108,7 +152,6 @@ class Task(object):
         ...         logger.info("Running MyTask with arg some_arg=%s" %
         ...         logger.info("Running MyTask with arg some_arg=%s" %
         ...                     some_arg))
         ...                     some_arg))
         ...         return 42
         ...         return 42
-        ... tasks.register(MyTask)
 
 
     You can delay the task using the classmethod :meth:`delay`...
     You can delay the task using the classmethod :meth:`delay`...
 
 
@@ -118,15 +161,12 @@ class Task(object):
         >>> result.result
         >>> result.result
         42
         42
 
 
-    ...or using the :func:`delay_task` function, by passing the name of
-    the task.
-
-        >>> from celery.task import delay_task
-        >>> result = delay_task(MyTask.name, some_arg="foo")
-
 
 
     """
     """
+    __metaclass__ = TaskType
+
     name = None
     name = None
+    abstract = True
     type = "regular"
     type = "regular"
     exchange = None
     exchange = None
     routing_key = None
     routing_key = None
@@ -253,7 +293,8 @@ class Task(object):
 
 
     @classmethod
     @classmethod
     def delay(cls, *args, **kwargs):
     def delay(cls, *args, **kwargs):
-        """Delay this task for execution by the ``celery`` daemon(s).
+        """Shortcut to :meth:`apply_async` but with star arguments,
+        and doesn't support the extra options.
 
 
         :param \*args: positional arguments passed on to the task.
         :param \*args: positional arguments passed on to the task.
 
 
@@ -261,8 +302,6 @@ class Task(object):
 
 
         :rtype: :class:`celery.result.AsyncResult`
         :rtype: :class:`celery.result.AsyncResult`
 
 
-        See :func:`celery.execute.delay_task`.
-
         """
         """
         return apply_async(cls, args, kwargs)
         return apply_async(cls, args, kwargs)
 
 
@@ -429,7 +468,6 @@ class ExecuteRemoteTask(Task):
         """
         """
         callable_ = pickle.loads(ser_callable)
         callable_ = pickle.loads(ser_callable)
         return callable_(*fargs, **fkwargs)
         return callable_(*fargs, **fkwargs)
-tasks.register(ExecuteRemoteTask)
 
 
 
 
 class AsynchronousMapTask(Task):
 class AsynchronousMapTask(Task):
@@ -441,7 +479,6 @@ class AsynchronousMapTask(Task):
         """The method run by ``celeryd``."""
         """The method run by ``celeryd``."""
         timeout = kwargs.get("timeout")
         timeout = kwargs.get("timeout")
         return TaskSet.map(pickle.loads(serfunc), args, timeout=timeout)
         return TaskSet.map(pickle.loads(serfunc), args, timeout=timeout)
-tasks.register(AsynchronousMapTask)
 
 
 
 
 class TaskSet(object):
 class TaskSet(object):
@@ -599,8 +636,6 @@ class PeriodicTask(Task):
     :raises NotImplementedError: if the :attr:`run_every` attribute is
     :raises NotImplementedError: if the :attr:`run_every` attribute is
         not defined.
         not defined.
 
 
-    You have to register the periodic task in the task registry.
-
     Example
     Example
 
 
         >>> from celery.task import tasks, PeriodicTask
         >>> from celery.task import tasks, PeriodicTask
@@ -612,9 +647,9 @@ class PeriodicTask(Task):
         ...     def run(self, **kwargs):
         ...     def run(self, **kwargs):
         ...         logger = self.get_logger(**kwargs)
         ...         logger = self.get_logger(**kwargs)
         ...         logger.info("Running MyPeriodicTask")
         ...         logger.info("Running MyPeriodicTask")
-        >>> tasks.register(MyPeriodicTask)
 
 
     """
     """
+    abstract = True
     run_every = timedelta(days=1)
     run_every = timedelta(days=1)
     type = "periodic"
     type = "periodic"
 
 

+ 0 - 2
celery/task/builtins.py

@@ -20,7 +20,6 @@ class DeleteExpiredTaskMetaTask(PeriodicTask):
         logger = self.get_logger(**kwargs)
         logger = self.get_logger(**kwargs)
         logger.info("Deleting expired task meta objects...")
         logger.info("Deleting expired task meta objects...")
         default_backend.cleanup()
         default_backend.cleanup()
-tasks.register(DeleteExpiredTaskMetaTask)
 
 
 
 
 class PingTask(Task):
 class PingTask(Task):
@@ -30,4 +29,3 @@ class PingTask(Task):
     def run(self, **kwargs):
     def run(self, **kwargs):
         """:returns: the string ``"pong"``."""
         """:returns: the string ``"pong"``."""
         return "pong"
         return "pong"
-tasks.register(PingTask)

+ 0 - 1
celery/task/rest.py

@@ -135,7 +135,6 @@ class RESTProxyTask(BaseTask):
         logger = self.get_logger(**kwargs)
         logger = self.get_logger(**kwargs)
         proxy = RESTProxy(url, kwargs, logger)
         proxy = RESTProxy(url, kwargs, logger)
         return proxy.execute()
         return proxy.execute()
-tasks.register(RESTProxyTask)
 
 
 
 
 def task_response(fun, *args, **kwargs):
 def task_response(fun, *args, **kwargs):

+ 0 - 1
celery/tests/test_backends/test_database.py

@@ -19,7 +19,6 @@ class MyPeriodicTask(PeriodicTask):
 
 
     def run(self, **kwargs):
     def run(self, **kwargs):
         return 42
         return 42
-registry.tasks.register(MyPeriodicTask)
 
 
 
 
 class TestDatabaseBackend(unittest.TestCase):
 class TestDatabaseBackend(unittest.TestCase):

+ 0 - 1
celery/tests/test_models.py

@@ -52,7 +52,6 @@ class TestModels(unittest.TestCase):
         self.assertFalse(m1 in TaskMeta.objects.all())
         self.assertFalse(m1 in TaskMeta.objects.all())
 
 
     def test_periodic_taskmeta(self):
     def test_periodic_taskmeta(self):
-        tasks.register(TestPeriodicTask)
         p = self.createPeriodicTaskMeta(TestPeriodicTask.name)
         p = self.createPeriodicTaskMeta(TestPeriodicTask.name)
         # check that repr works.
         # check that repr works.
         self.assertTrue(unicode(p).startswith("<PeriodicTask:"))
         self.assertTrue(unicode(p).startswith("<PeriodicTask:"))

+ 2 - 13
celery/tests/test_task.py

@@ -122,9 +122,10 @@ class TestTaskRetries(unittest.TestCase):
 class TestCeleryTasks(unittest.TestCase):
 class TestCeleryTasks(unittest.TestCase):
 
 
     def createTaskCls(self, cls_name, task_name=None):
     def createTaskCls(self, cls_name, task_name=None):
-        attrs = {}
+        attrs = {"__module__": self.__module__}
         if task_name:
         if task_name:
             attrs["name"] = task_name
             attrs["name"] = task_name
+
         cls = type(cls_name, (task.Task, ), attrs)
         cls = type(cls_name, (task.Task, ), attrs)
         cls.run = return_True
         cls.run = return_True
         return cls
         return cls
@@ -169,7 +170,6 @@ class TestCeleryTasks(unittest.TestCase):
         task_kwargs = task_data.get("kwargs", {})
         task_kwargs = task_data.get("kwargs", {})
         if test_eta:
         if test_eta:
             self.assertTrue(isinstance(task_data.get("eta"), datetime))
             self.assertTrue(isinstance(task_data.get("eta"), datetime))
-            print("TASK_KWARGS: %s" % task_kwargs)
         for arg_name, arg_value in kwargs.items():
         for arg_name, arg_value in kwargs.items():
             self.assertEquals(task_kwargs.get(arg_name), arg_value)
             self.assertEquals(task_kwargs.get(arg_name), arg_value)
 
 
@@ -193,7 +193,6 @@ class TestCeleryTasks(unittest.TestCase):
         T2 = self.createTaskCls("T2")
         T2 = self.createTaskCls("T2")
         self.assertEquals(T2().name, "celery.tests.test_task.T2")
         self.assertEquals(T2().name, "celery.tests.test_task.T2")
 
 
-        registry.tasks.register(T1)
         t1 = T1()
         t1 = T1()
         consumer = t1.get_consumer()
         consumer = t1.get_consumer()
         self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
         self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
@@ -313,13 +312,3 @@ class TestTaskApply(unittest.TestCase):
         self.assertFalse(f.is_done())
         self.assertFalse(f.is_done())
         self.assertTrue(f.traceback)
         self.assertTrue(f.traceback)
         self.assertRaises(KeyError, f.get)
         self.assertRaises(KeyError, f.get)
-
-
-class TestPeriodicTask(unittest.TestCase):
-
-    def test_interface(self):
-
-        class MyPeriodicTask(task.PeriodicTask):
-            run_every = None
-
-        self.assertRaises(NotImplementedError, MyPeriodicTask)

+ 0 - 1
celery/worker/job.py

@@ -148,7 +148,6 @@ class TaskWrapper(object):
                             "task_retries": self.retries}
                             "task_retries": self.retries}
         fun = self.task.run
         fun = self.task.run
         supported_keys = fun_takes_kwargs(fun, default_kwargs)
         supported_keys = fun_takes_kwargs(fun, default_kwargs)
-        print("TASK_NAME: %s SUP: %s" % (self.task_name, supported_keys))
         extend_with = dict((key, val) for key, val in default_kwargs.items()
         extend_with = dict((key, val) for key, val in default_kwargs.items()
                                 if key in supported_keys)
                                 if key in supported_keys)
         kwargs.update(extend_with)
         kwargs.update(extend_with)

+ 0 - 1
testproj/someapp/tasks.py

@@ -6,4 +6,3 @@ class SomeAppTask(Task):
 
 
     def run(self, **kwargs):
     def run(self, **kwargs):
         return 42
         return 42
-tasks.register(SomeAppTask)