瀏覽代碼

Refactor TaskRegistry

Ask Solem 15 年之前
父節點
當前提交
560817d20c
共有 3 個文件被更改,包括 21 次插入27 次删除
  1. 1 2
      celery/beat.py
  2. 14 20
      celery/registry.py
  3. 6 5
      celery/tests/test_registry.py

+ 1 - 2
celery/beat.py

@@ -129,8 +129,7 @@ class Scheduler(UserDict):
 
     def schedule_registry(self):
         """Add the current contents of the registry to the schedule."""
-        periodic_tasks = self.registry.get_all_periodic()
-        for name, task in self.registry.get_all_periodic().items():
+        for name, task in self.registry.periodic().items():
             if name not in self.schedule:
                 self.logger.debug(
                         "Scheduler: Adding periodic task %s to schedule" % (

+ 14 - 20
celery/registry.py

@@ -2,7 +2,7 @@
 import inspect
 from UserDict import UserDict
 
-from celery.exceptions import NotRegistered, AlreadyRegistered
+from celery.exceptions import NotRegistered
 
 
 class TaskRegistry(UserDict):
@@ -13,11 +13,20 @@ class TaskRegistry(UserDict):
     def __init__(self):
         self.data = {}
 
+    def regular(self):
+        """Get all regular task types."""
+        return self.filter_types("regular")
+
+    def periodic(self):
+        """Get all periodic task types."""
+        return self.filter_types("periodic")
+
     def register(self, task):
         """Register a task in the task registry.
 
-        The task will be automatically instantiated if it's a class
-        not an instance.
+        The task will be automatically instantiated if not already an
+        instance.
+
         """
 
         task = inspect.isclass(task) and task() or task
@@ -28,7 +37,7 @@ class TaskRegistry(UserDict):
         """Unregister task by name.
 
         :param name: name of the task to unregister, or a
-            :class:`celery.task.Task` class with a valid ``name`` attribute.
+            :class:`celery.task.base.Task` with a valid ``name`` attribute.
 
         :raises celery.exceptions.NotRegistered: if the task has not
             been registered.
@@ -38,28 +47,12 @@ class TaskRegistry(UserDict):
             name = name.name
         self.pop(name)
 
-    def get_all(self):
-        """Get all task types."""
-        return self.data
-
     def filter_types(self, type):
         """Return all tasks of a specific type."""
         return dict((task_name, task)
                         for task_name, task in self.data.items()
                             if task.type == type)
 
-    def get_all_regular(self):
-        """Get all regular task types."""
-        return self.filter_types(type="regular")
-
-    def get_all_periodic(self):
-        """Get all periodic task types."""
-        return self.filter_types(type="periodic")
-
-    def get_task(self, name):
-        """Get task by name."""
-        return self.data[name]
-
     def __getitem__(self, key):
         try:
             return UserDict.__getitem__(self, key)
@@ -72,6 +65,7 @@ class TaskRegistry(UserDict):
         except KeyError, exc:
             raise self.NotRegistered(exc)
 
+
 """
 .. data:: tasks
 

+ 6 - 5
celery/tests/test_registry.py

@@ -1,4 +1,5 @@
 import unittest
+
 from celery import registry
 from celery.task import Task, PeriodicTask
 
@@ -38,21 +39,21 @@ class TestTaskRegistry(unittest.TestCase):
         self.assertRegisterUnregisterCls(r, TestTask)
         self.assertRegisterUnregisterCls(r, TestPeriodicTask)
 
-        tasks = r.get_all()
+        tasks = r.all()
         self.assertTrue(isinstance(tasks.get(TestTask.name), TestTask))
         self.assertTrue(isinstance(tasks.get(TestPeriodicTask.name),
                                    TestPeriodicTask))
 
-        regular = r.get_all_regular()
+        regular = r.regular()
         self.assertTrue(TestTask.name in regular)
         self.assertFalse(TestPeriodicTask.name in regular)
 
-        periodic = r.get_all_periodic()
+        periodic = r.periodic()
         self.assertFalse(TestTask.name in periodic)
         self.assertTrue(TestPeriodicTask.name in periodic)
 
-        self.assertTrue(isinstance(r.get_task(TestTask.name), TestTask))
-        self.assertTrue(isinstance(r.get_task(TestPeriodicTask.name),
+        self.assertTrue(isinstance(r[TestTask.name], TestTask))
+        self.assertTrue(isinstance(r[TestPeriodicTask.name],
                                    TestPeriodicTask))
 
         r.unregister(TestTask)